diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 477ec7a8bd..eb8c5af19e 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -381,6 +381,7 @@ RandGridPatch, RandRotate, RandRotate90, + RandSimulateLowResolution, RandZoom, Resample, ResampleToMatch, @@ -437,6 +438,9 @@ RandRotated, RandRotateD, RandRotateDict, + RandSimulateLowResolutiond, + RandSimulateLowResolutionD, + RandSimulateLowResolutionDict, RandZoomd, RandZoomD, RandZoomDict, diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 6d95acb3d1..2f8e6cefe7 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -25,7 +25,7 @@ from monai.config import USE_COMPILED, DtypeLike from monai.config.type_definitions import NdarrayOrTensor -from monai.data.meta_obj import get_track_meta +from monai.data.meta_obj import get_track_meta, set_track_meta from monai.data.meta_tensor import MetaTensor from monai.data.utils import AFFINE_TOL, affine_to_spacing, compute_shape_offset, iter_patch, to_affine_nd, zoom_affine from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull @@ -111,6 +111,7 @@ "RandAffine", "Rand2DElastic", "Rand3DElastic", + "RandSimulateLowResolution", ] RandRange = Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] @@ -3456,3 +3457,98 @@ def __call__(self, array: NdarrayOrTensor, randomize: bool = True): if randomize: self.randomize(array) return super().__call__(array) + + +class RandSimulateLowResolution(RandomizableTransform): + """ + Random simulation of low resolution corresponding to nnU-Net's SimulateLowResolutionTransform + (https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/transforms/resample_transforms.py#L23) + First, the array/tensor is resampled at lower resolution as determined by the zoom_factor which is uniformly sampled + from the `zoom_range`. Then, the array/tensor is resampled at the original resolution. + """ + + backend = Affine.backend + + def __init__( + self, + prob: float = 0.1, + downsample_mode: InterpolateMode | str = InterpolateMode.NEAREST, + upsample_mode: InterpolateMode | str = InterpolateMode.TRILINEAR, + zoom_range: Sequence[float] = (0.5, 1.0), + align_corners=False, + device: Optional[torch.device] = None, + ) -> None: + """ + Args: + prob: probability of performing this augmentation + downsample_mode: interpolation mode for downsampling operation + upsample_mode: interpolation mode for downsampling operation + zoom_range: range from which the random zoom factor for the downsampling and upsampling operation is + sampled. It determines the shape of the downsampled tensor. + align_corners: This only has an effect when downsample_mode or upsample_mode is 'linear', 'bilinear', + 'bicubic' or 'trilinear'. Default: False + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html + device: device on which the tensor will be allocated. + + """ + RandomizableTransform.__init__(self, prob) + + self.downsample_mode = downsample_mode + self.upsample_mode = upsample_mode + self.zoom_range = zoom_range + self.align_corners = align_corners + self.device = device + self.zoom_factor = 1 + + def randomize(self, data: Optional[Any] = None) -> None: + super().randomize(None) + self.zoom_factor = self.R.uniform(self.zoom_range[0], self.zoom_range[1]) + if not self._do_transform: + return None + + def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: + """ + Args: + img: shape must be (num_channels, H, W[, D]), + randomize: whether to execute `randomize()` function first, defaults to True. + """ + if randomize: + self.randomize() + + if self._do_transform: + input_shape = np.array(img.shape[1:]) + target_shape = np.round(input_shape * self.zoom_factor).astype(np.int_) + + resize_tfm_downsample = Resize( + spatial_size=target_shape, + size_mode="all", + mode=self.downsample_mode, + anti_aliasing=False + ) + + resize_tfm_upsample = Resize( + spatial_size=input_shape, + size_mode="all", + mode=self.upsample_mode, + anti_aliasing=False, + align_corners=self.align_corners, + ) + # temporarily disable metadata tracking, since we do not want to invert the two Resize functions during + # post-processing + original_tack_meta_value = get_track_meta() + set_track_meta(False) + + img_downsampled = resize_tfm_downsample(img) + img_upsampled = resize_tfm_upsample(img_downsampled) + + # reset metadata tracking to original value + set_track_meta(original_tack_meta_value) + + # copy metadata from original image to down-and-upsampled image + img_upsampled = MetaTensor(img_upsampled) + img_upsampled.copy_meta_from(img) + + return img_upsampled + + else: + return img diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 6f53b10fc2..47270e38f0 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -45,6 +45,7 @@ RandGridDistortion, RandGridPatch, RandRotate, + RandSimulateLowResolution, RandZoom, ResampleToMatch, Resize, @@ -140,6 +141,9 @@ "RandGridPatchd", "RandGridPatchD", "RandGridPatchDict", + "RandSimulateLowResolutiond", + "RandSimulateLowResolutionD", + "RandSimulateLowResolutionDict", ] @@ -2518,6 +2522,94 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N return d +class RandSimulateLowResolutiond(RandomizableTransform, MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.RandSimulateLowResolution`. + Random simulation of low resolution corresponding to nnU-Net's SimulateLowResolutionTransform + (https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/transforms/resample_transforms.py#L23) + First, the array/tensor is resampled at lower resolution as determined by the zoom_factor which is uniformly sampled + from the `zoom_range`. Then, the array/tensor is resampled at the original resolution. + """ + + backend = RandAffine.backend + + def __init__( + self, + keys: KeysCollection, + prob: float = 0.1, + downsample_mode: InterpolateMode | str = InterpolateMode.NEAREST, + upsample_mode: InterpolateMode | str = InterpolateMode.TRILINEAR, + zoom_range=(0.5, 1.0), + align_corners=False, + allow_missing_keys: bool = False, + device: torch.device | None = None, + ) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + prob: probability of performing this augmentation + downsample_mode: interpolation mode for downsampling operation + upsample_mode: interpolation mode for downsampling operation + zoom_range: range from which the random zoom factor for the downsampling and upsampling operation is + sampled. It determines the shape of the downsampled tensor. + align_corners: This only has an effect when downsample_mode or upsample_mode is 'linear', 'bilinear', + 'bicubic' or 'trilinear'. Default: False + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html + allow_missing_keys: don't raise exception if key is missing. + device: device on which the tensor will be allocated. + + See also: + - :py:class:`monai.transforms.compose.MapTransform` + + """ + MapTransform.__init__(self, keys, allow_missing_keys) + RandomizableTransform.__init__(self, prob) + + self.downsample_mode = downsample_mode + self.upsample_mode = upsample_mode + self.zoom_range = zoom_range + self.align_corners = align_corners + self.device = device + + self.sim_lowres_tfm = RandSimulateLowResolution( + prob=1.0, # probability is handled by dictionary class + downsample_mode=self.downsample_mode, + upsample_mode=self.upsample_mode, + zoom_range=self.zoom_range, + align_corners=self.align_corners, + device=self.device, + ) + + def set_random_state( + self, seed: int | None = None, state: np.random.RandomState | None = None + ) -> "RandSimulateLowResolutiond": + super().set_random_state(seed, state) + return self + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be transformed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + """ + d = dict(data) + first_key: Hashable = self.first_key(d) + if first_key == (): + out: dict[Hashable, NdarrayOrTensor] = convert_to_tensor(d, track_meta=get_track_meta()) + return out + + self.randomize(None) + + for key in self.key_iterator(d): + # do the transform + if self._do_transform: + d[key] = self.sim_lowres_tfm(d[key]) # type: ignore + else: + d[key] = convert_to_tensor(d[key], track_meta=get_track_meta(), dtype=torch.float32) + return d + + SpatialResampleD = SpatialResampleDict = SpatialResampled ResampleToMatchD = ResampleToMatchDict = ResampleToMatchd SpacingD = SpacingDict = Spacingd @@ -2541,3 +2633,4 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N GridSplitD = GridSplitDict = GridSplitd GridPatchD = GridPatchDict = GridPatchd RandGridPatchD = RandGridPatchDict = RandGridPatchd +RandSimulateLowResolutionD = RandSimulateLowResolutionDict = RandSimulateLowResolutiond diff --git a/tests/test_rand_simulate_low_resolution.py b/tests/test_rand_simulate_low_resolution.py new file mode 100644 index 0000000000..7d05faad36 --- /dev/null +++ b/tests/test_rand_simulate_low_resolution.py @@ -0,0 +1,83 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import RandSimulateLowResolution +from tests.utils import TEST_NDARRAYS, assert_allclose + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + dict(prob=1.0, zoom_range=(0.8, 0.81)), + p( + np.array( + [ + [ + [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]], + [[16, 17, 18, 19], [20, 21, 22, 23], [24, 25, 26, 27], [28, 29, 30, 31]], + [[32, 33, 34, 35], [36, 37, 38, 39], [40, 41, 42, 43], [44, 45, 46, 47]], + [[48, 49, 50, 51], [52, 53, 54, 55], [56, 57, 58, 59], [60, 61, 62, 63]], + ] + ] + ) + ), + np.array( + [ + [ + [ + [0.0000, 0.6250, 1.3750, 2.0000], + [2.5000, 3.1250, 3.8750, 4.5000], + [5.5000, 6.1250, 6.8750, 7.5000], + [8.0000, 8.6250, 9.3750, 10.0000], + ], + [ + [10.0000, 10.6250, 11.3750, 12.0000], + [12.5000, 13.1250, 13.8750, 14.5000], + [15.5000, 16.1250, 16.8750, 17.5000], + [18.0000, 18.6250, 19.3750, 20.0000], + ], + [ + [22.0000, 22.6250, 23.3750, 24.0000], + [24.5000, 25.1250, 25.8750, 26.5000], + [27.5000, 28.1250, 28.8750, 29.5000], + [30.0000, 30.6250, 31.3750, 32.0000], + ], + [ + [32.0000, 32.6250, 33.3750, 34.0000], + [34.5000, 35.1250, 35.8750, 36.5000], + [37.5000, 38.1250, 38.8750, 39.5000], + [40.0000, 40.6250, 41.3750, 42.0000], + ], + ] + ] + ), + ] + ) + + +class TestRandGaussianSmooth(unittest.TestCase): + @parameterized.expand(TESTS) + def test_value(self, arguments, image, expected_data): + randsimlowres = RandSimulateLowResolution(**arguments) + randsimlowres.set_random_state(seed=0) + result = randsimlowres(image) + assert_allclose(result, expected_data, rtol=1e-4, type_test="tensor") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_rand_simulate_low_resolutiond.py b/tests/test_rand_simulate_low_resolutiond.py new file mode 100644 index 0000000000..f058ec3b2b --- /dev/null +++ b/tests/test_rand_simulate_low_resolutiond.py @@ -0,0 +1,73 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import RandSimulateLowResolutiond +from tests.utils import TEST_NDARRAYS, assert_allclose + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + dict(keys=["img", "seg"], prob=1.0, zoom_range=(0.8, 0.81)), + {"img": p(np.arange(64).reshape(1, 4, 4, 4)), "seg": p(np.arange(64).reshape(1, 4, 4, 4))}, + np.array( + [ + [ + [ + [0.0000, 0.6250, 1.3750, 2.0000], + [2.5000, 3.1250, 3.8750, 4.5000], + [5.5000, 6.1250, 6.8750, 7.5000], + [8.0000, 8.6250, 9.3750, 10.0000], + ], + [ + [10.0000, 10.6250, 11.3750, 12.0000], + [12.5000, 13.1250, 13.8750, 14.5000], + [15.5000, 16.1250, 16.8750, 17.5000], + [18.0000, 18.6250, 19.3750, 20.0000], + ], + [ + [22.0000, 22.6250, 23.3750, 24.0000], + [24.5000, 25.1250, 25.8750, 26.5000], + [27.5000, 28.1250, 28.8750, 29.5000], + [30.0000, 30.6250, 31.3750, 32.0000], + ], + [ + [32.0000, 32.6250, 33.3750, 34.0000], + [34.5000, 35.1250, 35.8750, 36.5000], + [37.5000, 38.1250, 38.8750, 39.5000], + [40.0000, 40.6250, 41.3750, 42.0000], + ], + ] + ] + ), + ] + ) + + +class TestRandGaussianSmoothd(unittest.TestCase): + @parameterized.expand(TESTS) + def test_value(self, arguments, image, expected_data): + converter = RandSimulateLowResolutiond(**arguments) + converter.set_random_state(seed=0) + result = converter(image) + assert_allclose(result["img"], expected_data, rtol=1e-4, type_test=False) + assert_allclose(result["seg"], expected_data, rtol=1e-4, type_test=False) + + +if __name__ == "__main__": + unittest.main()