diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index cba025e..824ed6b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -34,6 +34,9 @@ jobs: key: venv-${{ hashFiles('poetry.lock') }} - name: Install the project dependencies run: poetry install --with dev + - name: Run lint + run: | + make lint - name: Run tests run: | make tests diff --git a/.gitignore b/.gitignore index 424791f..28cefba 100644 --- a/.gitignore +++ b/.gitignore @@ -165,4 +165,4 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ \ No newline at end of file +#.idea/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..88d2717 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,24 @@ +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: check-added-large-files + args: ['--maxkb=600'] + - id: check-yaml + - id: check-json + - id: check-toml + - id: end-of-file-fixer + - id: trailing-whitespace + - id: check-docstring-first +- repo: https://github.com/python-poetry/poetry + rev: 1.8.0 + hooks: + - id: poetry-check +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.4.2 + hooks: + - id: ruff + args: [ --fix ] + exclude: (^plume_python/samples/|^tests/) + - id: ruff-format + exclude: (^plume_python/samples/|^tests/) diff --git a/CHANGELOG.md b/CHANGELOG.md index 186b9de..ccd4c70 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,4 +23,4 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- Fixed a bug where extracting samples by time range would throw an exception if the record contained timeless samples. \ No newline at end of file +- Fixed a bug where extracting samples by time range would throw an exception if the record contained timeless samples. diff --git a/LICENSE b/LICENSE index e72bfdd..f288702 100644 --- a/LICENSE +++ b/LICENSE @@ -671,4 +671,4 @@ into proprietary programs. If your program is a subroutine library, you may consider it more useful to permit linking proprietary applications with the library. If this is what you want to do, use the GNU Lesser General Public License instead of this License. But first, please read -. \ No newline at end of file +. diff --git a/Makefile b/Makefile index 0f57f21..d3e4018 100644 --- a/Makefile +++ b/Makefile @@ -7,3 +7,8 @@ install: tests: @echo "--- 🧪 Running tests ---" poetry run pytest + +.PHONY: lint +lint: + @echo "--- 🧹 Linting code ---" + poetry run pre-commit run --all-files diff --git a/README.md b/README.md index ae35544..86494a0 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,7 @@ For more advanced usage, the package can be imported in a Python script: import plume_python as plm from plume_python.utils.dataframe import samples_to_dataframe, record_to_dataframes from plume_python.samples.unity import transform_pb2 -from plume_python.export import xdf_exporter +from plume_python.export import xdf_exporter from plume_python.utils.game_object import find_names_by_guid, find_first_identifier_by_name # Load a record file @@ -136,4 +136,4 @@ Sophie VILLENAVE - sophie.villenave@ec-lyon.fr ``` [Button Docs]: https://img.shields.io/badge/Explore%20the%20docs-%E2%86%92-brightgreen -[Explore the docs]: https://liris-xr.github.io/PLUME/ \ No newline at end of file +[Explore the docs]: https://liris-xr.github.io/PLUME/ diff --git a/plume_python/__init__.py b/plume_python/__init__.py index d59f629..e69de29 100644 --- a/plume_python/__init__.py +++ b/plume_python/__init__.py @@ -1,7 +0,0 @@ -from . import file_reader -from . import parser -from . import record -from . import utils -from . import samples -from . import export -from . import cli diff --git a/plume_python/cli.py b/plume_python/cli.py index 869a7c5..a3b8c9e 100644 --- a/plume_python/cli.py +++ b/plume_python/cli.py @@ -1,13 +1,20 @@ import os.path -import click +import click # type: ignore from plume_python import parser from plume_python.export.xdf_exporter import export_xdf_from_record from plume_python.samples import sample_types_from_names -from plume_python.utils.dataframe import record_to_dataframes, samples_to_dataframe, world_transforms_to_dataframe -from plume_python.utils.game_object import find_names_by_guid, find_identifiers_by_name, \ - find_identifier_by_game_object_id +from plume_python.utils.dataframe import ( + record_to_dataframes, + samples_to_dataframe, + world_transforms_to_dataframe, +) +from plume_python.utils.game_object import ( + find_names_by_guid, + find_identifiers_by_name, + find_identifier_by_game_object_id, +) from plume_python.utils.transform import compute_transform_time_series @@ -17,33 +24,40 @@ def cli(): @click.command() -@click.argument('record_path', type=click.Path(exists=True, readable=True)) -@click.option('--xdf_output_path', type=click.Path(writable=True)) +@click.argument("record_path", type=click.Path(exists=True, readable=True)) +@click.option("--xdf_output_path", type=click.Path(writable=True)) def export_xdf(record_path: str, xdf_output_path: str | None): """Export a XDF file including LSL samples and markers.""" - if not record_path.endswith('.plm'): + if not record_path.endswith(".plm"): click.echo(err=True, message="Input file must be a .plm file") return if xdf_output_path is None: - xdf_output_path = record_path.replace('.plm', '.xdf') + xdf_output_path = record_path.replace(".plm", ".xdf") if os.path.exists(xdf_output_path): - if not click.confirm(f"File '{xdf_output_path}' already exists, do you want to overwrite it?"): + if not click.confirm( + f"File '{xdf_output_path}' already exists, do you want to overwrite it?" + ): return with open(xdf_output_path, "wb") as xdf_output_file: record = parser.parse_record_from_file(record_path) export_xdf_from_record(xdf_output_file, record) - click.echo('Exported xdf from record: ' + record_path + ' to ' + xdf_output_path) + click.echo( + "Exported xdf from record: " + + record_path + + " to " + + xdf_output_path + ) @click.command() -@click.argument('record_path', type=click.Path(exists=True, readable=True)) -@click.argument('guid', type=click.STRING) +@click.argument("record_path", type=click.Path(exists=True, readable=True)) +@click.argument("guid", type=click.STRING) def find_name(record_path: str, guid: str): """Find the name(s) of a GameObject with the given GUID in the record.""" - if not record_path.endswith('.plm'): + if not record_path.endswith(".plm"): click.echo(err=True, message="Input file must be a .plm file") return @@ -58,11 +72,11 @@ def find_name(record_path: str, guid: str): @click.command() -@click.argument('record_path', type=click.Path(exists=True, readable=True)) -@click.argument('name', type=click.STRING) +@click.argument("record_path", type=click.Path(exists=True, readable=True)) +@click.argument("name", type=click.STRING) def find_guid(record_path: str, name: str): """Find the GUID(s) of a GameObject by the given name.""" - if not record_path.endswith('.plm'): + if not record_path.endswith(".plm"): click.echo(err=True, message="Input file must be a .plm file") return @@ -77,11 +91,11 @@ def find_guid(record_path: str, name: str): @click.command() -@click.argument('record_path', type=click.Path(exists=True, readable=True)) -@click.argument('guid', type=click.STRING) +@click.argument("record_path", type=click.Path(exists=True, readable=True)) +@click.argument("guid", type=click.STRING) def export_world_transforms(record_path: str, guid: str): """Export world transforms of a GameObject with the given GUID to a CSV file.""" - if not record_path.endswith('.plm'): + if not record_path.endswith(".plm"): click.echo(err=True, message="Input file must be a .plm file") return @@ -92,40 +106,57 @@ def export_world_transforms(record_path: str, guid: str): click.echo(err=True, message=f"No identifier found for GUID {guid}") return - time_series = compute_transform_time_series(record, identifier.transform_id) + time_series = compute_transform_time_series( + record, identifier.transform_id + ) df = world_transforms_to_dataframe(time_series) - file_path = record_path.replace('.plm', f'_{guid}_world_transform.csv') + file_path = record_path.replace(".plm", f"_{guid}_world_transform.csv") df.to_csv(file_path) @click.command() -@click.argument('record_path', type=click.Path(exists=True, readable=True)) -@click.argument('output_dir', type=click.Path(exists=True, writable=True)) -@click.option('--filter', default="all", show_default=True, type=click.STRING, - help="Comma separated list of sample types to export (eg. 'TransformUpdate,GameObjectUpdate')") +@click.argument("record_path", type=click.Path(exists=True, readable=True)) +@click.argument("output_dir", type=click.Path(exists=True, writable=True)) +@click.option( + "--filter", + default="all", + show_default=True, + type=click.STRING, + help="Comma separated list of sample types to export (eg. 'TransformUpdate,GameObjectUpdate')", +) def export_csv(record_path: str, output_dir: str | None, filter: str): """Export samples from the record to CSV files.""" - if not record_path.endswith('.plm'): + if not record_path.endswith(".plm"): click.echo(err=True, message="Input file must be a .plm file") return record = parser.parse_record_from_file(record_path) - filters = [d.strip() for d in filter.split(',')] + filters = [d.strip() for d in filter.split(",")] - if filters == ['all'] or filters == ['*']: + if filters == ["all"] or filters == ["*"]: dataframes = record_to_dataframes(record) for sample_type, df in dataframes.items(): - file_path = os.path.join(output_dir, sample_type.__name__ + '.csv') + file_path = os.path.join(output_dir, sample_type.__name__ + ".csv") df.to_csv(file_path) - click.echo('Exported CSV for sample type: ' + sample_type.__name__ + ' to ' + file_path) + click.echo( + "Exported CSV for sample type: " + + sample_type.__name__ + + " to " + + file_path + ) else: sample_types = sample_types_from_names(filters) for sample_type in sample_types: df = samples_to_dataframe(record.get_samples_by_type(sample_type)) - file_path = os.path.join(output_dir, sample_type.__name__ + '.csv') + file_path = os.path.join(output_dir, sample_type.__name__ + ".csv") df.to_csv(file_path) - click.echo('Exported CSV for sample type: ' + sample_type.__name__ + ' to ' + file_path) + click.echo( + "Exported CSV for sample type: " + + sample_type.__name__ + + " to " + + file_path + ) cli.add_command(export_csv) diff --git a/plume_python/export/xdf_exporter.py b/plume_python/export/xdf_exporter.py index 7d36173..9d81728 100644 --- a/plume_python/export/xdf_exporter.py +++ b/plume_python/export/xdf_exporter.py @@ -1,12 +1,26 @@ -from plume_python.export.xdf_writer import * +from plume_python.export.xdf_writer import ( + STR_ENCODING, + write_file_header, + write_stream_header, + write_stream_sample, + write_stream_footer, +) from plume_python.record import Record from plume_python.samples.common import marker_pb2 from plume_python.samples.lsl import lsl_stream_pb2 from typing import BinaryIO +import xml.etree.ElementTree as ET +import numpy as np + def export_xdf_from_record(output_file: BinaryIO, record: Record): - datetime_str = record.get_metadata().start_time.ToDatetime().astimezone().strftime('%Y-%m-%dT%H:%M:%S%z') + datetime_str = ( + record.get_metadata() + .start_time.ToDatetime() + .astimezone() + .strftime("%Y-%m-%dT%H:%M:%S%z") + ) # Add a colon separator to the offset segment datetime_str = "{0}:{1}".format(datetime_str[:-2], datetime_str[-2:]) @@ -26,22 +40,32 @@ def export_xdf_from_record(output_file: BinaryIO, record: Record): for lsl_open_stream in record[lsl_stream_pb2.StreamOpen]: xml_header = ET.fromstring(lsl_open_stream.payload.xml_header) - stream_id = np.uint64(lsl_open_stream.payload.stream_id) + 1 # reserve id = 1 for the marker stream + stream_id = ( + np.uint64(lsl_open_stream.payload.stream_id) + 1 + ) # reserve id = 1 for the marker stream channel_format = xml_header.find("channel_format").text stream_channel_format[stream_id] = channel_format stream_min_time[stream_id] = None stream_max_time[stream_id] = None stream_sample_count[stream_id] = 0 - write_stream_header(output_file, lsl_open_stream.payload.xml_header, stream_id) + write_stream_header( + output_file, lsl_open_stream.payload.xml_header, stream_id + ) for lsl_sample in record[lsl_stream_pb2.StreamSample]: stream_id = np.uint64(lsl_sample.payload.stream_id) + 1 channel_format = stream_channel_format[stream_id] t = lsl_sample.timestamp / 1_000_000_000.0 # convert time to seconds - if stream_min_time[stream_id] is None or t < stream_min_time[stream_id]: + if ( + stream_min_time[stream_id] is None + or t < stream_min_time[stream_id] + ): stream_min_time[stream_id] = t - if stream_max_time[stream_id] is None or t > stream_max_time[stream_id]: + if ( + stream_max_time[stream_id] is None + or t > stream_max_time[stream_id] + ): stream_max_time[stream_id] = t if stream_id not in stream_sample_count: @@ -50,30 +74,59 @@ def export_xdf_from_record(output_file: BinaryIO, record: Record): stream_sample_count[stream_id] += 1 if channel_format == "string": - val = np.array([x for x in lsl_sample.payload.string_values.value], dtype=np.str_) + val = np.array( + [x for x in lsl_sample.payload.string_values.value], + dtype=np.str_, + ) elif channel_format == "int8": - val = np.array([x for x in lsl_sample.payload.int8_values.value], dtype=np.int8) + val = np.array( + [x for x in lsl_sample.payload.int8_values.value], + dtype=np.int8, + ) elif channel_format == "int16": - val = np.array([x for x in lsl_sample.payload.int16_values.value], dtype=np.int16) + val = np.array( + [x for x in lsl_sample.payload.int16_values.value], + dtype=np.int16, + ) elif channel_format == "int32": - val = np.array([x for x in lsl_sample.payload.int32_values.value], dtype=np.int32) + val = np.array( + [x for x in lsl_sample.payload.int32_values.value], + dtype=np.int32, + ) elif channel_format == "int64": - val = np.array([x for x in lsl_sample.payload.int64_values.value], dtype=np.int64) + val = np.array( + [x for x in lsl_sample.payload.int64_values.value], + dtype=np.int64, + ) elif channel_format == "float32": - val = np.array([x for x in lsl_sample.payload.float_values.value], dtype=np.float32) + val = np.array( + [x for x in lsl_sample.payload.float_values.value], + dtype=np.float32, + ) elif channel_format == "double64": - val = np.array([x for x in lsl_sample.payload.double_values.value], dtype=np.float64) + val = np.array( + [x for x in lsl_sample.payload.double_values.value], + dtype=np.float64, + ) else: raise ValueError(f"Unsupported channel format: {channel_format}") write_stream_sample(output_file, val, t, channel_format, stream_id) for marker_sample in record[marker_pb2.Marker]: - t = marker_sample.timestamp / 1_000_000_000.0 # convert time to seconds - - if stream_min_time[marker_stream_id] is None or t < stream_min_time[marker_stream_id]: + t = ( + marker_sample.timestamp / 1_000_000_000.0 + ) # convert time to seconds + + if ( + stream_min_time[marker_stream_id] is None + or t < stream_min_time[marker_stream_id] + ): stream_min_time[marker_stream_id] = t - if stream_max_time[marker_stream_id] is None or t > stream_max_time[marker_stream_id]: + if ( + stream_max_time[marker_stream_id] is None + or t > stream_max_time[marker_stream_id] + ): stream_max_time[marker_stream_id] = t if marker_stream_id not in stream_sample_count: @@ -87,13 +140,23 @@ def export_xdf_from_record(output_file: BinaryIO, record: Record): for lsl_close_stream in record[lsl_stream_pb2.StreamClose]: stream_id = np.uint64(lsl_close_stream.payload.stream_id) + 1 sample_count = stream_sample_count[stream_id] - write_stream_footer(output_file, stream_min_time[stream_id], stream_max_time[stream_id], sample_count, - stream_id) + write_stream_footer( + output_file, + stream_min_time[stream_id], + stream_max_time[stream_id], + sample_count, + stream_id, + ) # Write marker stream footer # stream_id = 1 is reserved for the marker stream - write_stream_footer(output_file, stream_min_time[marker_stream_id], stream_max_time[marker_stream_id], - stream_sample_count[marker_stream_id], marker_stream_id) + write_stream_footer( + output_file, + stream_min_time[marker_stream_id], + stream_max_time[marker_stream_id], + stream_sample_count[marker_stream_id], + marker_stream_id, + ) def write_marker_stream_header(output_buf, marker_stream_id): @@ -109,4 +172,6 @@ def write_marker_stream_header(output_buf, marker_stream_id): channel_count_el.text = "1" nominal_srate_el.text = "0.0" xml = ET.tostring(info_el, encoding=STR_ENCODING, xml_declaration=True) - write_stream_header(output_buf, xml, marker_stream_id) # stream_id = 1 is reserved for the marker stream + write_stream_header( + output_buf, xml, marker_stream_id + ) # stream_id = 1 is reserved for the marker stream diff --git a/plume_python/export/xdf_writer.py b/plume_python/export/xdf_writer.py index 6a1cabb..ba5521e 100644 --- a/plume_python/export/xdf_writer.py +++ b/plume_python/export/xdf_writer.py @@ -18,7 +18,19 @@ int64=np.int64, ) -DataType = np.int8 | np.int16 | np.int32 | np.int64 | np.uint8 | np.uint16 | np.uint32 | np.uint64 | np.float32 | np.float64 | str +DataType = ( + np.int8 + | np.int16 + | np.int32 + | np.int64 + | np.uint8 + | np.uint16 + | np.uint32 + | np.uint64 + | np.float32 + | np.float64 + | str +) class ChunkTag(Enum): @@ -32,18 +44,24 @@ class ChunkTag(Enum): def write_file_header(output: BinaryIO, version: str, datetime: str): - output.write(b'XDF:') + output.write(b"XDF:") info_element = ET.Element("info") version_element = ET.SubElement(info_element, "version") datetime_element = ET.SubElement(info_element, "datetime") version_element.text = version datetime_element.text = datetime xml_str = ET.tostring( - info_element, xml_declaration=True, encoding=STR_ENCODING) + info_element, xml_declaration=True, encoding=STR_ENCODING + ) write_chunk(output, ChunkTag.FILE_HEADER, xml_str) -def write_chunk(output: BinaryIO, chunk_tag: ChunkTag, content: bytes, stream_id: np.uint32 = None): +def write_chunk( + output: BinaryIO, + chunk_tag: ChunkTag, + content: bytes, + stream_id: np.uint32 = None, +): if not isinstance(content, bytes): raise Exception("Content should be bytes.") @@ -61,15 +79,27 @@ def write_chunk(output: BinaryIO, chunk_tag: ChunkTag, content: bytes, stream_id write(output, content) -def write_stream_header(output: BinaryIO, xml_header: str | bytes, stream_id: np.uint32 = None): +def write_stream_header( + output: BinaryIO, xml_header: str | bytes, stream_id: np.uint32 = None +): if isinstance(xml_header, str): xml_header = bytes(xml_header, encoding=STR_ENCODING) - write_chunk(output, ChunkTag.STREAM_HEADER, xml_header, None if stream_id is None else np.uint32(stream_id)) - - -def write_stream_footer(output: BinaryIO, first_timestamp: float, last_timestamp: float, - sample_count: int, stream_id: np.uint32 = None): + write_chunk( + output, + ChunkTag.STREAM_HEADER, + xml_header, + None if stream_id is None else np.uint32(stream_id), + ) + + +def write_stream_footer( + output: BinaryIO, + first_timestamp: float, + last_timestamp: float, + sample_count: int, + stream_id: np.uint32 = None, +): first_timestamp = np.float64(first_timestamp) last_timestamp = np.float64(last_timestamp) sample_count = np.uint64(sample_count) @@ -82,24 +112,49 @@ def write_stream_footer(output: BinaryIO, first_timestamp: float, last_timestamp sample_count_element.text = str(sample_count) xml_str = ET.tostring( - info_element, xml_declaration=True, encoding=STR_ENCODING) - write_chunk(output, ChunkTag.STREAM_FOOTER, xml_str, None if stream_id is None else np.uint32(stream_id)) - - -def write_stream_sample(output: BinaryIO, sample: np.ndarray, timestamp: float, channel_format: str, - stream_id: np.uint32 = None): + info_element, xml_declaration=True, encoding=STR_ENCODING + ) + write_chunk( + output, + ChunkTag.STREAM_FOOTER, + xml_str, + None if stream_id is None else np.uint32(stream_id), + ) + + +def write_stream_sample( + output: BinaryIO, + sample: np.ndarray, + timestamp: float, + channel_format: str, + stream_id: np.uint32 = None, +): if channel_format not in formats: - raise Exception("Unsupported channel format '{}'".format(channel_format)) + raise Exception( + "Unsupported channel format '{}'".format(channel_format) + ) fmt = formats[channel_format] - write_stream_sample_chunk(output, np.array([sample], dtype=fmt), [timestamp], - channel_format, None if stream_id is None else np.uint32(stream_id)) - - -def write_stream_sample_chunk(output: BinaryIO, chunk: np.ndarray, timestamps: list[float], channel_format: str, - stream_id: np.uint32 = None): + write_stream_sample_chunk( + output, + np.array([sample], dtype=fmt), + [timestamp], + channel_format, + None if stream_id is None else np.uint32(stream_id), + ) + + +def write_stream_sample_chunk( + output: BinaryIO, + chunk: np.ndarray, + timestamps: list[float], + channel_format: str, + stream_id: np.uint32 = None, +): if channel_format not in formats: - raise Exception("Unsupported channel format '{}'".format(channel_format)) + raise Exception( + "Unsupported channel format '{}'".format(channel_format) + ) fmt = formats[channel_format] chunk = np.array(chunk, dtype=fmt) @@ -116,7 +171,9 @@ def write_stream_sample_chunk(output: BinaryIO, chunk: np.ndarray, timestamps: l if sample.ndim == 0: if isinstance(sample, str): str_bytes = bytes(sample, STR_ENCODING) - write_variable_length_integer(tmp_output, np.uint64(len(str_bytes))) + write_variable_length_integer( + tmp_output, np.uint64(len(str_bytes)) + ) write(tmp_output, str_bytes) elif isinstance(sample, DataType): write(tmp_output, sample) @@ -126,14 +183,23 @@ def write_stream_sample_chunk(output: BinaryIO, chunk: np.ndarray, timestamps: l for channel in sample: if isinstance(channel, str): str_bytes = bytes(channel, STR_ENCODING) - write_variable_length_integer(tmp_output, np.uint64(len(str_bytes))) + write_variable_length_integer( + tmp_output, np.uint64(len(str_bytes)) + ) write(tmp_output, str_bytes) elif isinstance(channel, DataType): write(tmp_output, channel) else: - raise Exception("Unsupported data type " + str(type(channel))) + raise Exception( + "Unsupported data type " + str(type(channel)) + ) - write_chunk(output, ChunkTag.SAMPLES, tmp_output.getvalue(), None if stream_id is None else np.uint32(stream_id)) + write_chunk( + output, + ChunkTag.SAMPLES, + tmp_output.getvalue(), + None if stream_id is None else np.uint32(stream_id), + ) def write_timestamp(output: BinaryIO, timestamp: Optional[float] = None): @@ -159,9 +225,28 @@ def write_variable_length_integer(output: BinaryIO, val: np.uint64): write(output, np.uint64(val)) -def write_fixed_length_integer(output: BinaryIO, - val: np.int8 | np.int16 | np.int32 | np.int64 | np.uint8 | np.uint16 | np.uint32 | np.uint64): - if not isinstance(val, np.int8 | np.int16 | np.int32 | np.int64 | np.uint8 | np.uint16 | np.uint32 | np.uint64): +def write_fixed_length_integer( + output: BinaryIO, + val: np.int8 + | np.int16 + | np.int32 + | np.int64 + | np.uint8 + | np.uint16 + | np.uint32 + | np.uint64, +): + if not isinstance( + val, + np.int8 + | np.int16 + | np.int32 + | np.int64 + | np.uint8 + | np.uint16 + | np.uint32 + | np.uint64, + ): raise Exception("Unsupported data type " + str(type(val))) write(output, np.uint8(np.dtype(val).itemsize)) diff --git a/plume_python/file_reader.py b/plume_python/file_reader.py index c896755..70423f1 100644 --- a/plume_python/file_reader.py +++ b/plume_python/file_reader.py @@ -10,7 +10,7 @@ def _is_lz4_compressed(raw_bytes: bytes) -> bool: def read_file(filepath: str) -> BinaryIO: - if not filepath.endswith('.plm'): + if not filepath.endswith(".plm"): raise ValueError("File must be a .plm file") with open(filepath, "rb") as file: diff --git a/plume_python/parser.py b/plume_python/parser.py index acc414e..e65d4d5 100644 --- a/plume_python/parser.py +++ b/plume_python/parser.py @@ -2,10 +2,10 @@ from warnings import warn from tqdm import tqdm -from delimited_protobuf import read as _read_delimited -from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf.message import Message -from google.protobuf.message_factory import GetMessageClass +from delimited_protobuf import read as _read_delimited # type: ignore +from google.protobuf import descriptor_pool as _descriptor_pool # type: ignore +from google.protobuf.message import Message # type: ignore +from google.protobuf.message_factory import GetMessageClass # type: ignore from . import file_reader from .record import Record, Sample, FrameDataSample, FrameInfo @@ -20,11 +20,15 @@ def unpack_any( ) -> Optional[Message]: """Unpacks an Any message into its original type using the provided descriptor pool.""" try: - descriptor = descriptor_pool.FindMessageTypeByName(any_payload.TypeName()) + descriptor = descriptor_pool.FindMessageTypeByName( + any_payload.TypeName() + ) unpacked = GetMessageClass(descriptor)() success = any_payload.Unpack(unpacked) if not success: - warn(f"Failed to unpack payload with type name {any_payload.TypeName()}") + warn( + f"Failed to unpack payload with type name {any_payload.TypeName()}" + ) return None return unpacked except KeyError: @@ -40,18 +44,25 @@ def parse_record_from_stream(data_stream: BinaryIO) -> Record: first_timestamp: Optional[int] = None last_timestamp: Optional[int] = None - for packed_sample in tqdm(packed_samples, desc="Unpacking samples", unit="samples"): - - unpacked_payload = unpack_any(packed_sample.payload, default_descriptor_pool) + for packed_sample in tqdm( + packed_samples, desc="Unpacking samples", unit="samples" + ): + unpacked_payload = unpack_any( + packed_sample.payload, default_descriptor_pool + ) if unpacked_payload is None: continue timestamp = ( - packed_sample.timestamp if packed_sample.HasField("timestamp") else None + packed_sample.timestamp + if packed_sample.HasField("timestamp") + else None ) if timestamp is not None: last_timestamp = ( - timestamp if last_timestamp is None else max(last_timestamp, timestamp) + timestamp + if last_timestamp is None + else max(last_timestamp, timestamp) ) if first_timestamp is None: first_timestamp = timestamp @@ -80,7 +91,9 @@ def parse_record_from_stream(data_stream: BinaryIO) -> Record: else: sample = Sample(timestamp=timestamp, payload=unpacked_payload) payload_type = type(unpacked_payload) - samples_by_type.setdefault(payload_type, list[Sample[T]]()).append(sample) + samples_by_type.setdefault(payload_type, list[Sample[T]]()).append( + sample + ) return Record( samples_by_type=samples_by_type, @@ -103,10 +116,14 @@ def parse_packed_samples_from_stream( packed_samples = [] pbar = tqdm( - desc="Parsing packed samples", unit="bytes", total=len(data_stream.getbuffer()) + desc="Parsing packed samples", + unit="bytes", + total=len(data_stream.getbuffer()), ) while data_stream.tell() < len(data_stream.getbuffer()): - packed_sample = _read_delimited(data_stream, packed_sample_pb2.PackedSample) + packed_sample = _read_delimited( + data_stream, packed_sample_pb2.PackedSample + ) if packed_sample is not None: packed_samples.append(packed_sample) diff --git a/plume_python/record.py b/plume_python/record.py index 7d12064..3087493 100644 --- a/plume_python/record.py +++ b/plume_python/record.py @@ -2,9 +2,8 @@ from typing import TypeVar, Generic, Optional, Type from .samples.record_pb2 import RecordMetadata -from google.protobuf.message import Message +from google.protobuf.message import Message # type: ignore -from .samples import record_pb2 T = TypeVar("T", bound=Message) @@ -65,7 +64,6 @@ def get_samples_in_time_range( samples_in_time_range = {} for payload_type, samples in self.samples_by_type.items(): - samples = [ sample for sample in samples diff --git a/plume_python/utils/dataframe.py b/plume_python/utils/dataframe.py index e08c4f5..e15c808 100644 --- a/plume_python/utils/dataframe.py +++ b/plume_python/utils/dataframe.py @@ -8,10 +8,12 @@ from plume_python.record import Sample, FrameDataSample, Record -T = TypeVar('T', bound=Message) +T = TypeVar("T", bound=Message) -def world_transforms_to_dataframe(world_transforms: list[TimestampedTransform]) -> pd.DataFrame: +def world_transforms_to_dataframe( + world_transforms: list[TimestampedTransform], +) -> pd.DataFrame: if len(world_transforms) == 0: return pd.DataFrame() @@ -21,18 +23,21 @@ def world_transforms_to_dataframe(world_transforms: list[TimestampedTransform]) world_position = world_transform.get_world_position() world_rotation = world_transform.get_world_rotation() world_scale = world_transform.get_world_scale() - world_transform_data.append({"timestamp": world_transform.timestamp, - "position_x": world_position[0], - "position_y": world_position[1], - "position_z": world_position[2], - "rotation_x": world_rotation.x, - "rotation_y": world_rotation.y, - "rotation_z": world_rotation.z, - "rotation_w": world_rotation.w, - "scale_x": world_scale[0], - "scale_y": world_scale[1], - "scale_z": world_scale[2] - }) + world_transform_data.append( + { + "timestamp": world_transform.timestamp, + "position_x": world_position[0], + "position_y": world_position[1], + "position_z": world_position[2], + "rotation_x": world_rotation.x, + "rotation_y": world_rotation.y, + "rotation_z": world_rotation.z, + "rotation_w": world_rotation.w, + "scale_x": world_scale[0], + "scale_y": world_scale[1], + "scale_z": world_scale[2], + } + ) return pd.json_normalize(world_transform_data) @@ -46,14 +51,24 @@ def samples_to_dataframe(samples: list[Sample[T]]) -> pd.DataFrame: if isinstance(samples[0], FrameDataSample): frame_samples = cast(list[FrameDataSample[T]], samples) for frame_sample in frame_samples: - sample_payload_fields_value = MessageToDict(frame_sample.payload, True) - sample_data.append({"timestamp": frame_sample.timestamp, - "frame_number": frame_sample.frame_number} | sample_payload_fields_value) + sample_payload_fields_value = MessageToDict( + frame_sample.payload, True + ) + sample_data.append( + { + "timestamp": frame_sample.timestamp, + "frame_number": frame_sample.frame_number, + } + | sample_payload_fields_value + ) else: for sample in samples: sample_payload_fields_value = MessageToDict(sample.payload, True) if sample.is_timestamped(): - sample_data.append({"timestamp": sample.timestamp} | sample_payload_fields_value) + sample_data.append( + {"timestamp": sample.timestamp} + | sample_payload_fields_value + ) else: sample_data.append(sample_payload_fields_value) diff --git a/plume_python/utils/game_object.py b/plume_python/utils/game_object.py index a0fa438..1b5aa22 100644 --- a/plume_python/utils/game_object.py +++ b/plume_python/utils/game_object.py @@ -24,7 +24,9 @@ def find_first_name_by_guid(record: Record, guid: str) -> Optional[str]: return None -def find_identifier_by_game_object_id(record: Record, game_object_id: str) -> Optional[GameObjectIdentifier]: +def find_identifier_by_game_object_id( + record: Record, game_object_id: str +) -> Optional[GameObjectIdentifier]: for go_update_sample in record[GameObjectUpdate]: go_update = go_update_sample.payload if go_update.id.game_object_id == game_object_id: @@ -32,20 +34,27 @@ def find_identifier_by_game_object_id(record: Record, game_object_id: str) -> Op return None -def find_identifiers_by_name(record: Record, name: str) -> list[GameObjectIdentifier]: +def find_identifiers_by_name( + record: Record, name: str +) -> list[GameObjectIdentifier]: identifiers: list[GameObjectIdentifier] = [] known_guids: set[str] = set() for go_update_sample in record[GameObjectUpdate]: go_update = go_update_sample.payload if go_update.HasField("name"): - if name == go_update.name and go_update.id.game_object_id not in known_guids: + if ( + name == go_update.name + and go_update.id.game_object_id not in known_guids + ): identifiers.append(go_update.id) known_guids.add(go_update.id.game_object_id) return identifiers -def find_first_identifier_by_name(record: Record, name: str) -> Optional[GameObjectIdentifier]: +def find_first_identifier_by_name( + record: Record, name: str +) -> Optional[GameObjectIdentifier]: for go_update_sample in record[GameObjectUpdate]: go_update = go_update_sample.payload if go_update.HasField("name"): diff --git a/plume_python/utils/transform.py b/plume_python/utils/transform.py index 6429654..e27a825 100644 --- a/plume_python/utils/transform.py +++ b/plume_python/utils/transform.py @@ -15,18 +15,34 @@ @dataclass(slots=True) class Transform: _guid: str - _local_position: np.ndarray = field(default_factory=lambda: np.array(3, dtype=np.float32)) - _local_rotation: quaternion.quaternion = field(default_factory=lambda: quaternion.quaternion(1, 0, 0, 0)) - _local_scale: np.ndarray = field(default_factory=lambda: np.array(4, dtype=np.float32)) - _local_T_mtx: np.ndarray = field(default_factory=lambda: np.eye(4, dtype=np.float32)) - _local_R_mtx: np.ndarray = field(default_factory=lambda: np.eye(4, dtype=np.float32)) - _local_S_mtx: np.ndarray = field(default_factory=lambda: np.eye(4, dtype=np.float32)) + _local_position: np.ndarray = field( + default_factory=lambda: np.array(3, dtype=np.float32) + ) + _local_rotation: quaternion.quaternion = field( + default_factory=lambda: quaternion.quaternion(1, 0, 0, 0) + ) + _local_scale: np.ndarray = field( + default_factory=lambda: np.array(4, dtype=np.float32) + ) + _local_T_mtx: np.ndarray = field( + default_factory=lambda: np.eye(4, dtype=np.float32) + ) + _local_R_mtx: np.ndarray = field( + default_factory=lambda: np.eye(4, dtype=np.float32) + ) + _local_S_mtx: np.ndarray = field( + default_factory=lambda: np.eye(4, dtype=np.float32) + ) _local_to_world_mtx: np.ndarray = None _parent: Optional[Transform] = None _dirty: bool = True def _is_dirty(self) -> bool: - return self._dirty or self._parent is not None and self._parent._is_dirty() + return ( + self._dirty + or self._parent is not None + and self._parent._is_dirty() + ) def get_guid(self) -> str: return self._guid @@ -37,7 +53,9 @@ def set_local_position(self, local_position: np.ndarray): self._dirty = True def set_local_rotation(self, local_rotation: quaternion): - self._local_R_mtx[0:3, 0:3] = quaternion.as_rotation_matrix(local_rotation) + self._local_R_mtx[0:3, 0:3] = quaternion.as_rotation_matrix( + local_rotation + ) self._local_rotation = local_rotation self._dirty = True @@ -66,13 +84,14 @@ def get_parent(self) -> Optional[Transform]: def get_local_to_world_matrix(self) -> np.ndarray: if self._is_dirty() or self._local_to_world_mtx is None: - trs_mtx = self._local_T_mtx @ self._local_R_mtx @ self._local_S_mtx if self._parent is None: self._local_to_world_mtx = trs_mtx else: - self._local_to_world_mtx = self._parent.get_local_to_world_matrix() @ trs_mtx + self._local_to_world_mtx = ( + self._parent.get_local_to_world_matrix() @ trs_mtx + ) self._dirty = False @@ -82,7 +101,9 @@ def get_world_position(self) -> np.ndarray: return self.get_local_to_world_matrix()[0:3, 3].transpose() def get_world_rotation(self) -> quaternion: - return quaternion.from_rotation_matrix(self.get_local_to_world_matrix()[0:3, 0:3]) + return quaternion.from_rotation_matrix( + self.get_local_to_world_matrix()[0:3, 0:3] + ) def get_world_scale(self) -> np.ndarray: local_to_world_mtx = self.get_local_to_world_matrix() @@ -108,7 +129,9 @@ def get_world_position(self) -> np.ndarray: return self.local_to_world_mtx[0:3, 3].transpose() def get_world_rotation(self) -> quaternion: - return quaternion.from_rotation_matrix(self.local_to_world_mtx[0:3, 0:3]) + return quaternion.from_rotation_matrix( + self.local_to_world_mtx[0:3, 0:3] + ) def get_world_scale(self) -> np.ndarray: sx = np.linalg.norm(self.local_to_world_mtx[0:3, 0]) @@ -118,27 +141,42 @@ def get_world_scale(self) -> np.ndarray: return scale -def compute_transform_time_series(record: Record, guid: str) -> list[TimestampedTransform]: +def compute_transform_time_series( + record: Record, guid: str +) -> list[TimestampedTransform]: transform_time_series = compute_transforms_time_series(record, {guid}) return transform_time_series.get(guid, {}) -def compute_transforms_time_series(record: Record, included_guids: set[str] = None) \ - -> dict[str, list[TimestampedTransform]]: +def compute_transforms_time_series( + record: Record, included_guids: set[str] = None +) -> dict[str, list[TimestampedTransform]]: result: dict[str, list[TimestampedTransform]] = {} current_transforms: dict[str, Transform] = {} - creation_samples: dict[int, list[FrameDataSample[transform_pb2.TransformCreate]]] = {} - destruction_samples: dict[int, list[FrameDataSample[transform_pb2.TransformDestroy]]] = {} - update_samples: dict[int, list[FrameDataSample[transform_pb2.TransformUpdate]]] = {} - - for frame_number, s in groupby(record[transform_pb2.TransformCreate], lambda x: x.frame_number): + creation_samples: dict[ + int, list[FrameDataSample[transform_pb2.TransformCreate]] + ] = {} + destruction_samples: dict[ + int, list[FrameDataSample[transform_pb2.TransformDestroy]] + ] = {} + update_samples: dict[ + int, list[FrameDataSample[transform_pb2.TransformUpdate]] + ] = {} + + for frame_number, s in groupby( + record[transform_pb2.TransformCreate], lambda x: x.frame_number + ): creation_samples[frame_number] = list(s) - for frame_number, s in groupby(record[transform_pb2.TransformDestroy], lambda x: x.frame_number): + for frame_number, s in groupby( + record[transform_pb2.TransformDestroy], lambda x: x.frame_number + ): destruction_samples[frame_number] = list(s) - for frame_number, s in groupby(record[transform_pb2.TransformUpdate], lambda x: x.frame_number): + for frame_number, s in groupby( + record[transform_pb2.TransformUpdate], lambda x: x.frame_number + ): update_samples[frame_number] = list(s) for frame in tqdm(record.frames_info, desc="Computing world positions"): @@ -161,22 +199,43 @@ def compute_transforms_time_series(record: Record, included_guids: set[str] = No for update_sample in update_samples[frame.frame_number]: guid = update_sample.payload.id.component_id local_transform = current_transforms[guid] - if update_sample.payload.HasField('local_position'): + if update_sample.payload.HasField("local_position"): local_position = update_sample.payload.local_position - local_transform.set_local_position(np.array([local_position.x, local_position.y, local_position.z])) - if update_sample.payload.HasField('local_rotation'): + local_transform.set_local_position( + np.array( + [ + local_position.x, + local_position.y, + local_position.z, + ] + ) + ) + if update_sample.payload.HasField("local_rotation"): local_rotation = update_sample.payload.local_rotation - q = quaternion.quaternion(local_rotation.w, local_rotation.x, local_rotation.y, local_rotation.z) + q = quaternion.quaternion( + local_rotation.w, + local_rotation.x, + local_rotation.y, + local_rotation.z, + ) local_transform.set_local_rotation(q) - if update_sample.payload.HasField('local_scale'): + if update_sample.payload.HasField("local_scale"): local_scale = update_sample.payload.local_scale - local_transform.set_local_scale(np.array([local_scale.x, local_scale.y, local_scale.z])) - if update_sample.payload.HasField('parent_transform_id'): - parent_guid = update_sample.payload.parent_transform_id.component_id - if parent_guid == "00000000000000000000000000000000": # null guid + local_transform.set_local_scale( + np.array([local_scale.x, local_scale.y, local_scale.z]) + ) + if update_sample.payload.HasField("parent_transform_id"): + parent_guid = ( + update_sample.payload.parent_transform_id.component_id + ) + if ( + parent_guid == "00000000000000000000000000000000" + ): # null guid local_transform.set_parent(None) elif parent_guid in current_transforms: - local_transform.set_parent(current_transforms[parent_guid]) + local_transform.set_parent( + current_transforms[parent_guid] + ) else: parent = Transform(_guid=parent_guid) local_transform.set_parent(parent) @@ -185,17 +244,25 @@ def compute_transforms_time_series(record: Record, included_guids: set[str] = No if included_guids is None: included_transforms = current_transforms.values() else: - included_transforms = [current_transforms[guid] for guid in included_guids if guid in current_transforms] + included_transforms = [ + current_transforms[guid] + for guid in included_guids + if guid in current_transforms + ] for t in included_transforms: - timestamped_transform = TimestampedTransform(timestamp=frame.timestamp, - frame_number=frame.frame_number, - guid=t.get_guid(), - parent_guid=t.get_parent().get_guid(), - local_scale=t.get_local_scale(), - local_position=t.get_local_position(), - local_rotation=t.get_local_rotation(), - local_to_world_mtx=t.get_local_to_world_matrix()) - result.setdefault(t.get_guid(), list[TimestampedTransform]()).append(timestamped_transform) + timestamped_transform = TimestampedTransform( + timestamp=frame.timestamp, + frame_number=frame.frame_number, + guid=t.get_guid(), + parent_guid=t.get_parent().get_guid(), + local_scale=t.get_local_scale(), + local_position=t.get_local_position(), + local_rotation=t.get_local_rotation(), + local_to_world_mtx=t.get_local_to_world_matrix(), + ) + result.setdefault( + t.get_guid(), list[TimestampedTransform]() + ).append(timestamped_transform) return result diff --git a/poetry.lock b/poetry.lock index 10a2ebd..3a67d48 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,16 @@ # This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +[[package]] +name = "cfgv" +version = "3.4.0" +description = "Validate configuration and produce human readable error messages." +optional = false +python-versions = ">=3.8" +files = [ + {file = "cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9"}, + {file = "cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560"}, +] + [[package]] name = "click" version = "8.1.7" @@ -39,6 +50,17 @@ files = [ [package.dependencies] protobuf = "*" +[[package]] +name = "distlib" +version = "0.3.8" +description = "Distribution utilities" +optional = false +python-versions = "*" +files = [ + {file = "distlib-0.3.8-py2.py3-none-any.whl", hash = "sha256:034db59a0b96f8ca18035f36290806a9a6e6bd9d1ff91e45a7f172eb17e51784"}, + {file = "distlib-0.3.8.tar.gz", hash = "sha256:1530ea13e350031b6312d8580ddb6b27a104275a31106523b8f123787f494f64"}, +] + [[package]] name = "exceptiongroup" version = "1.2.1" @@ -53,6 +75,36 @@ files = [ [package.extras] test = ["pytest (>=6)"] +[[package]] +name = "filelock" +version = "3.15.4" +description = "A platform independent file lock." +optional = false +python-versions = ">=3.8" +files = [ + {file = "filelock-3.15.4-py3-none-any.whl", hash = "sha256:6ca1fffae96225dab4c6eaf1c4f4f28cd2568d3ec2a44e15a08520504de468e7"}, + {file = "filelock-3.15.4.tar.gz", hash = "sha256:2207938cbc1844345cb01a5a95524dae30f0ce089eba5b00378295a17e3e90cb"}, +] + +[package.extras] +docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-asyncio (>=0.21)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)", "virtualenv (>=20.26.2)"] +typing = ["typing-extensions (>=4.8)"] + +[[package]] +name = "identify" +version = "2.5.36" +description = "File identification library for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "identify-2.5.36-py2.py3-none-any.whl", hash = "sha256:37d93f380f4de590500d9dba7db359d0d3da95ffe7f9de1753faa159e71e7dfa"}, + {file = "identify-2.5.36.tar.gz", hash = "sha256:e5e00f54165f9047fbebeb4a560f9acfb8af4c88232be60a488e9b68d122745d"}, +] + +[package.extras] +license = ["ukkonen"] + [[package]] name = "iniconfig" version = "2.0.0" @@ -114,6 +166,17 @@ docs = ["sphinx (>=1.6.0)", "sphinx-bootstrap-theme"] flake8 = ["flake8"] tests = ["psutil", "pytest (!=3.3.0)", "pytest-cov"] +[[package]] +name = "nodeenv" +version = "1.9.1" +description = "Node.js virtual environment builder" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +files = [ + {file = "nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9"}, + {file = "nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f"}, +] + [[package]] name = "numpy" version = "1.26.4" @@ -310,6 +373,22 @@ sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-d test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"] xml = ["lxml (>=4.9.2)"] +[[package]] +name = "platformdirs" +version = "4.2.2" +description = "A small Python package for determining appropriate platform-specific dirs, e.g. a `user data dir`." +optional = false +python-versions = ">=3.8" +files = [ + {file = "platformdirs-4.2.2-py3-none-any.whl", hash = "sha256:2d7a1657e36a80ea911db832a8a6ece5ee53d8de21edd5cc5879af6530b1bfee"}, + {file = "platformdirs-4.2.2.tar.gz", hash = "sha256:38b7b51f512eed9e84a22788b4bce1de17c0adb134d6becb09836e37d8654cd3"}, +] + +[package.extras] +docs = ["furo (>=2023.9.10)", "proselint (>=0.13)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] +test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)"] +type = ["mypy (>=1.8)"] + [[package]] name = "pluggy" version = "1.5.0" @@ -325,6 +404,24 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] +[[package]] +name = "pre-commit" +version = "3.7.1" +description = "A framework for managing and maintaining multi-language pre-commit hooks." +optional = false +python-versions = ">=3.9" +files = [ + {file = "pre_commit-3.7.1-py2.py3-none-any.whl", hash = "sha256:fae36fd1d7ad7d6a5a1c0b0d5adb2ed1a3bda5a21bf6c3e5372073d7a11cd4c5"}, + {file = "pre_commit-3.7.1.tar.gz", hash = "sha256:8ca3ad567bc78a4972a3f1a477e94a79d4597e8140a6e0b651c5e33899c3654a"}, +] + +[package.dependencies] +cfgv = ">=2.0.0" +identify = ">=1.0.0" +nodeenv = ">=0.11.1" +pyyaml = ">=5.1" +virtualenv = ">=20.10.0" + [[package]] name = "protobuf" version = "5.27.1" @@ -392,6 +489,66 @@ files = [ {file = "pytz-2024.1.tar.gz", hash = "sha256:2a29735ea9c18baf14b448846bde5a48030ed267578472d8955cd0e7443a9812"}, ] +[[package]] +name = "pyyaml" +version = "6.0.1" +description = "YAML parser and emitter for Python" +optional = false +python-versions = ">=3.6" +files = [ + {file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"}, + {file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, + {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, + {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, + {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, + {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, + {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win32.whl", hash = "sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa"}, + {file = "PyYAML-6.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win32.whl", hash = "sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867"}, + {file = "PyYAML-6.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, + {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, + {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, + {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, + {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, + {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, +] + [[package]] name = "six" version = "1.16.0" @@ -445,7 +602,27 @@ files = [ {file = "tzdata-2024.1.tar.gz", hash = "sha256:2674120f8d891909751c38abcdfd386ac0a5a1127954fbc332af6b5ceae07efd"}, ] +[[package]] +name = "virtualenv" +version = "20.26.3" +description = "Virtual Python Environment builder" +optional = false +python-versions = ">=3.7" +files = [ + {file = "virtualenv-20.26.3-py3-none-any.whl", hash = "sha256:8cc4a31139e796e9a7de2cd5cf2489de1217193116a8fd42328f1bd65f434589"}, + {file = "virtualenv-20.26.3.tar.gz", hash = "sha256:4c43a2a236279d9ea36a0d76f98d84bd6ca94ac4e0f4a3b9d46d05e10fea542a"}, +] + +[package.dependencies] +distlib = ">=0.3.7,<1" +filelock = ">=3.12.2,<4" +platformdirs = ">=3.9.1,<5" + +[package.extras] +docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2,!=7.3)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] +test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8)", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10)"] + [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "62afe993ca87613b19b6253ce713518c998f3ad9b363f1bd29592e4d102dcdbd" +content-hash = "0389fcf30f061f817d308ced1e3155a12d1f6676196c6335bfff3b7ee5bcb70a" diff --git a/pyproject.toml b/pyproject.toml index 900e80a..aa64118 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ optional = true [tool.poetry.group.dev.dependencies] pytest = "^7.4.4" +pre-commit = "^3.7.1" [tool.setuptools.packages.find] exclude = ["tests"] @@ -30,6 +31,10 @@ exclude = ["tests"] [tool.poetry.scripts] plume-python = "plume_python.cli:cli" +[tool.ruff] +line-length = 79 +target-version = "py39" + [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" diff --git a/tests/.gitignore b/tests/.gitignore index 9cf4042..69341fd 100644 --- a/tests/.gitignore +++ b/tests/.gitignore @@ -1 +1 @@ -test_outputs/ \ No newline at end of file +test_outputs/ diff --git a/tests/test_compute_world_transform.py b/tests/test_compute_world_transform.py index 95ac7f0..2ae32fb 100644 --- a/tests/test_compute_world_transform.py +++ b/tests/test_compute_world_transform.py @@ -1,11 +1,18 @@ import pytest -from plume_python.utils.transform import compute_transforms_time_series, compute_transform_time_series +from plume_python.utils.transform import ( + compute_transforms_time_series, + compute_transform_time_series, +) def test_compute_single_transform_time_series(): - transform_time_series = compute_transform_time_series(pytest.record, "4a3f40e37eaf4c0a9d5d88ac993c0ebc") + transform_time_series = compute_transform_time_series( + pytest.record, "4a3f40e37eaf4c0a9d5d88ac993c0ebc" + ) def test_compute_all_transforms_time_series(): - transform_time_series = compute_transforms_time_series(pytest.record, {"4a3f40e37eaf4c0a9d5d88ac993c0ebc"}) + transform_time_series = compute_transforms_time_series( + pytest.record, {"4a3f40e37eaf4c0a9d5d88ac993c0ebc"} + ) diff --git a/tests/test_export_xdf.py b/tests/test_export_xdf.py index 9b18686..4cd4993 100644 --- a/tests/test_export_xdf.py +++ b/tests/test_export_xdf.py @@ -7,6 +7,6 @@ def test_export_xdf(): record = pytest.record # create directory tests/test_outputs if it does not exist - Path('tests/test_outputs').mkdir(parents=True, exist_ok=True) - with open('tests/test_outputs/test.xdf', 'wb') as f: + Path("tests/test_outputs").mkdir(parents=True, exist_ok=True) + with open("tests/test_outputs/test.xdf", "wb") as f: export_xdf_from_record(f, record) diff --git a/tests/test_find_game_object_identifiers_by_name.py b/tests/test_find_game_object_identifiers_by_name.py index 9dccf8b..467be5b 100644 --- a/tests/test_find_game_object_identifiers_by_name.py +++ b/tests/test_find_game_object_identifiers_by_name.py @@ -1,6 +1,9 @@ import pytest -from plume_python.utils.game_object import find_identifiers_by_name, find_first_identifier_by_name +from plume_python.utils.game_object import ( + find_identifiers_by_name, + find_first_identifier_by_name, +) def test_find_first_game_object_identifier(): diff --git a/tests/test_parser.py b/tests/test_parser.py index eda071f..1ec67f4 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -22,7 +22,13 @@ def test_parse_record(): def test_simple_filtering(): record = cast(parser.Record, pytest.record) - filtered_lsl = [lsl_sample for lsl_sample in record[lsl_stream_pb2.StreamSample] if - 0 <= lsl_sample.timestamp <= 5_000_000_000] - filtered_transform_updates = [frame for frame in record[transform_pb2.TransformUpdate] if - 0 <= frame.timestamp <= 5_000_000_000] + filtered_lsl = [ + lsl_sample + for lsl_sample in record[lsl_stream_pb2.StreamSample] + if 0 <= lsl_sample.timestamp <= 5_000_000_000 + ] + filtered_transform_updates = [ + frame + for frame in record[transform_pb2.TransformUpdate] + if 0 <= frame.timestamp <= 5_000_000_000 + ]