-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added VolumetricImagingExtractor to its own file
- Loading branch information
1 parent
a96afc2
commit 6e0e83d
Showing
1 changed file
with
148 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
from typing import Tuple, List, Iterable, Optional | ||
import numpy as np | ||
|
||
from .extraction_tools import ArrayType, DtypeType | ||
from .imagingextractor import ImagingExtractor | ||
|
||
|
||
class VolumetricImagingExtractor(ImagingExtractor): | ||
"""Class to combine multiple ImagingExtractor objects by depth plane.""" | ||
|
||
extractor_name = "VolumetricImaging" | ||
installed = True | ||
installatiuon_mesage = "" | ||
|
||
def __init__(self, imaging_extractors: List[ImagingExtractor]): | ||
"""Initialize a VolumetricImagingExtractor object from a list of ImagingExtractors. | ||
Parameters | ||
---------- | ||
imaging_extractors: list of ImagingExtractor | ||
list of imaging extractor objects | ||
""" | ||
super().__init__() | ||
assert isinstance(imaging_extractors, list), "Enter a list of ImagingExtractor objects as argument" | ||
assert all(isinstance(imaging_extractor, ImagingExtractor) for imaging_extractor in imaging_extractors) | ||
self._check_consistency_between_imaging_extractors(imaging_extractors) | ||
self._imaging_extractors = imaging_extractors | ||
self._num_planes = len(imaging_extractors) | ||
|
||
def _check_consistency_between_imaging_extractors(self, imaging_extractors: List[ImagingExtractor]): | ||
"""Check that essential properties are consistent between extractors so that they can be combined appropriately. | ||
Parameters | ||
---------- | ||
imaging_extractors: list of ImagingExtractor | ||
list of imaging extractor objects | ||
Raises | ||
------ | ||
AssertionError | ||
If any of the properties are not consistent between extractors. | ||
Notes | ||
----- | ||
This method checks the following properties: | ||
- sampling frequency | ||
- image size | ||
- number of channels | ||
- channel names | ||
- data type | ||
- num_frames | ||
""" | ||
properties_to_check = dict( | ||
get_sampling_frequency="The sampling frequency", | ||
get_image_size="The size of a frame", | ||
get_num_channels="The number of channels", | ||
get_channel_names="The name of the channels", | ||
get_dtype="The data type", | ||
get_num_frames="The number of frames", | ||
) | ||
for method, property_message in properties_to_check.items(): | ||
values = [getattr(extractor, method)() for extractor in imaging_extractors] | ||
unique_values = set(tuple(v) if isinstance(v, Iterable) else v for v in values) | ||
assert ( | ||
len(unique_values) == 1 | ||
), f"{property_message} is not consistent over the files (found {unique_values})." | ||
|
||
def get_video(self, start_frame: Optional[int] = None, end_frame: Optional[int] = None) -> np.ndarray: | ||
"""Get the video frames. | ||
Parameters | ||
---------- | ||
start_frame: int, optional | ||
Start frame index (inclusive). | ||
end_frame: int, optional | ||
End frame index (exclusive). | ||
Returns | ||
------- | ||
video: numpy.ndarray | ||
The 3D video frames (num_rows, num_columns, num_planes). | ||
""" | ||
start_frame = start_frame if start_frame is not None else 0 | ||
end_frame = end_frame if end_frame is not None else self.get_num_frames() | ||
|
||
video = np.zeros((end_frame - start_frame, *self.get_image_size()), self.get_dtype()) | ||
for i, imaging_extractor in enumerate(self._imaging_extractors): | ||
video[..., i] = imaging_extractor.get_video(start_frame, end_frame) | ||
return video | ||
|
||
def get_frames(self, frame_idxs: ArrayType) -> np.ndarray: | ||
"""Get specific video frames from indices (not necessarily continuous). | ||
Parameters | ||
---------- | ||
frame_idxs: array-like | ||
Indices of frames to return. | ||
Returns | ||
------- | ||
frames: numpy.ndarray | ||
The 3D video frames (num_rows, num_columns, num_planes). | ||
""" | ||
if isinstance(frame_idxs, int): | ||
frame_idxs = [frame_idxs] | ||
|
||
if not all(np.diff(frame_idxs) == 1): | ||
frames = np.zeros((len(frame_idxs), *self.get_image_size()), self.get_dtype()) | ||
for i, imaging_extractor in enumerate(self._imaging_extractors): | ||
frames[..., i] = imaging_extractor.get_frames(frame_idxs) | ||
else: | ||
return self.get_video(start_frame=frame_idxs[0], end_frame=frame_idxs[-1] + 1) | ||
|
||
def get_image_size(self) -> Tuple: | ||
"""Get the size of a single frame. | ||
Returns | ||
------- | ||
image_size: tuple | ||
The size of a single frame (num_rows, num_columns, num_planes). | ||
""" | ||
image_size = (*self._imaging_extractors[0].get_image_size(), self.get_num_planes()) | ||
return image_size | ||
|
||
def get_num_planes(self) -> int: | ||
"""Get the number of depth planes. | ||
Returns | ||
------- | ||
_num_planes: int | ||
The number of depth planes. | ||
""" | ||
return self._num_planes | ||
|
||
def get_num_frames(self) -> int: | ||
return self._imaging_extractors[0].get_num_frames() | ||
|
||
def get_sampling_frequency(self) -> float: | ||
return self._imaging_extractors[0].get_sampling_frequency() | ||
|
||
def get_channel_names(self) -> list: | ||
return self._imaging_extractors[0].get_channel_names() | ||
|
||
def get_num_channels(self) -> int: | ||
return self._imaging_extractors[0].get_num_channels() | ||
|
||
def get_dtype(self) -> DtypeType: | ||
return self._imaging_extractors[0].get_dtype() |