Skip to content

Commit

Permalink
added VolumetricImagingExtractor to its own file
Browse files Browse the repository at this point in the history
  • Loading branch information
pauladkisson committed Sep 20, 2023
1 parent a96afc2 commit 6e0e83d
Showing 1 changed file with 148 additions and 0 deletions.
148 changes: 148 additions & 0 deletions src/roiextractors/volumetricimagingextractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
from typing import Tuple, List, Iterable, Optional
import numpy as np

from .extraction_tools import ArrayType, DtypeType
from .imagingextractor import ImagingExtractor


class VolumetricImagingExtractor(ImagingExtractor):
"""Class to combine multiple ImagingExtractor objects by depth plane."""

extractor_name = "VolumetricImaging"
installed = True
installatiuon_mesage = ""

def __init__(self, imaging_extractors: List[ImagingExtractor]):
"""Initialize a VolumetricImagingExtractor object from a list of ImagingExtractors.
Parameters
----------
imaging_extractors: list of ImagingExtractor
list of imaging extractor objects
"""
super().__init__()
assert isinstance(imaging_extractors, list), "Enter a list of ImagingExtractor objects as argument"
assert all(isinstance(imaging_extractor, ImagingExtractor) for imaging_extractor in imaging_extractors)
self._check_consistency_between_imaging_extractors(imaging_extractors)
self._imaging_extractors = imaging_extractors
self._num_planes = len(imaging_extractors)

def _check_consistency_between_imaging_extractors(self, imaging_extractors: List[ImagingExtractor]):
"""Check that essential properties are consistent between extractors so that they can be combined appropriately.
Parameters
----------
imaging_extractors: list of ImagingExtractor
list of imaging extractor objects
Raises
------
AssertionError
If any of the properties are not consistent between extractors.
Notes
-----
This method checks the following properties:
- sampling frequency
- image size
- number of channels
- channel names
- data type
- num_frames
"""
properties_to_check = dict(
get_sampling_frequency="The sampling frequency",
get_image_size="The size of a frame",
get_num_channels="The number of channels",
get_channel_names="The name of the channels",
get_dtype="The data type",
get_num_frames="The number of frames",
)
for method, property_message in properties_to_check.items():
values = [getattr(extractor, method)() for extractor in imaging_extractors]
unique_values = set(tuple(v) if isinstance(v, Iterable) else v for v in values)
assert (
len(unique_values) == 1
), f"{property_message} is not consistent over the files (found {unique_values})."

def get_video(self, start_frame: Optional[int] = None, end_frame: Optional[int] = None) -> np.ndarray:
"""Get the video frames.
Parameters
----------
start_frame: int, optional
Start frame index (inclusive).
end_frame: int, optional
End frame index (exclusive).
Returns
-------
video: numpy.ndarray
The 3D video frames (num_rows, num_columns, num_planes).
"""
start_frame = start_frame if start_frame is not None else 0
end_frame = end_frame if end_frame is not None else self.get_num_frames()

video = np.zeros((end_frame - start_frame, *self.get_image_size()), self.get_dtype())
for i, imaging_extractor in enumerate(self._imaging_extractors):
video[..., i] = imaging_extractor.get_video(start_frame, end_frame)
return video

def get_frames(self, frame_idxs: ArrayType) -> np.ndarray:
"""Get specific video frames from indices (not necessarily continuous).
Parameters
----------
frame_idxs: array-like
Indices of frames to return.
Returns
-------
frames: numpy.ndarray
The 3D video frames (num_rows, num_columns, num_planes).
"""
if isinstance(frame_idxs, int):
frame_idxs = [frame_idxs]

if not all(np.diff(frame_idxs) == 1):
frames = np.zeros((len(frame_idxs), *self.get_image_size()), self.get_dtype())
for i, imaging_extractor in enumerate(self._imaging_extractors):
frames[..., i] = imaging_extractor.get_frames(frame_idxs)
else:
return self.get_video(start_frame=frame_idxs[0], end_frame=frame_idxs[-1] + 1)

def get_image_size(self) -> Tuple:
"""Get the size of a single frame.
Returns
-------
image_size: tuple
The size of a single frame (num_rows, num_columns, num_planes).
"""
image_size = (*self._imaging_extractors[0].get_image_size(), self.get_num_planes())
return image_size

def get_num_planes(self) -> int:
"""Get the number of depth planes.
Returns
-------
_num_planes: int
The number of depth planes.
"""
return self._num_planes

def get_num_frames(self) -> int:
return self._imaging_extractors[0].get_num_frames()

def get_sampling_frequency(self) -> float:
return self._imaging_extractors[0].get_sampling_frequency()

def get_channel_names(self) -> list:
return self._imaging_extractors[0].get_channel_names()

def get_num_channels(self) -> int:
return self._imaging_extractors[0].get_num_channels()

def get_dtype(self) -> DtypeType:
return self._imaging_extractors[0].get_dtype()

0 comments on commit 6e0e83d

Please sign in to comment.