diff --git a/plume_python/cli.py b/plume_python/cli.py index e883d44..de77f3c 100644 --- a/plume_python/cli.py +++ b/plume_python/cli.py @@ -1,6 +1,6 @@ import os.path -import click +import click # type: ignore from plume_python import parser from plume_python.export.xdf_exporter import export_xdf_from_record @@ -36,17 +36,13 @@ def export_xdf(record_path: str, xdf_output_path: str | None): xdf_output_path = record_path.replace(".plm", ".xdf") if os.path.exists(xdf_output_path): - if not click.confirm( - f"File '{xdf_output_path}' already exists, do you want to overwrite it?" - ): + if not click.confirm(f"File '{xdf_output_path}' already exists, do you want to overwrite it?"): return with open(xdf_output_path, "wb") as xdf_output_file: record = parser.parse_record_from_file(record_path) export_xdf_from_record(xdf_output_file, record) - click.echo( - "Exported xdf from record: " + record_path + " to " + xdf_output_path - ) + click.echo("Exported xdf from record: " + record_path + " to " + xdf_output_path) @click.command() @@ -134,24 +130,14 @@ def export_csv(record_path: str, output_dir: str | None, filter: str): for sample_type, df in dataframes.items(): file_path = os.path.join(output_dir, sample_type.__name__ + ".csv") df.to_csv(file_path) - click.echo( - "Exported CSV for sample type: " - + sample_type.__name__ - + " to " - + file_path - ) + click.echo("Exported CSV for sample type: " + sample_type.__name__ + " to " + file_path) else: sample_types = sample_types_from_names(filters) for sample_type in sample_types: df = samples_to_dataframe(record.get_samples_by_type(sample_type)) file_path = os.path.join(output_dir, sample_type.__name__ + ".csv") df.to_csv(file_path) - click.echo( - "Exported CSV for sample type: " - + sample_type.__name__ - + " to " - + file_path - ) + click.echo("Exported CSV for sample type: " + sample_type.__name__ + " to " + file_path) cli.add_command(export_csv) diff --git a/plume_python/export/xdf_exporter.py b/plume_python/export/xdf_exporter.py index 80d1ad9..1489203 100644 --- a/plume_python/export/xdf_exporter.py +++ b/plume_python/export/xdf_exporter.py @@ -15,12 +15,7 @@ def export_xdf_from_record(output_file: BinaryIO, record: Record): - datetime_str = ( - record.get_metadata() - .start_time.ToDatetime() - .astimezone() - .strftime("%Y-%m-%dT%H:%M:%S%z") - ) + datetime_str = record.get_metadata().start_time.ToDatetime().astimezone().strftime("%Y-%m-%dT%H:%M:%S%z") # Add a colon separator to the offset segment datetime_str = "{0}:{1}".format(datetime_str[:-2], datetime_str[-2:]) @@ -40,9 +35,7 @@ def export_xdf_from_record(output_file: BinaryIO, record: Record): for lsl_open_stream in record[lsl_stream_pb2.StreamOpen]: xml_header = ET.fromstring(lsl_open_stream.payload.xml_header) - stream_id = ( - np.uint64(lsl_open_stream.payload.stream_id) + 1 - ) # reserve id = 1 for the marker stream + stream_id = np.uint64(lsl_open_stream.payload.stream_id) + 1 # reserve id = 1 for the marker stream channel_format = xml_header.find("channel_format").text stream_channel_format[stream_id] = channel_format stream_min_time[stream_id] = None @@ -66,33 +59,19 @@ def export_xdf_from_record(output_file: BinaryIO, record: Record): stream_sample_count[stream_id] += 1 if channel_format == "string": - val = np.array( - [x for x in lsl_sample.payload.string_values.value], dtype=np.str_ - ) + val = np.array([x for x in lsl_sample.payload.string_values.value], dtype=np.str_) elif channel_format == "int8": - val = np.array( - [x for x in lsl_sample.payload.int8_values.value], dtype=np.int8 - ) + val = np.array([x for x in lsl_sample.payload.int8_values.value], dtype=np.int8) elif channel_format == "int16": - val = np.array( - [x for x in lsl_sample.payload.int16_values.value], dtype=np.int16 - ) + val = np.array([x for x in lsl_sample.payload.int16_values.value], dtype=np.int16) elif channel_format == "int32": - val = np.array( - [x for x in lsl_sample.payload.int32_values.value], dtype=np.int32 - ) + val = np.array([x for x in lsl_sample.payload.int32_values.value], dtype=np.int32) elif channel_format == "int64": - val = np.array( - [x for x in lsl_sample.payload.int64_values.value], dtype=np.int64 - ) + val = np.array([x for x in lsl_sample.payload.int64_values.value], dtype=np.int64) elif channel_format == "float32": - val = np.array( - [x for x in lsl_sample.payload.float_values.value], dtype=np.float32 - ) + val = np.array([x for x in lsl_sample.payload.float_values.value], dtype=np.float32) elif channel_format == "double64": - val = np.array( - [x for x in lsl_sample.payload.double_values.value], dtype=np.float64 - ) + val = np.array([x for x in lsl_sample.payload.double_values.value], dtype=np.float64) else: raise ValueError(f"Unsupported channel format: {channel_format}") @@ -101,15 +80,9 @@ def export_xdf_from_record(output_file: BinaryIO, record: Record): for marker_sample in record[marker_pb2.Marker]: t = marker_sample.timestamp / 1_000_000_000.0 # convert time to seconds - if ( - stream_min_time[marker_stream_id] is None - or t < stream_min_time[marker_stream_id] - ): + if stream_min_time[marker_stream_id] is None or t < stream_min_time[marker_stream_id]: stream_min_time[marker_stream_id] = t - if ( - stream_max_time[marker_stream_id] is None - or t > stream_max_time[marker_stream_id] - ): + if stream_max_time[marker_stream_id] is None or t > stream_max_time[marker_stream_id]: stream_max_time[marker_stream_id] = t if marker_stream_id not in stream_sample_count: @@ -155,6 +128,4 @@ def write_marker_stream_header(output_buf, marker_stream_id): channel_count_el.text = "1" nominal_srate_el.text = "0.0" xml = ET.tostring(info_el, encoding=STR_ENCODING, xml_declaration=True) - write_stream_header( - output_buf, xml, marker_stream_id - ) # stream_id = 1 is reserved for the marker stream + write_stream_header(output_buf, xml, marker_stream_id) # stream_id = 1 is reserved for the marker stream diff --git a/plume_python/export/xdf_writer.py b/plume_python/export/xdf_writer.py index db5b855..9f49be6 100644 --- a/plume_python/export/xdf_writer.py +++ b/plume_python/export/xdf_writer.py @@ -54,9 +54,7 @@ def write_file_header(output: BinaryIO, version: str, datetime: str): write_chunk(output, ChunkTag.FILE_HEADER, xml_str) -def write_chunk( - output: BinaryIO, chunk_tag: ChunkTag, content: bytes, stream_id: np.uint32 = None -): +def write_chunk(output: BinaryIO, chunk_tag: ChunkTag, content: bytes, stream_id: np.uint32 = None): if not isinstance(content, bytes): raise Exception("Content should be bytes.") @@ -74,9 +72,7 @@ def write_chunk( write(output, content) -def write_stream_header( - output: BinaryIO, xml_header: str | bytes, stream_id: np.uint32 = None -): +def write_stream_header(output: BinaryIO, xml_header: str | bytes, stream_id: np.uint32 = None): if isinstance(xml_header, str): xml_header = bytes(xml_header, encoding=STR_ENCODING) @@ -210,25 +206,11 @@ def write_variable_length_integer(output: BinaryIO, val: np.uint64): def write_fixed_length_integer( output: BinaryIO, - val: np.int8 - | np.int16 - | np.int32 - | np.int64 - | np.uint8 - | np.uint16 - | np.uint32 - | np.uint64, + val: np.int8 | np.int16 | np.int32 | np.int64 | np.uint8 | np.uint16 | np.uint32 | np.uint64, ): if not isinstance( val, - np.int8 - | np.int16 - | np.int32 - | np.int64 - | np.uint8 - | np.uint16 - | np.uint32 - | np.uint64, + np.int8 | np.int16 | np.int32 | np.int64 | np.uint8 | np.uint16 | np.uint32 | np.uint64, ): raise Exception("Unsupported data type " + str(type(val))) diff --git a/plume_python/parser.py b/plume_python/parser.py index c4a50f3..9ae6db3 100644 --- a/plume_python/parser.py +++ b/plume_python/parser.py @@ -2,10 +2,10 @@ from warnings import warn from tqdm import tqdm -from delimited_protobuf import read as _read_delimited -from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf.message import Message -from google.protobuf.message_factory import GetMessageClass +from delimited_protobuf import read as _read_delimited # type: ignore +from google.protobuf import descriptor_pool as _descriptor_pool # type: ignore +from google.protobuf.message import Message # type: ignore +from google.protobuf.message_factory import GetMessageClass # type: ignore from . import file_reader from .record import Record, Sample, FrameDataSample, FrameInfo @@ -15,9 +15,7 @@ T = TypeVar("T", bound=Message) -def unpack_any( - any_payload: Any, descriptor_pool: _descriptor_pool.DescriptorPool -) -> Optional[Message]: +def unpack_any(any_payload: Any, descriptor_pool: _descriptor_pool.DescriptorPool) -> Optional[Message]: """Unpacks an Any message into its original type using the provided descriptor pool.""" try: descriptor = descriptor_pool.FindMessageTypeByName(any_payload.TypeName()) @@ -44,27 +42,19 @@ def parse_record_from_stream(data_stream: BinaryIO) -> Record: unpacked_payload = unpack_any(packed_sample.payload, default_descriptor_pool) if unpacked_payload is None: continue - timestamp = ( - packed_sample.timestamp if packed_sample.HasField("timestamp") else None - ) + timestamp = packed_sample.timestamp if packed_sample.HasField("timestamp") else None if timestamp is not None: - last_timestamp = ( - timestamp if last_timestamp is None else max(last_timestamp, timestamp) - ) + last_timestamp = timestamp if last_timestamp is None else max(last_timestamp, timestamp) if first_timestamp is None: first_timestamp = timestamp if isinstance(unpacked_payload, frame_pb2.Frame): frame = cast(frame_pb2.Frame, unpacked_payload) - frames_info.append( - FrameInfo(frame_number=frame.frame_number, timestamp=timestamp) - ) + frames_info.append(FrameInfo(frame_number=frame.frame_number, timestamp=timestamp)) for packed_frame_data in frame.data: - unpacked_frame_data = unpack_any( - packed_frame_data, default_descriptor_pool - ) + unpacked_frame_data = unpack_any(packed_frame_data, default_descriptor_pool) if unpacked_frame_data is None: continue frame_data_sample = FrameDataSample( @@ -73,9 +63,7 @@ def parse_record_from_stream(data_stream: BinaryIO) -> Record: payload=unpacked_frame_data, ) payload_type = type(unpacked_frame_data) - samples_by_type.setdefault( - payload_type, list[FrameDataSample[T]]() - ).append(frame_data_sample) + samples_by_type.setdefault(payload_type, list[FrameDataSample[T]]()).append(frame_data_sample) else: sample = Sample(timestamp=timestamp, payload=unpacked_payload) payload_type = type(unpacked_payload) @@ -101,9 +89,7 @@ def parse_packed_samples_from_stream( """Parses packed samples from a binary stream and returns a list of packed samples.""" packed_samples = [] - pbar = tqdm( - desc="Parsing packed samples", unit="bytes", total=len(data_stream.getbuffer()) - ) + pbar = tqdm(desc="Parsing packed samples", unit="bytes", total=len(data_stream.getbuffer())) while data_stream.tell() < len(data_stream.getbuffer()): packed_sample = _read_delimited(data_stream, packed_sample_pb2.PackedSample) diff --git a/plume_python/record.py b/plume_python/record.py index a1f49c0..abf7070 100644 --- a/plume_python/record.py +++ b/plume_python/record.py @@ -2,7 +2,7 @@ from typing import TypeVar, Generic, Optional, Type from .samples.record_pb2 import RecordMetadata -from google.protobuf.message import Message +from google.protobuf.message import Message # type: ignore T = TypeVar("T", bound=Message) @@ -52,15 +52,10 @@ def get_samples_by_type(self, payload_type: Type[T]) -> list[Sample[T]]: def get_timeless_samples(self) -> list[Sample[T]]: return [ - sample - for samples in self.samples_by_type.values() - for sample in samples - if not sample.is_timestamped() + sample for samples in self.samples_by_type.values() for sample in samples if not sample.is_timestamped() ] - def get_samples_in_time_range( - self, start: Optional[int], end: Optional[int] - ) -> dict[Type[T], list[Sample[T]]]: + def get_samples_in_time_range(self, start: Optional[int], end: Optional[int]) -> dict[Type[T], list[Sample[T]]]: samples_in_time_range = {} for payload_type, samples in self.samples_by_type.items(): diff --git a/plume_python/utils/dataframe.py b/plume_python/utils/dataframe.py index e7736f1..fdc9eee 100644 --- a/plume_python/utils/dataframe.py +++ b/plume_python/utils/dataframe.py @@ -63,9 +63,7 @@ def samples_to_dataframe(samples: list[Sample[T]]) -> pd.DataFrame: for sample in samples: sample_payload_fields_value = MessageToDict(sample.payload, True) if sample.is_timestamped(): - sample_data.append( - {"timestamp": sample.timestamp} | sample_payload_fields_value - ) + sample_data.append({"timestamp": sample.timestamp} | sample_payload_fields_value) else: sample_data.append(sample_payload_fields_value) diff --git a/plume_python/utils/game_object.py b/plume_python/utils/game_object.py index 015a1a0..a0fa438 100644 --- a/plume_python/utils/game_object.py +++ b/plume_python/utils/game_object.py @@ -24,9 +24,7 @@ def find_first_name_by_guid(record: Record, guid: str) -> Optional[str]: return None -def find_identifier_by_game_object_id( - record: Record, game_object_id: str -) -> Optional[GameObjectIdentifier]: +def find_identifier_by_game_object_id(record: Record, game_object_id: str) -> Optional[GameObjectIdentifier]: for go_update_sample in record[GameObjectUpdate]: go_update = go_update_sample.payload if go_update.id.game_object_id == game_object_id: @@ -41,18 +39,13 @@ def find_identifiers_by_name(record: Record, name: str) -> list[GameObjectIdenti for go_update_sample in record[GameObjectUpdate]: go_update = go_update_sample.payload if go_update.HasField("name"): - if ( - name == go_update.name - and go_update.id.game_object_id not in known_guids - ): + if name == go_update.name and go_update.id.game_object_id not in known_guids: identifiers.append(go_update.id) known_guids.add(go_update.id.game_object_id) return identifiers -def find_first_identifier_by_name( - record: Record, name: str -) -> Optional[GameObjectIdentifier]: +def find_first_identifier_by_name(record: Record, name: str) -> Optional[GameObjectIdentifier]: for go_update_sample in record[GameObjectUpdate]: go_update = go_update_sample.payload if go_update.HasField("name"): diff --git a/plume_python/utils/transform.py b/plume_python/utils/transform.py index 8a8481e..ed56162 100644 --- a/plume_python/utils/transform.py +++ b/plume_python/utils/transform.py @@ -15,24 +15,12 @@ @dataclass(slots=True) class Transform: _guid: str - _local_position: np.ndarray = field( - default_factory=lambda: np.array(3, dtype=np.float32) - ) - _local_rotation: quaternion.quaternion = field( - default_factory=lambda: quaternion.quaternion(1, 0, 0, 0) - ) - _local_scale: np.ndarray = field( - default_factory=lambda: np.array(4, dtype=np.float32) - ) - _local_T_mtx: np.ndarray = field( - default_factory=lambda: np.eye(4, dtype=np.float32) - ) - _local_R_mtx: np.ndarray = field( - default_factory=lambda: np.eye(4, dtype=np.float32) - ) - _local_S_mtx: np.ndarray = field( - default_factory=lambda: np.eye(4, dtype=np.float32) - ) + _local_position: np.ndarray = field(default_factory=lambda: np.array(3, dtype=np.float32)) + _local_rotation: quaternion.quaternion = field(default_factory=lambda: quaternion.quaternion(1, 0, 0, 0)) + _local_scale: np.ndarray = field(default_factory=lambda: np.array(4, dtype=np.float32)) + _local_T_mtx: np.ndarray = field(default_factory=lambda: np.eye(4, dtype=np.float32)) + _local_R_mtx: np.ndarray = field(default_factory=lambda: np.eye(4, dtype=np.float32)) + _local_S_mtx: np.ndarray = field(default_factory=lambda: np.eye(4, dtype=np.float32)) _local_to_world_mtx: np.ndarray = None _parent: Optional[Transform] = None _dirty: bool = True @@ -83,9 +71,7 @@ def get_local_to_world_matrix(self) -> np.ndarray: if self._parent is None: self._local_to_world_mtx = trs_mtx else: - self._local_to_world_mtx = ( - self._parent.get_local_to_world_matrix() @ trs_mtx - ) + self._local_to_world_mtx = self._parent.get_local_to_world_matrix() @ trs_mtx self._dirty = False @@ -95,9 +81,7 @@ def get_world_position(self) -> np.ndarray: return self.get_local_to_world_matrix()[0:3, 3].transpose() def get_world_rotation(self) -> quaternion: - return quaternion.from_rotation_matrix( - self.get_local_to_world_matrix()[0:3, 0:3] - ) + return quaternion.from_rotation_matrix(self.get_local_to_world_matrix()[0:3, 0:3]) def get_world_scale(self) -> np.ndarray: local_to_world_mtx = self.get_local_to_world_matrix() @@ -133,9 +117,7 @@ def get_world_scale(self) -> np.ndarray: return scale -def compute_transform_time_series( - record: Record, guid: str -) -> list[TimestampedTransform]: +def compute_transform_time_series(record: Record, guid: str) -> list[TimestampedTransform]: transform_time_series = compute_transforms_time_series(record, {guid}) return transform_time_series.get(guid, {}) @@ -146,27 +128,17 @@ def compute_transforms_time_series( result: dict[str, list[TimestampedTransform]] = {} current_transforms: dict[str, Transform] = {} - creation_samples: dict[ - int, list[FrameDataSample[transform_pb2.TransformCreate]] - ] = {} - destruction_samples: dict[ - int, list[FrameDataSample[transform_pb2.TransformDestroy]] - ] = {} + creation_samples: dict[int, list[FrameDataSample[transform_pb2.TransformCreate]]] = {} + destruction_samples: dict[int, list[FrameDataSample[transform_pb2.TransformDestroy]]] = {} update_samples: dict[int, list[FrameDataSample[transform_pb2.TransformUpdate]]] = {} - for frame_number, s in groupby( - record[transform_pb2.TransformCreate], lambda x: x.frame_number - ): + for frame_number, s in groupby(record[transform_pb2.TransformCreate], lambda x: x.frame_number): creation_samples[frame_number] = list(s) - for frame_number, s in groupby( - record[transform_pb2.TransformDestroy], lambda x: x.frame_number - ): + for frame_number, s in groupby(record[transform_pb2.TransformDestroy], lambda x: x.frame_number): destruction_samples[frame_number] = list(s) - for frame_number, s in groupby( - record[transform_pb2.TransformUpdate], lambda x: x.frame_number - ): + for frame_number, s in groupby(record[transform_pb2.TransformUpdate], lambda x: x.frame_number): update_samples[frame_number] = list(s) for frame in tqdm(record.frames_info, desc="Computing world positions"): @@ -205,9 +177,7 @@ def compute_transforms_time_series( local_transform.set_local_rotation(q) if update_sample.payload.HasField("local_scale"): local_scale = update_sample.payload.local_scale - local_transform.set_local_scale( - np.array([local_scale.x, local_scale.y, local_scale.z]) - ) + local_transform.set_local_scale(np.array([local_scale.x, local_scale.y, local_scale.z])) if update_sample.payload.HasField("parent_transform_id"): parent_guid = update_sample.payload.parent_transform_id.component_id if parent_guid == "00000000000000000000000000000000": # null guid @@ -222,11 +192,7 @@ def compute_transforms_time_series( if included_guids is None: included_transforms = current_transforms.values() else: - included_transforms = [ - current_transforms[guid] - for guid in included_guids - if guid in current_transforms - ] + included_transforms = [current_transforms[guid] for guid in included_guids if guid in current_transforms] for t in included_transforms: timestamped_transform = TimestampedTransform( @@ -239,8 +205,6 @@ def compute_transforms_time_series( local_rotation=t.get_local_rotation(), local_to_world_mtx=t.get_local_to_world_matrix(), ) - result.setdefault(t.get_guid(), list[TimestampedTransform]()).append( - timestamped_transform - ) + result.setdefault(t.get_guid(), list[TimestampedTransform]()).append(timestamped_transform) return result diff --git a/pyproject.toml b/pyproject.toml index d7f3dc6..0e3f341 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,10 @@ exclude = ["tests"] [tool.poetry.scripts] plume-python = "plume_python.cli:cli" +[tool.ruff] +line-length = 119 +target-version = "py39" + [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api"