diff --git a/.github/workflows/mkdocs-release-caller.yml b/.github/workflows/mkdocs-release-caller.yml new file mode 100644 index 0000000..14e6f4f --- /dev/null +++ b/.github/workflows/mkdocs-release-caller.yml @@ -0,0 +1,9 @@ +name: mkdocs-release +on: + workflow_dispatch: + +jobs: + mkdocs_release: + uses: datajoint/.github/.github/workflows/mkdocs_release.yaml@main + permissions: + contents: write \ No newline at end of file diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml new file mode 100644 index 0000000..954c84a --- /dev/null +++ b/.github/workflows/release.yaml @@ -0,0 +1,10 @@ +name: Release +on: + workflow_dispatch: +jobs: + make_github_release: + uses: datajoint/.github/.github/workflows/make_github_release.yaml@main + mkdocs_release: + uses: datajoint/.github/.github/workflows/mkdocs_release.yaml@main + permissions: + contents: write \ No newline at end of file diff --git a/.github/workflows/semantic-release-caller.yml b/.github/workflows/semantic-release-caller.yml new file mode 100644 index 0000000..2aa3cd1 --- /dev/null +++ b/.github/workflows/semantic-release-caller.yml @@ -0,0 +1,7 @@ +name: semantic-release +on: + workflow_dispatch: + +jobs: + call_semantic_release: + uses: datajoint/.github/.github/workflows/semantic-release.yaml@main \ No newline at end of file diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml new file mode 100644 index 0000000..bd23395 --- /dev/null +++ b/.github/workflows/test.yaml @@ -0,0 +1,35 @@ +name: Test +on: + push: + pull_request: + workflow_dispatch: + schedule: + - cron: "0 8 * * 1" +jobs: + devcontainer-build: + uses: datajoint/.github/.github/workflows/devcontainer-build.yaml@main + tests: + runs-on: ubuntu-latest + strategy: + matrix: + py_ver: ["3.9", "3.10"] + mysql_ver: ["8.0", "5.7"] + include: + - py_ver: "3.8" + mysql_ver: "5.7" + - py_ver: "3.7" + mysql_ver: "5.7" + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{matrix.py_ver}} + uses: actions/setup-python@v4 + with: + python-version: ${{matrix.py_ver}} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install flake8 "black[jupyter]" + - name: Run style tests + run: | + python_version=${{matrix.py_ver}} + black element_interface --check --verbose --target-version py${python_version//.} \ No newline at end of file diff --git a/.github/workflows/u24_element_before_release.yaml b/.github/workflows/u24_element_before_release.yaml deleted file mode 100644 index 692cf82..0000000 --- a/.github/workflows/u24_element_before_release.yaml +++ /dev/null @@ -1,17 +0,0 @@ -name: u24_element_before_release -on: - pull_request: - push: - branches: - - '**' - tags-ignore: - - '**' - workflow_dispatch: -jobs: - call_context_check: - uses: dj-sciops/djsciops-cicd/.github/workflows/context_check.yaml@main - call_u24_elements_build_alpine: - uses: dj-sciops/djsciops-cicd/.github/workflows/u24_element_build.yaml@main - with: - py_ver: 3.9 - image: djbase diff --git a/.github/workflows/u24_element_release_call.yaml b/.github/workflows/u24_element_release_call.yaml deleted file mode 100644 index 4324cca..0000000 --- a/.github/workflows/u24_element_release_call.yaml +++ /dev/null @@ -1,28 +0,0 @@ -name: u24_element_release_call -on: - workflow_run: - workflows: ["u24_element_tag_to_release"] - types: - - completed -jobs: - call_context_check: - uses: dj-sciops/djsciops-cicd/.github/workflows/context_check.yaml@main - test_call_u24_elements_release_alpine: - if: >- - github.event.workflow_run.conclusion == 'success' && ( contains(github.event.workflow_run.head_branch, 'test') || (github.event.workflow_run.event == 'pull_request')) - uses: dj-sciops/djsciops-cicd/.github/workflows/u24_element_release.yaml@main - with: - py_ver: 3.9 - twine_repo: testpypi - secrets: - TWINE_USERNAME: ${{secrets.TWINE_TEST_USERNAME}} - TWINE_PASSWORD: ${{secrets.TWINE_TEST_PASSWORD}} - call_u24_elements_release_alpine: - if: >- - github.event.workflow_run.conclusion == 'success' && github.repository_owner == 'datajoint' && !contains(github.event.workflow_run.head_branch, 'test') - uses: dj-sciops/djsciops-cicd/.github/workflows/u24_element_release.yaml@main - with: - py_ver: 3.9 - secrets: - TWINE_USERNAME: ${{secrets.TWINE_USERNAME}} - TWINE_PASSWORD: ${{secrets.TWINE_PASSWORD}} diff --git a/.github/workflows/u24_element_tag_to_release.yaml b/.github/workflows/u24_element_tag_to_release.yaml deleted file mode 100644 index 57334e9..0000000 --- a/.github/workflows/u24_element_tag_to_release.yaml +++ /dev/null @@ -1,14 +0,0 @@ -name: u24_element_tag_to_release -on: - push: - tags: - - '*.*.*' - - 'test*.*.*' -jobs: - call_context_check: - uses: dj-sciops/djsciops-cicd/.github/workflows/context_check.yaml@main - call_u24_elements_build_alpine: - uses: dj-sciops/djsciops-cicd/.github/workflows/u24_element_build.yaml@main - with: - py_ver: 3.9 - image: djbase diff --git a/CHANGELOG.md b/CHANGELOG.md index b7dfd0b..815952c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,12 +3,18 @@ Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention. +## [0.7.0] - 2024-08-09 + ++ Add - `memoized_result` decorator to cache function results ++ Update - `prairie_view_loader.py` to create big tiff files from `.ome.tif` files ++ Update - `run_caiman.py` to run latest version of CaImAn ++ Update - `caiman_loader.py` to process output of latest version of CaImAn ++ Fix - general fixes and improvements ## [0.6.1] - 2023-08-02 + Update DANDI upload funtionality to improve useability - ## [0.6.0] - 2023-07-26 + Update - `prairieviewreader.py` -> `prairie_view_loader.py` @@ -83,6 +89,8 @@ Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and + Add - Readers for: `ScanImage`, `Suite2p`, `CaImAn`. + +[0.7.0]: https://github.com/datajoint/element-interface/releases/tag/0.7.0 [0.6.0]: https://github.com/datajoint/element-interface/releases/tag/0.6.0 [0.5.4]: https://github.com/datajoint/element-interface/releases/tag/0.5.4 [0.5.3]: https://github.com/datajoint/element-interface/releases/tag/0.5.3 diff --git a/LICENSE b/LICENSE index d394fe3..6872305 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2021 DataJoint NEURO +Copyright (c) 2024 DataJoint Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/docs/src/concepts.md b/docs/src/concepts.md index c7a592c..0ae80c2 100644 --- a/docs/src/concepts.md +++ b/docs/src/concepts.md @@ -24,6 +24,12 @@ how to use various Elements. `utils.str_to_bool` converts a set of strings to boolean True or False. This is implemented as the equivalent item in Python's `distutils` which will be removed in future versions. +`utils.memoized_result` is a decorator that caches the result of a function call based + on input parameters and the state of the output. If the function is called with the same + parameters and the output files in the directory remain unchanged, it returns the + cached results; otherwise, it executes the function and caches the new results along + with metadata. + ### Suite2p This Element provides functions to independently run Suite2p's motion correction, @@ -46,13 +52,15 @@ Requirements: ### PrairieView Reader -This Element provides a function to read the PrairieView Scanner's metadata file. The -PrairieView software generates one `.ome.tif` imaging file per frame acquired. The -metadata for all frames is contained in one `.xml` file. This function locates the -`.xml` file and generates a dictionary necessary to populate the DataJoint ScanInfo and -Field tables. PrairieView works with resonance scanners with a single field, does not -support bidirectional x and y scanning, and the `.xml` file does not contain ROI -information. +This Element provides a `PrairieViewMeta` class to handle different types of output from + the PrairieView Scanner. The PrairieView software either generates one `.ome.tif` + imaging file per frame acquired or multi-page `.ome.tif` files. The metadata for all + frames is contained in one `.xml` file. This class contains methods that locate the + `.xml` file and generate a dictionary necessary to populate the DataJoint ScanInfo and + Field tables. The class also contains methods to create a big tiff file from the + individual `.ome.tif` files. PrairieView works with resonance scanners with a single + field, does not support bidirectional x and y scanning, and the `.xml` file does not + contain ROI information. ## Element Architecture diff --git a/docs/src/index.md b/docs/src/index.md index c3eea7b..52b39f9 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -16,4 +16,6 @@ a number of other Elements. - Data ingestion, see [`ingest_csv_to_table` API](./api/element_interface/utils/#element_interface.utils.ingest_csv_to_table) +- Code execution, see [`memoized_result` API](./api/element_interface/utils/#element_interface.utils.memoized_result) + Visit the [Concepts page](./concepts.md) for more information on these tools. diff --git a/element_interface/caiman_loader.py b/element_interface/caiman_loader.py index 3726afd..7d1e06f 100644 --- a/element_interface/caiman_loader.py +++ b/element_interface/caiman_loader.py @@ -1,7 +1,7 @@ import os import pathlib from datetime import datetime - +import re import caiman as cm import h5py import numpy as np @@ -18,6 +18,314 @@ class CaImAn: + """ + Loader class for CaImAn analysis results + A top level aggregator of multiple set of CaImAn results (e.g. multi-plane analysis) + Calling _CaImAn (see below) under the hood + """ + + def __init__(self, caiman_dir: str): + """Initialize CaImAn loader class + + Args: + caiman_dir (str): string, absolute file path to CaIman directory + + Raises: + FileNotFoundError: No CaImAn analysis output file found + FileNotFoundError: No CaImAn analysis output found, missing required fields + """ + # ---- Search and verify CaImAn output file exists ---- + caiman_dir = pathlib.Path(caiman_dir) + if not caiman_dir.exists(): + raise FileNotFoundError("CaImAn directory not found: {}".format(caiman_dir)) + + caiman_subdirs = [] + for fp in caiman_dir.rglob("*.hdf5"): + with h5py.File(fp, "r") as h5f: + if all(s in h5f for s in _required_hdf5_fields): + caiman_subdirs.append(fp.parent) + + if not caiman_subdirs: + raise FileNotFoundError( + "No CaImAn analysis output file found at {}" + " containg all required fields ({})".format( + caiman_dir, _required_hdf5_fields + ) + ) + + # Extract CaImAn results from all planes, sorted by plane index + _planes_caiman = {} + for idx, caiman_subdir in enumerate(sorted(caiman_subdirs)): + pln_cm = _CaImAn(caiman_subdir.as_posix()) + pln_idx_match = re.search(r"pln(\d+)_.*", caiman_subdir.stem) + pln_idx = pln_idx_match.groups()[0] if pln_idx_match else idx + pln_cm.plane_idx = pln_idx + _planes_caiman[pln_idx] = pln_cm + sorted_pln_ind = sorted(list(_planes_caiman.keys())) + self.planes = {k: _planes_caiman[k] for k in sorted_pln_ind} + + self.creation_time = min( + [p.creation_time for p in self.planes.values()] + ) # ealiest file creation time + self.curation_time = max( + [p.curation_time for p in self.planes.values()] + ) # most recent curation time + + # is this 3D CaImAn analyis or multiple 2D per-plane analysis + if len(self.planes) > 1: + # if more than one set of caiman result, likely to be multiple 2D per-plane + # assert that the "is3D" value are all False for each of the caiman result + assert all(p.params.motion["is3D"] is False for p in self.planes.values()) + self.is3D = False + self.is_multiplane = True + else: + self.is3D = list(self.planes.values())[0].params.motion["is3D"] + self.is_multiplane = False + + if self.is_multiplane and self.is3D: + raise NotImplementedError( + f"Unable to load CaImAn results mixed between 3D and multi-plane analysis" + ) + + self._motion_correction = None + self._masks = None + self._ref_image = None + self._mean_image = None + self._max_proj_image = None + self._correlation_map = None + + @property + def is_pw_rigid(self): + pw_rigid = set(p.params.motion["pw_rigid"] for p in self.planes.values()) + assert ( + len(pw_rigid) == 1 + ), f"Unable to load CaImAn results mixed between rigid and pw_rigid motion correction" + return pw_rigid.pop() + + @property + def motion_correction(self): + if self._motion_correction is None: + self._motion_correction = ( + self.extract_pw_rigid_mc() + if self.is_pw_rigid + else self.extract_rigid_mc() + ) + return self._motion_correction + + def extract_rigid_mc(self): + # -- rigid motion correction -- + rigid_correction = {} + for pln_idx, (plane, pln_cm) in enumerate(self.planes.items()): + if pln_idx == 0: + rigid_correction = { + "x_shifts": pln_cm.motion_correction["shifts_rig"][:, 0], + "y_shifts": pln_cm.motion_correction["shifts_rig"][:, 1], + } + rigid_correction["x_std"] = np.nanstd( + rigid_correction["x_shifts"].flatten() + ) + rigid_correction["y_std"] = np.nanstd( + rigid_correction["y_shifts"].flatten() + ) + else: + rigid_correction["x_shifts"] = np.vstack( + [ + rigid_correction["x_shifts"], + pln_cm.motion_correction["shifts_rig"][:, 0], + ] + ) + rigid_correction["x_std"] = np.nanstd( + rigid_correction["x_shifts"].flatten() + ) + rigid_correction["y_shifts"] = np.vstack( + [ + rigid_correction["y_shifts"], + pln_cm.motion_correction["shifts_rig"][:, 1], + ] + ) + rigid_correction["y_std"] = np.nanstd( + rigid_correction["y_shifts"].flatten() + ) + + if not self.is_multiplane: + pln_cm = list(self.planes.values())[0] + rigid_correction["z_shifts"] = ( + pln_cm.motion_correction["shifts_rig"][:, 2] + if self.is3D + else np.full_like(rigid_correction["x_shifts"], 0) + ) + rigid_correction["z_std"] = ( + np.nanstd(pln_cm.motion_correction["shifts_rig"][:, 2]) + if self.is3D + else np.nan + ) + else: + rigid_correction["z_shifts"] = np.full_like(rigid_correction["x_shifts"], 0) + rigid_correction["z_std"] = np.nan + + rigid_correction["outlier_frames"] = None + + return rigid_correction + + def extract_pw_rigid_mc(self): + # -- piece-wise rigid motion correction -- + nonrigid_correction, nonrigid_blocks = {} + for pln_idx, (plane, pln_cm) in enumerate(self.planes.items()): + block_count = len(nonrigid_blocks) + if pln_idx == 0: + nonrigid_correction = { + "block_height": ( + pln_cm.params.motion["strides"][0] + + pln_cm.params.motion["overlaps"][0] + ), + "block_width": ( + pln_cm.params.motion["strides"][1] + + pln_cm.params.motion["overlaps"][1] + ), + "block_depth": 1, + "block_count_x": len( + set(pln_cm.motion_correction["coord_shifts_els"][:, 0]) + ), + "block_count_y": len( + set(pln_cm.motion_correction["coord_shifts_els"][:, 2]) + ), + "block_count_z": len(self.planes), + "outlier_frames": None, + } + for b_id in range(len(pln_cm.motion_correction["x_shifts_els"][0, :])): + b_id += block_count + nonrigid_blocks[b_id] = { + "block_id": b_id, + "block_x": np.arange( + *pln_cm.motion_correction["coord_shifts_els"][b_id, 0:2] + ), + "block_y": np.arange( + *pln_cm.motion_correction["coord_shifts_els"][b_id, 2:4] + ), + "block_z": ( + np.arange( + *pln_cm.motion_correction["coord_shifts_els"][b_id, 4:6] + ) + if self.is3D + else np.full_like( + np.arange( + *pln_cm.motion_correction["coord_shifts_els"][b_id, 0:2] + ), + pln_idx, + ) + ), + "x_shifts": pln_cm.motion_correction["x_shifts_els"][:, b_id], + "y_shifts": pln_cm.motion_correction["y_shifts_els"][:, b_id], + "z_shifts": ( + pln_cm.motion_correction["z_shifts_els"][:, b_id] + if self.is3D + else np.full_like( + pln_cm.motion_correction["x_shifts_els"][:, b_id], + 0, + ) + ), + "x_std": np.nanstd( + pln_cm.motion_correction["x_shifts_els"][:, b_id] + ), + "y_std": np.nanstd( + pln_cm.motion_correction["y_shifts_els"][:, b_id] + ), + "z_std": ( + np.nanstd(pln_cm.motion_correction["z_shifts_els"][:, b_id]) + if self.is3D + else np.nan + ), + } + + if not self.is_multiplane and self.is3D: + pln_cm = list(self.planes.values())[0] + nonrigid_correction["block_depth"] = ( + pln_cm.params.motion["strides"][2] + pln_cm.params.motion["overlaps"][2] + ) + nonrigid_correction["block_count_z"] = len( + set(pln_cm.motion_correction["coord_shifts_els"][:, 4]) + ) + + return nonrigid_correction, nonrigid_blocks + + @property + def masks(self): + if self._masks is None: + all_masks = [] + for pln_idx, pln_cm in sorted(self.planes.items()): + mask_count = len(all_masks) # increment mask id from all "plane" + all_masks.extend( + [ + { + **m, + "mask_id": m["mask_id"] + mask_count, + "orig_mask_id": m["mask_id"], + "accepted": ( + m["mask_id"] in pln_cm.cnmf.estimates.idx_components + if pln_cm.cnmf.estimates.idx_components is not None + else False + ), + } + for m in pln_cm.masks + ] + ) + + self._masks = all_masks + return self._masks + + @property + def alignment_channel(self): + return 0 # hard-code to channel index 0 + + @property + def segmentation_channel(self): + return 0 # hard-code to channel index 0 + + # -- image property -- + + def _get_image(self, img_type): + if not self.is_multiplane: + pln_cm = list(self.planes.values())[0] + img_ = ( + pln_cm.motion_correction[img_type].transpose() + if self.is3D + else pln_cm.motion_correction[img_type][...][..., np.newaxis] + ) + else: + img_ = np.dstack( + [ + pln_cm.motion_correction[img_type][...] + for pln_cm in self.planes.values() + ] + ) + return img_ + + @property + def ref_image(self): + if self._ref_image is None: + self._ref_image = self._get_image("reference_image") + return self._ref_image + + @property + def mean_image(self): + if self._mean_image is None: + self._mean_image = self._get_image("average_image") + return self._mean_image + + @property + def max_proj_image(self): + if self._max_proj_image is None: + self._max_proj_image = self._get_image("max_image") + return self._max_proj_image + + @property + def correlation_map(self): + if self._correlation_map is None: + self._correlation_map = self._get_image("correlation_image") + return self._correlation_map + + +class _CaImAn: """Parse the CaImAn output file [CaImAn results doc](https://caiman.readthedocs.io/en/master/Getting_Started.html#result-variables-for-2p-batch-analysis) @@ -54,6 +362,7 @@ class CaImAn: motion_correction: h5f "motion_correction" property params: cnmf.params segmentation_channel: hard-coded to 0 + plane_idx: N/A if `is3D` else hard-coded to 0 """ def __init__(self, caiman_dir: str): @@ -89,13 +398,20 @@ def __init__(self, caiman_dir: str): self.params = self.cnmf.params self.h5f = h5py.File(self.caiman_fp, "r") - self.motion_correction = self.h5f["motion_correction"] + self.plane_idx = None if self.params.motion["is3D"] else 0 + self._motion_correction = None self._masks = None # ---- Metainfo ---- self.creation_time = datetime.fromtimestamp(os.stat(self.caiman_fp).st_ctime) self.curation_time = datetime.fromtimestamp(os.stat(self.caiman_fp).st_ctime) + @property + def motion_correction(self): + if self._motion_correction is None: + self._motion_correction = self.h5f["motion_correction"] + return self._motion_correction + @property def masks(self): if self._masks is None: @@ -139,7 +455,7 @@ def extract_masks(self) -> dict: else: xpix, ypix = np.unravel_index(ind, self.cnmf.dims, order="F") center_x, center_y = comp_contour["CoM"].astype(int) - center_z = 0 + center_z = self.plane_idx zpix = np.full(len(weights), center_z) masks.append( @@ -161,7 +477,7 @@ def extract_masks(self) -> dict: return masks -def _process_scanimage_tiff(scan_filenames, output_dir="./"): +def _process_scanimage_tiff(scan_filenames, output_dir="./", split_depths=False): """ Read ScanImage TIFF - reshape into volumetric data based on scanning depths/channels Save new TIFF files for each channel - with shape (frame x height x width x depth) @@ -216,7 +532,12 @@ def _process_scanimage_tiff(scan_filenames, output_dir="./"): imsave(save_fp.as_posix(), chn_vol) -def _save_mc(mc, caiman_fp: str, is3D: bool): +def _save_mc( + mc, + caiman_fp: str, + is3D: bool, + summary_images: dict = None, +): """Save motion correction to hdf5 output Run these commands after the CaImAn analysis has completed. @@ -229,21 +550,13 @@ def _save_mc(mc, caiman_fp: str, is3D: bool): shifts_rig : Rigid transformation x and y shifts per frame x_shifts_els : Non rigid transformation x shifts per frame per block y_shifts_els : Non rigid transformation y shifts per frame per block - caiman_fp (str): CaImAn output (*.hdf5) file path + caiman_fp (str): CaImAn output (*.hdf5) file path - append if exists, else create new one + is3D (bool): the data is 3D + summary_images(dict): dict of summary images (average_image, max_image, correlation_image) - if None, will be computed, if provided as empty dict, will not be computed """ - - # Load motion corrected mmap image - mc_image = cm.load(mc.mmap_file, is3D=is3D) - - # Compute motion corrected summary images - average_image = np.mean(mc_image, axis=0) - max_image = np.max(mc_image, axis=0) - - # Compute motion corrected correlation image - correlation_image = cm.local_correlations( - mc_image.transpose((1, 2, 3, 0) if is3D else (1, 2, 0)) - ) - correlation_image[np.isnan(correlation_image)] = 0 + Yr, dims, T = cm.mmapping.load_memmap(mc.mmap_file[0]) + # Load the first frame of the movie + mc_image = np.reshape(Yr[: np.product(dims), :1], [1] + list(dims), order="F") # Compute mc.coord_shifts_els grid = [] @@ -275,7 +588,8 @@ def _save_mc(mc, caiman_fp: str, is3D: bool): ) # Open hdf5 file and create 'motion_correction' group - h5f = h5py.File(caiman_fp, "r+") + caiman_fp = pathlib.Path(caiman_fp) + h5f = h5py.File(caiman_fp, "r+" if caiman_fp.exists() else "w") h5g = h5f.require_group("motion_correction") # Write motion correction shifts and motion corrected summary images to hdf5 file @@ -307,7 +621,7 @@ def _save_mc(mc, caiman_fp: str, is3D: bool): # For CaImAn, reference image is still a 2D array even for the case of 3D # Assume that the same ref image is used for all the planes reference_image = ( - np.tile(mc.total_template_els, (1, 1, correlation_image.shape[-1])) + np.tile(mc.total_template_els, (1, 1, dims[-1])) if is3D else mc.total_template_els ) @@ -322,32 +636,45 @@ def _save_mc(mc, caiman_fp: str, is3D: bool): "coord_shifts_rig", shape=np.shape(grid), data=grid, dtype=type(grid[0][0]) ) reference_image = ( - np.tile(mc.total_template_rig, (1, 1, correlation_image.shape[-1])) + np.tile(mc.total_template_rig, (1, 1, dims[-1])) if is3D else mc.total_template_rig ) + if summary_images is None: + # Load motion corrected mmap image + mc_image = cm.load(mc.mmap_file, is3D=is3D) + + # Compute motion corrected summary images + average_image = np.mean(mc_image, axis=0) + max_image = np.max(mc_image, axis=0) + + # Compute motion corrected correlation image + correlation_image = cm.local_correlations( + mc_image.transpose((1, 2, 3, 0) if is3D else (1, 2, 0)) + ) + correlation_image[np.isnan(correlation_image)] = 0 + + summary_images = { + "average_image": average_image, + "max_image": max_image, + "correlation_image": correlation_image, + } + + for img_type, img in summary_images.items(): + h5g.require_dataset( + img_type, + shape=np.shape(img), + data=img, + dtype=img.dtype, + ) + h5g.require_dataset( "reference_image", shape=np.shape(reference_image), data=reference_image, dtype=reference_image.dtype, ) - h5g.require_dataset( - "correlation_image", - shape=np.shape(correlation_image), - data=correlation_image, - dtype=correlation_image.dtype, - ) - h5g.require_dataset( - "average_image", - shape=np.shape(average_image), - data=average_image, - dtype=average_image.dtype, - ) - h5g.require_dataset( - "max_image", shape=np.shape(max_image), data=max_image, dtype=max_image.dtype - ) # Close hdf5 file h5f.close() diff --git a/element_interface/extract_trigger.py b/element_interface/extract_trigger.py index 103e3e0..2600a93 100644 --- a/element_interface/extract_trigger.py +++ b/element_interface/extract_trigger.py @@ -43,7 +43,8 @@ def __init__( def write_matlab_run_script(self): """Compose a matlab script and save it with the name run_extract.m. - The composed script is basically the formatted version of the m_template attribute.""" + The composed script is basically the formatted version of the m_template attribute. + """ self.output_fullpath = ( self.output_dir / f"{self.scanfile.stem}_extract_output.mat" @@ -53,11 +54,15 @@ def write_matlab_run_script(self): **dict( parameters_list_string="\n".join( [ - f"config.{k} = '{v}';" - if isinstance(v, str) - else f"config.{k} = {str(v).lower()};" - if isinstance(v, bool) - else f"config.{k} = {v};" + ( + f"config.{k} = '{v}';" + if isinstance(v, str) + else ( + f"config.{k} = {str(v).lower()};" + if isinstance(v, bool) + else f"config.{k} = {v};" + ) + ) for k, v in self.parameters.items() ] ), diff --git a/element_interface/prairie_view_loader.py b/element_interface/prairie_view_loader.py index ee306cd..33fe1a9 100644 --- a/element_interface/prairie_view_loader.py +++ b/element_interface/prairie_view_loader.py @@ -1,12 +1,16 @@ +import os import pathlib from pathlib import Path import xml.etree.ElementTree as ET from datetime import datetime import numpy as np +import tifffile +import logging +logger = logging.getLogger(__name__) -class PrairieViewMeta: +class PrairieViewMeta: def __init__(self, prairieview_dir: str): """Initialize PrairieViewMeta loader class @@ -35,27 +39,194 @@ def __init__(self, prairieview_dir: str): def meta(self): if self._meta is None: self._meta = _extract_prairieview_metadata(self.xml_file) + # adjust for the different definition of "frames" + # from the ome meta - "frame" refers to an image at a given scanning depth, time step combination + # in the imaging pipeline - "frame" refers to video frames - i.e. time steps + num_frames = int(self._meta.pop("num_frames") / self._meta["num_planes"]) + self._meta["num_frames"] = num_frames + self._meta["frame_rate"] = num_frames / self._meta["scan_duration"] + return self._meta - def get_prairieview_files(self, plane_idx=None, channel=None): + def get_prairieview_filenames( + self, plane_idx=None, channel=None, return_pln_chn=False + ): + """ + Extract from metadata the set of tiff files specific to the specified "plane_idx" and "channel" + Args: + plane_idx: int - plane index + channel: int - channel + return_pln_chn: bool - if True, returns (filenames, plane_idx, channel), else returns `filenames` + + Returns: List[str] - the set of tiff files specific to the specified "plane_idx" and "channel" + """ if plane_idx is None: - if self.meta['num_planes'] > 1: - raise ValueError(f"Please specify 'plane_idx' - Plane indices: {self.meta['plane_indices']}") + if self.meta["num_planes"] > 1: + raise ValueError( + f"Please specify 'plane_idx' - Plane indices: {self.meta['plane_indices']}" + ) else: - plane_idx = self.meta['plane_indices'][0] + plane_idx = self.meta["plane_indices"][0] else: - assert plane_idx in self.meta['plane_indices'], f"Invalid 'plane_idx' - Plane indices: {self.meta['plane_indices']}" + assert ( + plane_idx in self.meta["plane_indices"] + ), f"Invalid 'plane_idx' - Plane indices: {self.meta['plane_indices']}" if channel is None: - if self.meta['num_channels'] > 1: - raise ValueError(f"Please specify 'channel' - Channels: {self.meta['channels']}") + if self.meta["num_channels"] > 1: + raise ValueError( + f"Please specify 'channel' - Channels: {self.meta['channels']}" + ) else: - plane_idx = self.meta['channels'][0] + channel = self.meta["channels"][0] else: - assert channel in self.meta['channels'], f"Invalid 'channel' - Channels: {self.meta['channels']}" + assert ( + channel in self.meta["channels"] + ), f"Invalid 'channel' - Channels: {self.meta['channels']}" - frames = self._xml_root.findall(f".//Sequence/Frame/[@index='{plane_idx}']/File/[@channel='{channel}']") - return [f.attrib['filename'] for f in frames] + # single-plane ome.tif does not have "@index" under Frame to search for + plane_search = f"/[@index='{plane_idx}']" if self.meta["num_planes"] > 1 else "" + # ome.tif does have "@channel" under File regardless of single or multi channel + channel_search = f"/[@channel='{channel}']" + + frames = self._xml_root.findall( + f".//Sequence/Frame{plane_search}/File{channel_search}" + ) + + fnames = np.unique([f.attrib["filename"] for f in frames]).tolist() + return fnames if not return_pln_chn else (fnames, plane_idx, channel) + + def write_single_bigtiff( + self, + plane_idx=None, + channel=None, + output_prefix=None, + output_dir="./", + caiman_compatible=False, # if True, save the movie as a single page (frame x height x width) + overwrite=False, + gb_per_file=None, + ): + logger.warning( + "Deprecation warning: `caiman_compatible` argument will no longer have any effect and will be removed in the future. `write_single_bigtiff` will return multi-page tiff, which is compatible with CaImAn." + ) + + tiff_names, plane_idx, channel = self.get_prairieview_filenames( + plane_idx=plane_idx, channel=channel, return_pln_chn=True + ) + if output_prefix is None: + output_prefix = os.path.commonprefix(tiff_names) + output_tiff_stem = f"{output_prefix}_pln{plane_idx}_chn{channel}" + + output_dir = Path(output_dir) + output_tiff_list = list(output_dir.glob(f"{output_tiff_stem}*.tif")) + if len(output_tiff_list) and not overwrite: + return output_tiff_list[0] if gb_per_file is None else output_tiff_list + + # delete old tif files if overwrite is True + [f.unlink() for f in output_tiff_list] + + output_tiff_list = [] + if self.meta["is_multipage"]: + if gb_per_file is not None: + logger.warning( + "Ignoring `gb_per_file` argument for multi-page tiff (NotYetImplemented)" + ) + # For multi-page tiff - the pages are organized as: + # (channel x slice x frame) - each page is (height x width) + # - TODO: verify this is the case for Bruker multi-page tiff + # This implementation is partially based on the reference code from `scanreader` package - https://github.com/atlab/scanreader + # See: https://github.com/atlab/scanreader/blob/2a021a85fca011c17e553d0e1c776998d3f2b2d8/scanreader/scans.py#L337 + slice_step = self.meta["num_channels"] + frame_step = self.meta["num_channels"] * self.meta["num_planes"] + slice_idx = self.meta["plane_indices"].index(plane_idx) + channel_idx = self.meta["channels"].index(channel) + + page_indices = [ + frame_idx * frame_step + slice_idx * slice_step + channel_idx + for frame_idx in range(self.meta["num_frames"]) + ] + + combined_data = np.empty( + [ + self.meta["num_frames"], + self.meta["height_in_pixels"], + self.meta["width_in_pixels"], + ], + dtype=int, + ) + start_page = 0 + try: + for input_file in tiff_names: + with tifffile.TiffFile(self.prairieview_dir / input_file) as tffl: + # Get indices in this tiff file and in output array + final_page_in_file = start_page + len(tffl.pages) + is_page_in_file = lambda page: page in range( + start_page, final_page_in_file + ) + pages_in_file = filter(is_page_in_file, page_indices) + file_indices = [page - start_page for page in pages_in_file] + global_indices = [ + is_page_in_file(page) for page in page_indices + ] + + # Read from this tiff file (if needed) + if len(file_indices) > 0: + # this line looks a bit ugly but is memory efficient. Do not separate + combined_data[global_indices] = tffl.asarray( + key=file_indices + ) + start_page += len(tffl.pages) + except Exception as e: + raise Exception(f"Error in processing tiff file {input_file}: {e}") + + output_tiff_fullpath = output_dir / f"{output_tiff_stem}.tif" + tifffile.imwrite( + output_tiff_fullpath, + combined_data, + metadata={"axes": "TYX", "'fps'": self.meta["frame_rate"]}, + bigtiff=True, + ) + output_tiff_list.append(output_tiff_fullpath) + else: + while len(tiff_names): + output_tiff_fullpath = ( + output_dir / f"{output_tiff_stem}_{len(output_tiff_list):04}.tif" + ) + with tifffile.TiffWriter( + output_tiff_fullpath, + bigtiff=True, + ) as tiff_writer: + while len(tiff_names): + input_file = tiff_names.pop(0) + try: + with tifffile.TiffFile( + self.prairieview_dir / input_file + ) as tffl: + assert len(tffl.pages) == 1 + tiff_writer.write( + tffl.pages[0].asarray(), + metadata={ + "axes": "YX", + "'fps'": self.meta["frame_rate"], + }, + ) + # additional safeguard to close the file and delete the object + # in the attempt to prevent error: `not a TIFF file b''` + tffl.close() + del tffl + except Exception as e: + raise Exception( + f"Error in processing tiff file {input_file}: {e}" + ) + if ( + gb_per_file + and output_tiff_fullpath.stat().st_size + >= gb_per_file * 1024**3 + ): + break + output_tiff_list.append(output_tiff_fullpath) + + return output_tiff_list[0] if gb_per_file is None else output_tiff_list def _extract_prairieview_metadata(xml_filepath: str): @@ -67,7 +238,7 @@ def _extract_prairieview_metadata(xml_filepath: str): bidirectional_scan = False # Does not support bidirectional roi = 0 - n_fields = 1 # Always contains 1 field + is_multipage = xml_root.find(".//Sequence/Frame/File/[@page]") is not None recording_start_time = xml_root.find(".//Sequence/[@cycle='1']").attrib.get("time") # Get all channels and find unique values @@ -78,9 +249,9 @@ def _extract_prairieview_metadata(xml_filepath: str): channels = set(channel_list) n_channels = len(channels) n_frames = len(xml_root.findall(".//Sequence/Frame")) - framerate = 1 / float( + frame_period = float( xml_root.findall('.//PVStateValue/[@key="framePeriod"]')[0].attrib.get("value") - ) # rate = 1/framePeriod + ) usec_per_line = ( float( @@ -129,7 +300,7 @@ def _extract_prairieview_metadata(xml_filepath: str): if ( xml_root.find( - ".//Sequence/[@cycle='1']/Frame/PVStateShard/PVStateValue/[@key='positionCurrent']/SubindexedValues/[@index='ZAxis']" + ".//Sequence/[@cycle='2']/Frame/PVStateShard/PVStateValue/[@key='positionCurrent']/SubindexedValues/[@index='ZAxis']" ) is None ): @@ -156,22 +327,22 @@ def _extract_prairieview_metadata(xml_filepath: str): n_depths = len(plane_indices) z_controllers = xml_root.findall( - ".//Sequence/[@cycle='1']/Frame/[@index='1']/PVStateShard/PVStateValue/[@key='positionCurrent']/SubindexedValues/[@index='ZAxis']/SubindexedValue" + ".//Sequence/[@cycle='2']/Frame/[@index='1']/PVStateShard/PVStateValue/[@key='positionCurrent']/SubindexedValues/[@index='ZAxis']/SubindexedValue" ) - # If more than one Z-axis controllers are found, + # If more than one Z-axis controllers are found, # check which controller is changing z_field depth. Only 1 controller # must change depths. if len(z_controllers) > 1: z_repeats = [] for controller in xml_root.findall( - ".//Sequence/[@cycle='1']/Frame/[@index='1']/PVStateShard/PVStateValue/[@key='positionCurrent']/SubindexedValues/[@index='ZAxis']/" + ".//Sequence/[@cycle='2']/Frame/[@index='1']/PVStateShard/PVStateValue/[@key='positionCurrent']/SubindexedValues/[@index='ZAxis']/" ): z_repeats.append( [ float(z.attrib.get("value")) for z in xml_root.findall( - ".//Sequence/[@cycle='1']/Frame/PVStateShard/PVStateValue/[@key='positionCurrent']/SubindexedValues/[@index='ZAxis']/SubindexedValue/[@subindex='{0}']".format( + ".//Sequence/[@cycle='2']/Frame/PVStateShard/PVStateValue/[@key='positionCurrent']/SubindexedValues/[@index='ZAxis']/SubindexedValue/[@subindex='{0}']".format( controller.attrib.get("subindex") ) ) @@ -191,7 +362,7 @@ def _extract_prairieview_metadata(xml_filepath: str): z_fields = [ z.attrib.get("value") for z in xml_root.findall( - ".//Sequence/[@cycle='1']/Frame/PVStateShard/PVStateValue/[@key='positionCurrent']/SubindexedValues/[@index='ZAxis']/SubindexedValue/[@subindex='0']" + ".//Sequence/[@cycle='2']/Frame/PVStateShard/PVStateValue/[@key='positionCurrent']/SubindexedValues/[@index='ZAxis']/SubindexedValue/[@subindex='0']" ) ] @@ -200,7 +371,7 @@ def _extract_prairieview_metadata(xml_filepath: str): ), "Number of z fields does not match number of depths." metainfo = dict( - num_fields=n_fields, + num_fields=n_depths, num_channels=n_channels, num_planes=n_depths, num_frames=n_frames, @@ -208,9 +379,10 @@ def _extract_prairieview_metadata(xml_filepath: str): x_pos=None, y_pos=None, z_pos=None, - frame_rate=framerate, + frame_period=frame_period, bidirectional=bidirectional_scan, bidirectional_z=bidirection_z, + is_multipage=is_multipage, scan_datetime=scan_datetime, usecs_per_line=usec_per_line, scan_duration=total_scan_duration, diff --git a/element_interface/run_caiman.py b/element_interface/run_caiman.py index eb480a9..4eefadc 100644 --- a/element_interface/run_caiman.py +++ b/element_interface/run_caiman.py @@ -1,12 +1,23 @@ -import pathlib - import cv2 +import os +import pathlib +import shutil +import numpy as np +import multiprocessing try: cv2.setNumThreads(0) except: # noqa E722 pass # TODO: remove bare except +try: + import torch + + cuda_is_available = torch.cuda.is_available() +except: + cuda_is_available = False + pass + import caiman as cm from caiman.source_extraction.cnmf import params as params from caiman.source_extraction.cnmf.cnmf import CNMF @@ -35,21 +46,48 @@ def run_caiman( parameters["fnames"] = file_paths parameters["fr"] = sampling_rate - opts = params.CNMFParams(params_dict=parameters) + use_cuda = parameters.get("use_cuda") + parameters["use_cuda"] = cuda_is_available if use_cuda is None else use_cuda - c, dview, n_processes = cm.cluster.setup_cluster( - backend="local", n_processes=None, single_thread=False - ) + if "indices" in parameters: + indices = parameters.pop( + "indices" + ) # Indices that restrict FOV for motion correction. + indices = slice(*indices[0]), slice(*indices[1]) + parameters["motion"] = {**parameters.get("motion", {}), "indices": indices} - cnm = CNMF(n_processes, params=opts, dview=dview) - cnmf_output, mc_output = cnm.fit_file( - motion_correct=True, include_eval=True, output_dir=output_dir, return_mc=True + caiman_temp = os.environ.get("CAIMAN_TEMP") + os.environ["CAIMAN_TEMP"] = str(output_dir) + + # use 80% of available cores + n_processes = int(np.floor(multiprocessing.cpu_count() * 0.8)) + _, dview, n_processes = cm.cluster.setup_cluster( + backend="multiprocessing", n_processes=n_processes ) - cm.stop_server(dview=dview) + try: + opts = params.CNMFParams(params_dict=parameters) + cnm = CNMF(n_processes, params=opts, dview=dview) + cnmf_output, mc_output = cnm.fit_file( + motion_correct=True, + indices=None, # Indices defined here restrict FOV for segmentation. `None` uses the full image for segmentation. + include_eval=True, + output_dir=output_dir, + return_mc=True, + ) + except Exception as e: + dview.terminate() + raise e + else: + cm.stop_server(dview=dview) + + if caiman_temp is not None: + os.environ["CAIMAN_TEMP"] = caiman_temp + else: + del os.environ["CAIMAN_TEMP"] cnmf_output_file = pathlib.Path(cnmf_output.mmap_file[:-4] + "hdf5") + cnmf_output_file = pathlib.Path(output_dir) / cnmf_output_file.name assert cnmf_output_file.exists() - assert cnmf_output_file.parent == pathlib.Path(output_dir) _save_mc(mc_output, cnmf_output_file.as_posix(), parameters["is3D"]) diff --git a/element_interface/suite2p_loader.py b/element_interface/suite2p_loader.py index e16fd3f..646c01b 100644 --- a/element_interface/suite2p_loader.py +++ b/element_interface/suite2p_loader.py @@ -153,7 +153,9 @@ def __init__(self, suite2p_plane_dir: str): @property def curation_time(self): - print("DeprecationWarning: 'curation_time' is deprecated, set to be the same as 'creation time', no longer reliable.") + print( + "DeprecationWarning: 'curation_time' is deprecated, set to be the same as 'creation time', no longer reliable." + ) return self.creation_time @property diff --git a/element_interface/utils.py b/element_interface/utils.py index 14d4eee..2fc8ca9 100644 --- a/element_interface/utils.py +++ b/element_interface/utils.py @@ -5,7 +5,9 @@ import pathlib import sys import uuid - +import json +import pickle +from datetime import datetime from datajoint.utils import to_camel_case logger = logging.getLogger("datajoint") @@ -187,3 +189,70 @@ def __exit__(self, *args): logger.setLevel(self.prev_log_level) sys.stdout.close() sys.stdout = self._original_stdout + + +def memoized_result(uniqueness_dict: dict, output_directory: str): + """ + This is a decorator factory designed to cache the results of a function based on its input parameters and the state of the output directory. + If the function is called with the same parameters and the output files in the directory remain unchanged, + it returns the cached results; otherwise, it executes the function and caches the new results along with metadata. + + Args: + uniqueness_dict: a dictionary that would identify a unique function call + output_directory: directory location for the output files + + Returns: a decorator to enable a function call to memoize/cached the resulting files + + Conditions for robust usage: + - the "output_directory" is to store exclusively the resulting files generated by this function call only, not a shared space with other functions/processes + - the "parameters" passed to the decorator captures the true and uniqueness of the arguments to be used in the decorated function call + """ + + def decorator(func): + def wrapped(*args, **kwargs): + output_dir = _to_Path(output_directory) + input_hash = dict_to_uuid(uniqueness_dict) + input_hash_fp = output_dir / f".{input_hash}.json" + # check if results already exist (from previous identical run) + output_dir_files_hash = dict_to_uuid( + { + f.relative_to(output_dir).as_posix(): f.stat().st_size + for f in output_dir.rglob("*") + if f.name != f".{input_hash}.json" + } + ) + if input_hash_fp.exists(): + with open(input_hash_fp, "r") as f: + meta = json.load(f) + if str(output_dir_files_hash) == meta["output_dir_files_hash"]: + logger.info(f"Existing results found, skip '{func.__name__}'") + with open(output_dir / f".{input_hash}_results.pickle", "rb") as f: + results = pickle.load(f) + return results + # no results - trigger the run + logger.info(f"No existing results found, calling '{func.__name__}'") + start_time = datetime.utcnow() + results = func(*args, **kwargs) + + with open(output_dir / f".{input_hash}_results.pickle", "wb") as f: + pickle.dump(results, f, protocol=pickle.HIGHEST_PROTOCOL) + + meta = { + "output_dir_files_hash": dict_to_uuid( + { + f.relative_to(output_dir).as_posix(): f.stat().st_size + for f in output_dir.rglob("*") + if f.name != f".{input_hash}.json" + } + ), + "start_time": start_time, + "completion_time": datetime.utcnow(), + } + with open(input_hash_fp, "w") as f: + json.dump(meta, f, default=str) + + return results + + return wrapped + + return decorator diff --git a/element_interface/version.py b/element_interface/version.py index 0da8726..70aab85 100644 --- a/element_interface/version.py +++ b/element_interface/version.py @@ -1,3 +1,3 @@ """Package metadata""" -__version__ = "0.6.1" +__version__ = "0.7.0"