Skip to content

Commit

Permalink
Update ruff settings
Browse files Browse the repository at this point in the history
  • Loading branch information
imenelydiaker authored and cjaverliat committed Jun 23, 2024
1 parent f671233 commit 6533294
Show file tree
Hide file tree
Showing 9 changed files with 60 additions and 181 deletions.
24 changes: 5 additions & 19 deletions plume_python/cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os.path

import click
import click # type: ignore

from plume_python import parser
from plume_python.export.xdf_exporter import export_xdf_from_record
Expand Down Expand Up @@ -36,17 +36,13 @@ def export_xdf(record_path: str, xdf_output_path: str | None):
xdf_output_path = record_path.replace(".plm", ".xdf")

if os.path.exists(xdf_output_path):
if not click.confirm(
f"File '{xdf_output_path}' already exists, do you want to overwrite it?"
):
if not click.confirm(f"File '{xdf_output_path}' already exists, do you want to overwrite it?"):
return

with open(xdf_output_path, "wb") as xdf_output_file:
record = parser.parse_record_from_file(record_path)
export_xdf_from_record(xdf_output_file, record)
click.echo(
"Exported xdf from record: " + record_path + " to " + xdf_output_path
)
click.echo("Exported xdf from record: " + record_path + " to " + xdf_output_path)


@click.command()
Expand Down Expand Up @@ -134,24 +130,14 @@ def export_csv(record_path: str, output_dir: str | None, filter: str):
for sample_type, df in dataframes.items():
file_path = os.path.join(output_dir, sample_type.__name__ + ".csv")
df.to_csv(file_path)
click.echo(
"Exported CSV for sample type: "
+ sample_type.__name__
+ " to "
+ file_path
)
click.echo("Exported CSV for sample type: " + sample_type.__name__ + " to " + file_path)
else:
sample_types = sample_types_from_names(filters)
for sample_type in sample_types:
df = samples_to_dataframe(record.get_samples_by_type(sample_type))
file_path = os.path.join(output_dir, sample_type.__name__ + ".csv")
df.to_csv(file_path)
click.echo(
"Exported CSV for sample type: "
+ sample_type.__name__
+ " to "
+ file_path
)
click.echo("Exported CSV for sample type: " + sample_type.__name__ + " to " + file_path)


cli.add_command(export_csv)
Expand Down
53 changes: 12 additions & 41 deletions plume_python/export/xdf_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,7 @@


def export_xdf_from_record(output_file: BinaryIO, record: Record):
datetime_str = (
record.get_metadata()
.start_time.ToDatetime()
.astimezone()
.strftime("%Y-%m-%dT%H:%M:%S%z")
)
datetime_str = record.get_metadata().start_time.ToDatetime().astimezone().strftime("%Y-%m-%dT%H:%M:%S%z")
# Add a colon separator to the offset segment
datetime_str = "{0}:{1}".format(datetime_str[:-2], datetime_str[-2:])

Expand All @@ -40,9 +35,7 @@ def export_xdf_from_record(output_file: BinaryIO, record: Record):

for lsl_open_stream in record[lsl_stream_pb2.StreamOpen]:
xml_header = ET.fromstring(lsl_open_stream.payload.xml_header)
stream_id = (
np.uint64(lsl_open_stream.payload.stream_id) + 1
) # reserve id = 1 for the marker stream
stream_id = np.uint64(lsl_open_stream.payload.stream_id) + 1 # reserve id = 1 for the marker stream
channel_format = xml_header.find("channel_format").text
stream_channel_format[stream_id] = channel_format
stream_min_time[stream_id] = None
Expand All @@ -66,33 +59,19 @@ def export_xdf_from_record(output_file: BinaryIO, record: Record):
stream_sample_count[stream_id] += 1

if channel_format == "string":
val = np.array(
[x for x in lsl_sample.payload.string_values.value], dtype=np.str_
)
val = np.array([x for x in lsl_sample.payload.string_values.value], dtype=np.str_)
elif channel_format == "int8":
val = np.array(
[x for x in lsl_sample.payload.int8_values.value], dtype=np.int8
)
val = np.array([x for x in lsl_sample.payload.int8_values.value], dtype=np.int8)
elif channel_format == "int16":
val = np.array(
[x for x in lsl_sample.payload.int16_values.value], dtype=np.int16
)
val = np.array([x for x in lsl_sample.payload.int16_values.value], dtype=np.int16)
elif channel_format == "int32":
val = np.array(
[x for x in lsl_sample.payload.int32_values.value], dtype=np.int32
)
val = np.array([x for x in lsl_sample.payload.int32_values.value], dtype=np.int32)
elif channel_format == "int64":
val = np.array(
[x for x in lsl_sample.payload.int64_values.value], dtype=np.int64
)
val = np.array([x for x in lsl_sample.payload.int64_values.value], dtype=np.int64)
elif channel_format == "float32":
val = np.array(
[x for x in lsl_sample.payload.float_values.value], dtype=np.float32
)
val = np.array([x for x in lsl_sample.payload.float_values.value], dtype=np.float32)
elif channel_format == "double64":
val = np.array(
[x for x in lsl_sample.payload.double_values.value], dtype=np.float64
)
val = np.array([x for x in lsl_sample.payload.double_values.value], dtype=np.float64)
else:
raise ValueError(f"Unsupported channel format: {channel_format}")

Expand All @@ -101,15 +80,9 @@ def export_xdf_from_record(output_file: BinaryIO, record: Record):
for marker_sample in record[marker_pb2.Marker]:
t = marker_sample.timestamp / 1_000_000_000.0 # convert time to seconds

if (
stream_min_time[marker_stream_id] is None
or t < stream_min_time[marker_stream_id]
):
if stream_min_time[marker_stream_id] is None or t < stream_min_time[marker_stream_id]:
stream_min_time[marker_stream_id] = t
if (
stream_max_time[marker_stream_id] is None
or t > stream_max_time[marker_stream_id]
):
if stream_max_time[marker_stream_id] is None or t > stream_max_time[marker_stream_id]:
stream_max_time[marker_stream_id] = t

if marker_stream_id not in stream_sample_count:
Expand Down Expand Up @@ -155,6 +128,4 @@ def write_marker_stream_header(output_buf, marker_stream_id):
channel_count_el.text = "1"
nominal_srate_el.text = "0.0"
xml = ET.tostring(info_el, encoding=STR_ENCODING, xml_declaration=True)
write_stream_header(
output_buf, xml, marker_stream_id
) # stream_id = 1 is reserved for the marker stream
write_stream_header(output_buf, xml, marker_stream_id) # stream_id = 1 is reserved for the marker stream
26 changes: 4 additions & 22 deletions plume_python/export/xdf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ def write_file_header(output: BinaryIO, version: str, datetime: str):
write_chunk(output, ChunkTag.FILE_HEADER, xml_str)


def write_chunk(
output: BinaryIO, chunk_tag: ChunkTag, content: bytes, stream_id: np.uint32 = None
):
def write_chunk(output: BinaryIO, chunk_tag: ChunkTag, content: bytes, stream_id: np.uint32 = None):
if not isinstance(content, bytes):
raise Exception("Content should be bytes.")

Expand All @@ -74,9 +72,7 @@ def write_chunk(
write(output, content)


def write_stream_header(
output: BinaryIO, xml_header: str | bytes, stream_id: np.uint32 = None
):
def write_stream_header(output: BinaryIO, xml_header: str | bytes, stream_id: np.uint32 = None):
if isinstance(xml_header, str):
xml_header = bytes(xml_header, encoding=STR_ENCODING)

Expand Down Expand Up @@ -210,25 +206,11 @@ def write_variable_length_integer(output: BinaryIO, val: np.uint64):

def write_fixed_length_integer(
output: BinaryIO,
val: np.int8
| np.int16
| np.int32
| np.int64
| np.uint8
| np.uint16
| np.uint32
| np.uint64,
val: np.int8 | np.int16 | np.int32 | np.int64 | np.uint8 | np.uint16 | np.uint32 | np.uint64,
):
if not isinstance(
val,
np.int8
| np.int16
| np.int32
| np.int64
| np.uint8
| np.uint16
| np.uint32
| np.uint64,
np.int8 | np.int16 | np.int32 | np.int64 | np.uint8 | np.uint16 | np.uint32 | np.uint64,
):
raise Exception("Unsupported data type " + str(type(val)))

Expand Down
36 changes: 11 additions & 25 deletions plume_python/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from warnings import warn
from tqdm import tqdm

from delimited_protobuf import read as _read_delimited
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf.message import Message
from google.protobuf.message_factory import GetMessageClass
from delimited_protobuf import read as _read_delimited # type: ignore
from google.protobuf import descriptor_pool as _descriptor_pool # type: ignore
from google.protobuf.message import Message # type: ignore
from google.protobuf.message_factory import GetMessageClass # type: ignore

from . import file_reader
from .record import Record, Sample, FrameDataSample, FrameInfo
Expand All @@ -15,9 +15,7 @@
T = TypeVar("T", bound=Message)


def unpack_any(
any_payload: Any, descriptor_pool: _descriptor_pool.DescriptorPool
) -> Optional[Message]:
def unpack_any(any_payload: Any, descriptor_pool: _descriptor_pool.DescriptorPool) -> Optional[Message]:
"""Unpacks an Any message into its original type using the provided descriptor pool."""
try:
descriptor = descriptor_pool.FindMessageTypeByName(any_payload.TypeName())
Expand All @@ -44,27 +42,19 @@ def parse_record_from_stream(data_stream: BinaryIO) -> Record:
unpacked_payload = unpack_any(packed_sample.payload, default_descriptor_pool)
if unpacked_payload is None:
continue
timestamp = (
packed_sample.timestamp if packed_sample.HasField("timestamp") else None
)
timestamp = packed_sample.timestamp if packed_sample.HasField("timestamp") else None

if timestamp is not None:
last_timestamp = (
timestamp if last_timestamp is None else max(last_timestamp, timestamp)
)
last_timestamp = timestamp if last_timestamp is None else max(last_timestamp, timestamp)
if first_timestamp is None:
first_timestamp = timestamp

if isinstance(unpacked_payload, frame_pb2.Frame):
frame = cast(frame_pb2.Frame, unpacked_payload)
frames_info.append(
FrameInfo(frame_number=frame.frame_number, timestamp=timestamp)
)
frames_info.append(FrameInfo(frame_number=frame.frame_number, timestamp=timestamp))

for packed_frame_data in frame.data:
unpacked_frame_data = unpack_any(
packed_frame_data, default_descriptor_pool
)
unpacked_frame_data = unpack_any(packed_frame_data, default_descriptor_pool)
if unpacked_frame_data is None:
continue
frame_data_sample = FrameDataSample(
Expand All @@ -73,9 +63,7 @@ def parse_record_from_stream(data_stream: BinaryIO) -> Record:
payload=unpacked_frame_data,
)
payload_type = type(unpacked_frame_data)
samples_by_type.setdefault(
payload_type, list[FrameDataSample[T]]()
).append(frame_data_sample)
samples_by_type.setdefault(payload_type, list[FrameDataSample[T]]()).append(frame_data_sample)
else:
sample = Sample(timestamp=timestamp, payload=unpacked_payload)
payload_type = type(unpacked_payload)
Expand All @@ -101,9 +89,7 @@ def parse_packed_samples_from_stream(
"""Parses packed samples from a binary stream and returns a list of packed samples."""
packed_samples = []

pbar = tqdm(
desc="Parsing packed samples", unit="bytes", total=len(data_stream.getbuffer())
)
pbar = tqdm(desc="Parsing packed samples", unit="bytes", total=len(data_stream.getbuffer()))
while data_stream.tell() < len(data_stream.getbuffer()):
packed_sample = _read_delimited(data_stream, packed_sample_pb2.PackedSample)

Expand Down
11 changes: 3 additions & 8 deletions plume_python/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import TypeVar, Generic, Optional, Type
from .samples.record_pb2 import RecordMetadata

from google.protobuf.message import Message
from google.protobuf.message import Message # type: ignore


T = TypeVar("T", bound=Message)
Expand Down Expand Up @@ -52,15 +52,10 @@ def get_samples_by_type(self, payload_type: Type[T]) -> list[Sample[T]]:

def get_timeless_samples(self) -> list[Sample[T]]:
return [
sample
for samples in self.samples_by_type.values()
for sample in samples
if not sample.is_timestamped()
sample for samples in self.samples_by_type.values() for sample in samples if not sample.is_timestamped()
]

def get_samples_in_time_range(
self, start: Optional[int], end: Optional[int]
) -> dict[Type[T], list[Sample[T]]]:
def get_samples_in_time_range(self, start: Optional[int], end: Optional[int]) -> dict[Type[T], list[Sample[T]]]:
samples_in_time_range = {}

for payload_type, samples in self.samples_by_type.items():
Expand Down
4 changes: 1 addition & 3 deletions plume_python/utils/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,7 @@ def samples_to_dataframe(samples: list[Sample[T]]) -> pd.DataFrame:
for sample in samples:
sample_payload_fields_value = MessageToDict(sample.payload, True)
if sample.is_timestamped():
sample_data.append(
{"timestamp": sample.timestamp} | sample_payload_fields_value
)
sample_data.append({"timestamp": sample.timestamp} | sample_payload_fields_value)
else:
sample_data.append(sample_payload_fields_value)

Expand Down
13 changes: 3 additions & 10 deletions plume_python/utils/game_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ def find_first_name_by_guid(record: Record, guid: str) -> Optional[str]:
return None


def find_identifier_by_game_object_id(
record: Record, game_object_id: str
) -> Optional[GameObjectIdentifier]:
def find_identifier_by_game_object_id(record: Record, game_object_id: str) -> Optional[GameObjectIdentifier]:
for go_update_sample in record[GameObjectUpdate]:
go_update = go_update_sample.payload
if go_update.id.game_object_id == game_object_id:
Expand All @@ -41,18 +39,13 @@ def find_identifiers_by_name(record: Record, name: str) -> list[GameObjectIdenti
for go_update_sample in record[GameObjectUpdate]:
go_update = go_update_sample.payload
if go_update.HasField("name"):
if (
name == go_update.name
and go_update.id.game_object_id not in known_guids
):
if name == go_update.name and go_update.id.game_object_id not in known_guids:
identifiers.append(go_update.id)
known_guids.add(go_update.id.game_object_id)
return identifiers


def find_first_identifier_by_name(
record: Record, name: str
) -> Optional[GameObjectIdentifier]:
def find_first_identifier_by_name(record: Record, name: str) -> Optional[GameObjectIdentifier]:
for go_update_sample in record[GameObjectUpdate]:
go_update = go_update_sample.payload
if go_update.HasField("name"):
Expand Down
Loading

0 comments on commit 6533294

Please sign in to comment.