Skip to content

Commit

Permalink
Add a mapping function in image_reader.py and image_writer.py (#7769)
Browse files Browse the repository at this point in the history
Add a function to create a JSON file that maps input and output paths.

Fixes #7557  .

### Description

A few sentences describing the changes proposed in this pull request.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: staydelight <[email protected]>
Co-authored-by: staydelight <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: YunLiu <[email protected]>
  • Loading branch information
4 people authored Aug 28, 2024
1 parent b62d1e1 commit b6d6d77
Show file tree
Hide file tree
Showing 7 changed files with 351 additions and 6 deletions.
12 changes: 12 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,12 @@ IO
:members:
:special-members: __call__

`WriteFileMapping`
""""""""""""""""""
.. autoclass:: WriteFileMapping
:members:
:special-members: __call__


NVIDIA Tool Extension (NVTX)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down Expand Up @@ -1642,6 +1648,12 @@ IO (Dict)
:members:
:special-members: __call__

`WriteFileMappingd`
"""""""""""""""""""
.. autoclass:: WriteFileMappingd
:members:
:special-members: __call__

Post-processing (Dict)
^^^^^^^^^^^^^^^^^^^^^^

Expand Down
14 changes: 12 additions & 2 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,18 @@
)
from .inverse import InvertibleTransform, TraceableTransform
from .inverse_batch_transform import BatchInverseTransform, Decollated, DecollateD, DecollateDict
from .io.array import SUPPORTED_READERS, LoadImage, SaveImage
from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict
from .io.array import SUPPORTED_READERS, LoadImage, SaveImage, WriteFileMapping
from .io.dictionary import (
LoadImaged,
LoadImageD,
LoadImageDict,
SaveImaged,
SaveImageD,
SaveImageDict,
WriteFileMappingd,
WriteFileMappingD,
WriteFileMappingDict,
)
from .lazy.array import ApplyPending
from .lazy.dictionary import ApplyPendingd, ApplyPendingD, ApplyPendingDict
from .lazy.functional import apply_pending
Expand Down
60 changes: 58 additions & 2 deletions monai/transforms/io/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

import inspect
import json
import logging
import sys
import traceback
Expand Down Expand Up @@ -45,11 +46,19 @@
from monai.transforms.utility.array import EnsureChannelFirst
from monai.utils import GridSamplePadMode
from monai.utils import ImageMetaKey as Key
from monai.utils import OptionalImportError, convert_to_dst_type, ensure_tuple, look_up_option, optional_import
from monai.utils import (
MetaKeys,
OptionalImportError,
convert_to_dst_type,
ensure_tuple,
look_up_option,
optional_import,
)

nib, _ = optional_import("nibabel")
Image, _ = optional_import("PIL.Image")
nrrd, _ = optional_import("nrrd")
FileLock, has_filelock = optional_import("filelock", name="FileLock")

__all__ = ["LoadImage", "SaveImage", "SUPPORTED_READERS"]

Expand Down Expand Up @@ -505,7 +514,7 @@ def __call__(
else:
self._data_index += 1
if self.savepath_in_metadict and meta_data is not None:
meta_data["saved_to"] = filename
meta_data[MetaKeys.SAVED_TO] = filename
return img
msg = "\n".join([f"{e}" for e in err])
raise RuntimeError(
Expand All @@ -514,3 +523,50 @@ def __call__(
" https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies.\n"
f" The current registered writers for {self.output_ext}: {self.writers}.\n{msg}"
)


class WriteFileMapping(Transform):
"""
Writes a JSON file that logs the mapping between input image paths and their corresponding output paths.
This class uses FileLock to ensure safe writing to the JSON file in a multiprocess environment.
Args:
mapping_file_path (Path or str): Path to the JSON file where the mappings will be saved.
"""

def __init__(self, mapping_file_path: Path | str = "mapping.json"):
self.mapping_file_path = Path(mapping_file_path)

def __call__(self, img: NdarrayOrTensor):
"""
Args:
img: The input image with metadata.
"""
if isinstance(img, MetaTensor):
meta_data = img.meta

if MetaKeys.SAVED_TO not in meta_data:
raise KeyError(
"Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True."
)

input_path = meta_data[Key.FILENAME_OR_OBJ]
output_path = meta_data[MetaKeys.SAVED_TO]
log_data = {"input": input_path, "output": output_path}

if has_filelock:
with FileLock(str(self.mapping_file_path) + ".lock"):
self._write_to_file(log_data)
else:
self._write_to_file(log_data)
return img

def _write_to_file(self, log_data):
try:
with self.mapping_file_path.open("r") as f:
existing_log_data = json.load(f)
except (FileNotFoundError, json.JSONDecodeError):
existing_log_data = []
existing_log_data.append(log_data)
with self.mapping_file_path.open("w") as f:
json.dump(existing_log_data, f, indent=4)
31 changes: 29 additions & 2 deletions monai/transforms/io/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,17 @@

from __future__ import annotations

from collections.abc import Hashable, Mapping
from pathlib import Path
from typing import Callable

import numpy as np

import monai
from monai.config import DtypeLike, KeysCollection
from monai.config import DtypeLike, KeysCollection, NdarrayOrTensor
from monai.data import image_writer
from monai.data.image_reader import ImageReader
from monai.transforms.io.array import LoadImage, SaveImage
from monai.transforms.io.array import LoadImage, SaveImage, WriteFileMapping
from monai.transforms.transform import MapTransform, Transform
from monai.utils import GridSamplePadMode, ensure_tuple, ensure_tuple_rep
from monai.utils.enums import PostFix
Expand Down Expand Up @@ -320,5 +321,31 @@ def __call__(self, data):
return d


class WriteFileMappingd(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.WriteFileMapping`.
Args:
keys: keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
mapping_file_path: Path to the JSON file where the mappings will be saved.
Defaults to "mapping.json".
allow_missing_keys: don't raise exception if key is missing.
"""

def __init__(
self, keys: KeysCollection, mapping_file_path: Path | str = "mapping.json", allow_missing_keys: bool = False
) -> None:
super().__init__(keys, allow_missing_keys)
self.mapping = WriteFileMapping(mapping_file_path)

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.mapping(d[key])
return d


LoadImageD = LoadImageDict = LoadImaged
SaveImageD = SaveImageDict = SaveImaged
WriteFileMappingD = WriteFileMappingDict = WriteFileMappingd
1 change: 1 addition & 0 deletions monai/utils/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,7 @@ class MetaKeys(StrEnum):
SPATIAL_SHAPE = "spatial_shape" # optional key for the length in each spatial dimension
SPACE = "space" # possible values of space type are defined in `SpaceKeys`
ORIGINAL_CHANNEL_DIM = "original_channel_dim" # an integer or float("nan")
SAVED_TO = "saved_to"


class ColorOrder(StrEnum):
Expand Down
117 changes: 117 additions & 0 deletions tests/test_mapping_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import json
import os
import shutil
import tempfile
import unittest

import numpy as np
from parameterized import parameterized

from monai.data import DataLoader, Dataset
from monai.transforms import Compose, LoadImage, SaveImage, WriteFileMapping
from monai.utils import optional_import

nib, has_nib = optional_import("nibabel")


def create_input_file(temp_dir, name):
test_image = np.random.rand(128, 128, 128)
output_ext = ".nii.gz"
input_file = os.path.join(temp_dir, name + output_ext)
nib.save(nib.Nifti1Image(test_image, np.eye(4)), input_file)
return input_file


def create_transform(temp_dir, mapping_file_path, savepath_in_metadict=True):
return Compose(
[
LoadImage(image_only=True),
SaveImage(output_dir=temp_dir, output_ext=".nii.gz", savepath_in_metadict=savepath_in_metadict),
WriteFileMapping(mapping_file_path=mapping_file_path),
]
)


@unittest.skipUnless(has_nib, "nibabel required")
class TestWriteFileMapping(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.mkdtemp()

def tearDown(self):
shutil.rmtree(self.temp_dir)

@parameterized.expand([(True,), (False,)])
def test_mapping_file(self, savepath_in_metadict):
mapping_file_path = os.path.join(self.temp_dir, "mapping.json")
name = "test_image"
input_file = create_input_file(self.temp_dir, name)
output_file = os.path.join(self.temp_dir, name, name + "_trans.nii.gz")

transform = create_transform(self.temp_dir, mapping_file_path, savepath_in_metadict)

if savepath_in_metadict:
transform(input_file)
self.assertTrue(os.path.exists(mapping_file_path))
with open(mapping_file_path) as f:
mapping_data = json.load(f)
self.assertEqual(len(mapping_data), 1)
self.assertEqual(mapping_data[0]["input"], input_file)
self.assertEqual(mapping_data[0]["output"], output_file)
else:
with self.assertRaises(RuntimeError) as cm:
transform(input_file)
cause_exception = cm.exception.__cause__
self.assertIsInstance(cause_exception, KeyError)
self.assertIn(
"Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True.",
str(cause_exception),
)

def test_multiprocess_mapping_file(self):
num_images = 50

single_mapping_file = os.path.join(self.temp_dir, "single_mapping.json")
multi_mapping_file = os.path.join(self.temp_dir, "multi_mapping.json")

data = [create_input_file(self.temp_dir, f"test_image_{i}") for i in range(num_images)]

# single process
single_transform = create_transform(self.temp_dir, single_mapping_file)
single_dataset = Dataset(data=data, transform=single_transform)
single_loader = DataLoader(single_dataset, batch_size=1, num_workers=0, shuffle=True)
for _ in single_loader:
pass

# multiple processes
multi_transform = create_transform(self.temp_dir, multi_mapping_file)
multi_dataset = Dataset(data=data, transform=multi_transform)
multi_loader = DataLoader(multi_dataset, batch_size=4, num_workers=3, shuffle=True)
for _ in multi_loader:
pass

with open(single_mapping_file) as f:
single_mapping_data = json.load(f)
with open(multi_mapping_file) as f:
multi_mapping_data = json.load(f)

single_set = {(entry["input"], entry["output"]) for entry in single_mapping_data}
multi_set = {(entry["input"], entry["output"]) for entry in multi_mapping_data}

self.assertEqual(single_set, multi_set)


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit b6d6d77

Please sign in to comment.