From 19a5f81318303dd8670704cab82cc0c4b5e48d53 Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Mon, 15 May 2023 13:39:03 +0300 Subject: [PATCH 01/60] Improved performance: np.clip and one percentile call --- batchgenerators/augmentations/color_augmentations.py | 11 +++++------ batchgenerators/augmentations/normalizations.py | 12 ++++-------- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/batchgenerators/augmentations/color_augmentations.py b/batchgenerators/augmentations/color_augmentations.py index 05f483b..db30e17 100644 --- a/batchgenerators/augmentations/color_augmentations.py +++ b/batchgenerators/augmentations/color_augmentations.py @@ -41,11 +41,10 @@ def augment_contrast(data_sample: np.ndarray, minm = data_sample[c].min() maxm = data_sample[c].max() - data_sample[c] = (data_sample[c] - mn) * factor + mn + data_sample[c] = data_sample[c] * factor + mn * (1 - factor) if preserve_range: - data_sample[c][data_sample[c] < minm] = minm - data_sample[c][data_sample[c] > maxm] = maxm + np.clip(data_sample[c], minm, maxm, out=data_sample[c]) else: for c in range(data_sample.shape[0]): if np.random.uniform() < p_per_channel: @@ -62,11 +61,11 @@ def augment_contrast(data_sample: np.ndarray, minm = data_sample[c].min() maxm = data_sample[c].max() - data_sample[c] = (data_sample[c] - mn) * factor + mn + data_sample[c] = data_sample[c] * factor + mn * (1 - factor) if preserve_range: - data_sample[c][data_sample[c] < minm] = minm - data_sample[c][data_sample[c] > maxm] = maxm + np.clip(data_sample[c], minm, maxm, out=data_sample[c]) + return data_sample diff --git a/batchgenerators/augmentations/normalizations.py b/batchgenerators/augmentations/normalizations.py index 20a6d65..acc37d1 100644 --- a/batchgenerators/augmentations/normalizations.py +++ b/batchgenerators/augmentations/normalizations.py @@ -84,14 +84,10 @@ def mean_std_normalization(data, mean, std, per_channel=True): def cut_off_outliers(data, percentile_lower=0.2, percentile_upper=99.8, per_channel=False): for b in range(len(data)): if not per_channel: - cut_off_lower = np.percentile(data[b], percentile_lower) - cut_off_upper = np.percentile(data[b], percentile_upper) - data[b][data[b] < cut_off_lower] = cut_off_lower - data[b][data[b] > cut_off_upper] = cut_off_upper + cut_off_lower, cut_off_upper = np.percentile(data[b], (percentile_lower, percentile_upper)) + np.clip(data[b], cut_off_lower, cut_off_upper, out=data[b]) else: for c in range(data.shape[1]): - cut_off_lower = np.percentile(data[b, c], percentile_lower) - cut_off_upper = np.percentile(data[b, c], percentile_upper) - data[b, c][data[b, c] < cut_off_lower] = cut_off_lower - data[b, c][data[b, c] > cut_off_upper] = cut_off_upper + cut_off_lower, cut_off_upper = np.percentile(data[b, c], (percentile_lower, percentile_upper)) + np.clip(data[b, c], cut_off_lower, cut_off_upper, out=data[b, c]) return data From 57f21ccd02b0f09de9ff9a58e7aecc90b468f6d7 Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Mon, 15 May 2023 13:39:59 +0300 Subject: [PATCH 02/60] Doing inplace operations --- batchgenerators/augmentations/color_augmentations.py | 12 ++++++------ batchgenerators/augmentations/noise_augmentations.py | 2 +- batchgenerators/augmentations/normalizations.py | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/batchgenerators/augmentations/color_augmentations.py b/batchgenerators/augmentations/color_augmentations.py index db30e17..ba06e99 100644 --- a/batchgenerators/augmentations/color_augmentations.py +++ b/batchgenerators/augmentations/color_augmentations.py @@ -121,9 +121,9 @@ def augment_gamma(data_sample, gamma_range=(0.5, 2), invert_image=False, epsilon rnge = data_sample.max() - minm data_sample = np.power(((data_sample - minm) / float(rnge + epsilon)), gamma) * rnge + minm if retain_stats_here: - data_sample = data_sample - data_sample.mean() - data_sample = data_sample / (data_sample.std() + 1e-8) * sd - data_sample = data_sample + mn + data_sample -= data_sample.mean() + data_sample *= sd / (data_sample.std() + 1e-8) + data_sample += mn else: for c in range(data_sample.shape[0]): retain_stats_here = retain_stats() if callable(retain_stats) else retain_stats @@ -138,9 +138,9 @@ def augment_gamma(data_sample, gamma_range=(0.5, 2), invert_image=False, epsilon rnge = data_sample[c].max() - minm data_sample[c] = np.power(((data_sample[c] - minm) / float(rnge + epsilon)), gamma) * float(rnge + epsilon) + minm if retain_stats_here: - data_sample[c] = data_sample[c] - data_sample[c].mean() - data_sample[c] = data_sample[c] / (data_sample[c].std() + 1e-8) * sd - data_sample[c] = data_sample[c] + mn + data_sample[c] -= data_sample[c].mean() + data_sample[c] *= sd / (data_sample[c].std() + 1e-8) + data_sample[c] += mn if invert_image: data_sample = - data_sample return data_sample diff --git a/batchgenerators/augmentations/noise_augmentations.py b/batchgenerators/augmentations/noise_augmentations.py index c97e395..bf8f5b1 100644 --- a/batchgenerators/augmentations/noise_augmentations.py +++ b/batchgenerators/augmentations/noise_augmentations.py @@ -44,7 +44,7 @@ def augment_gaussian_noise(data_sample: np.ndarray, noise_variance: Tuple[float, noise_variance[0] if noise_variance[0] == noise_variance[1] else \ random.uniform(noise_variance[0], noise_variance[1]) # bug fixed: https://github.com/MIC-DKFZ/batchgenerators/issues/86 - data_sample[c] = data_sample[c] + np.random.normal(0.0, variance_here, size=data_sample[c].shape) + data_sample[c] += np.random.normal(0.0, variance_here, size=data_sample[c].shape) return data_sample diff --git a/batchgenerators/augmentations/normalizations.py b/batchgenerators/augmentations/normalizations.py index acc37d1..a9c0703 100644 --- a/batchgenerators/augmentations/normalizations.py +++ b/batchgenerators/augmentations/normalizations.py @@ -74,7 +74,7 @@ def mean_std_normalization(data, mean, std, per_channel=True): for b in range(data_shape[0]): if per_channel: - for c in range(data_shape[1]): + for c in range(data_shape[1]): # TODO: do one loop data_normalized[b][c] = (data[b][c] - mean[c]) / std[c] else: data_normalized[b] = (data[b] - mean) / std From dbba889c3ca10c4c39f940707ef08bac220a38a6 Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Mon, 15 May 2023 13:41:08 +0300 Subject: [PATCH 03/60] Misc + small improvements, preferring tuples to lists --- .../augmentations/color_augmentations.py | 2 +- .../augmentations/crop_and_pad_augmentations.py | 14 ++++++++------ batchgenerators/augmentations/normalizations.py | 4 ++-- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/batchgenerators/augmentations/color_augmentations.py b/batchgenerators/augmentations/color_augmentations.py index ba06e99..51acc43 100644 --- a/batchgenerators/augmentations/color_augmentations.py +++ b/batchgenerators/augmentations/color_augmentations.py @@ -93,8 +93,8 @@ def augment_brightness_additive(data_sample, mu:float, sigma:float , per_channel def augment_brightness_multiplicative(data_sample, multiplier_range=(0.5, 2), per_channel=True): - multiplier = np.random.uniform(multiplier_range[0], multiplier_range[1]) if not per_channel: + multiplier = np.random.uniform(multiplier_range[0], multiplier_range[1]) data_sample *= multiplier else: for c in range(data_sample.shape[0]): diff --git a/batchgenerators/augmentations/crop_and_pad_augmentations.py b/batchgenerators/augmentations/crop_and_pad_augmentations.py index 88a8cff..5930113 100644 --- a/batchgenerators/augmentations/crop_and_pad_augmentations.py +++ b/batchgenerators/augmentations/crop_and_pad_augmentations.py @@ -32,10 +32,12 @@ def get_lbs_for_random_crop(crop_size, data_shape, margins): """ lbs = [] for i in range(len(data_shape) - 2): - if data_shape[i+2] - crop_size[i] - margins[i] > margins[i]: - lbs.append(np.random.randint(margins[i], data_shape[i+2] - crop_size[i] - margins[i])) + new_shape = data_shape[i+2] - crop_size[i] + margin = margins[i] + if new_shape > 2 * margin: + lbs.append(np.random.randint(margin, new_shape - margin)) else: - lbs.append((data_shape[i+2] - crop_size[i]) // 2) + lbs.append(new_shape // 2) return lbs @@ -88,7 +90,7 @@ def crop(data, seg=None, crop_size=128, margins=(0, 0, 0), crop_type="center", (str(data_shape), str(seg_shape)) if type(crop_size) not in (tuple, list, np.ndarray): - crop_size = [crop_size] * dim + crop_size = (crop_size, ) * dim else: assert len(crop_size) == len( data_shape) - 2, "If you provide a list/tuple as center crop make sure it has the same dimension as your " \ @@ -97,9 +99,9 @@ def crop(data, seg=None, crop_size=128, margins=(0, 0, 0), crop_type="center", if not isinstance(margins, (np.ndarray, tuple, list)): margins = [margins] * dim - data_return = np.zeros([data_shape[0], data_shape[1]] + list(crop_size), dtype=data_dtype) + data_return = np.zeros((data_shape[0], data_shape[1]) + tuple(crop_size), dtype=data_dtype) if seg is not None: - seg_return = np.zeros([seg_shape[0], seg_shape[1]] + list(crop_size), dtype=seg_dtype) + seg_return = np.zeros((seg_shape[0], seg_shape[1]) + tuple(crop_size), dtype=seg_dtype) else: seg_return = None diff --git a/batchgenerators/augmentations/normalizations.py b/batchgenerators/augmentations/normalizations.py index a9c0703..567ae89 100644 --- a/batchgenerators/augmentations/normalizations.py +++ b/batchgenerators/augmentations/normalizations.py @@ -57,10 +57,10 @@ def zero_mean_unit_variance_normalization(data, per_channel=True, epsilon=1e-8): def mean_std_normalization(data, mean, std, per_channel=True): data_normalized = np.zeros(data.shape, dtype=data.dtype) if isinstance(data, np.ndarray): - data_shape = tuple(list(data.shape)) + data_shape = data.shape elif isinstance(data, (list, tuple)): assert len(data) > 0 and isinstance(data[0], np.ndarray) - data_shape = [len(data)] + list(data[0].shape) + data_shape = (len(data),) + data[0].shape else: raise TypeError("Data has to be either a numpy array or a list") From 68efcfdb6b3b3c18c13a0d2413edc318b77341e5 Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 16 May 2023 14:02:13 +0300 Subject: [PATCH 04/60] Improved speed of mean_std normalization + added tests --- .../augmentations/normalizations.py | 15 +++++----- tests/test_multithreaded_augmenter.py | 3 +- tests/test_normalizations.py | 29 ++++++++++++++++++- 3 files changed, 38 insertions(+), 9 deletions(-) diff --git a/batchgenerators/augmentations/normalizations.py b/batchgenerators/augmentations/normalizations.py index 567ae89..43f03d6 100644 --- a/batchgenerators/augmentations/normalizations.py +++ b/batchgenerators/augmentations/normalizations.py @@ -55,7 +55,6 @@ def zero_mean_unit_variance_normalization(data, per_channel=True, epsilon=1e-8): def mean_std_normalization(data, mean, std, per_channel=True): - data_normalized = np.zeros(data.shape, dtype=data.dtype) if isinstance(data, np.ndarray): data_shape = data.shape elif isinstance(data, (list, tuple)): @@ -72,12 +71,14 @@ def mean_std_normalization(data, mean, std, per_channel=True): elif per_channel and isinstance(std, (tuple, list, np.ndarray)): assert len(std) == data_shape[1] - for b in range(data_shape[0]): - if per_channel: - for c in range(data_shape[1]): # TODO: do one loop - data_normalized[b][c] = (data[b][c] - mean[c]) / std[c] - else: - data_normalized[b] = (data[b] - mean) / std + if per_channel: + mean = np.array(mean) + std = np.array(std) + data_normalized = np.zeros(data.shape, dtype=data.dtype) + for b in range(data_shape[0]): + data_normalized[b] = ((data[b].T - mean) / std).T + else: + data_normalized = (data - mean) / std return data_normalized diff --git a/tests/test_multithreaded_augmenter.py b/tests/test_multithreaded_augmenter.py index 35aaf14..09cb5a5 100644 --- a/tests/test_multithreaded_augmenter.py +++ b/tests/test_multithreaded_augmenter.py @@ -187,7 +187,8 @@ def test_image_pipeline_and_pin_memory(self): res = mt.next() assert isinstance(res['data'], torch.Tensor) - assert res['data'].is_pinned() + if torch.cuda.is_available(): + assert res['data'].is_pinned() # let mt finish caching, otherwise it's going to print an error (which is not a problem and will not prevent # the success of the test but it does not look pretty) diff --git a/tests/test_normalizations.py b/tests/test_normalizations.py index b8d1a1a..ae4896b 100644 --- a/tests/test_normalizations.py +++ b/tests/test_normalizations.py @@ -17,7 +17,7 @@ import numpy as np from batchgenerators.augmentations.normalizations import range_normalization, zero_mean_unit_variance_normalization, \ - cut_off_outliers + cut_off_outliers, mean_std_normalization class TestNormalization(unittest.TestCase): @@ -230,6 +230,33 @@ def test_cut_off_outliers_whole_image(self): print('Test test_cut_off_outliers_whole_image. [START]') + def test_mean_std_normalization_per_channel(self): + print('Test test_mean_std_normalization_per_channel. [START]') + data = np.random.random((32, 4, 64, 56, 48)) + + mean = [np.mean(data[:, i]) for i in range(4)] + std = [np.std(data[:, i]) for i in range(4)] + data_normalized = mean_std_normalization(data, mean, std, per_channel=True) + + for i in range(4): + self.assertAlmostEqual(data_normalized[:, i].mean(), 0.0) + self.assertAlmostEqual(data_normalized[:, i].std(), 1.0) + + print('Test test_mean_std_normalization_per_channel. [DONE]') + + def test_mean_std_normalization_whole_image(self): + print('Test test_mean_std_normalization_whole_image. [START]') + data = np.random.random((32, 4, 64, 56, 48)) + + mean = np.mean(data) + std = np.std(data) + data_normalized = mean_std_normalization(data, mean, std, per_channel=False) + + self.assertAlmostEqual(data_normalized.mean(), 0.0) + self.assertAlmostEqual(data_normalized.std(), 1.0) + + print('Test test_mean_std_normalization_whole_image. [DONE]') + if __name__ == '__main__': unittest.main() From 26ca086bc2843f61f6993554e0c240289ad3f3e7 Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 16 May 2023 14:41:39 +0300 Subject: [PATCH 05/60] Vectorized all normalizations per batch and per channel --- .../augmentations/normalizations.py | 67 +++++++++---------- 1 file changed, 33 insertions(+), 34 deletions(-) diff --git a/batchgenerators/augmentations/normalizations.py b/batchgenerators/augmentations/normalizations.py index 43f03d6..98d080e 100644 --- a/batchgenerators/augmentations/normalizations.py +++ b/batchgenerators/augmentations/normalizations.py @@ -17,40 +17,42 @@ def range_normalization(data, rnge=(0, 1), per_channel=True, eps=1e-8): - data_normalized = np.zeros(data.shape, dtype=data.dtype) - for b in range(data.shape[0]): - if per_channel: - for c in range(data.shape[1]): - data_normalized[b, c] = min_max_normalization(data[b, c], eps) - else: - data_normalized[b] = min_max_normalization(data[b], eps) + if per_channel: + axes = tuple(range(2, len(data.shape))) + else: + axes = tuple(range(1, len(data.shape))) + data_normalized = min_max_normalization_batched(data, eps, axes) data_normalized *= (rnge[1] - rnge[0]) data_normalized += rnge[0] return data_normalized +def min_max_normalization_batched(data, eps, axes): + mn = data.min(axis=axes) + mx = data.max(axis=axes) + old_range = mx - mn + eps + data_normalized = ((data.T - mn.T) / old_range.T).T + return data_normalized + + def min_max_normalization(data, eps): mn = data.min() mx = data.max() - data_normalized = data - mn old_range = mx - mn + eps - data_normalized /= old_range - + data_normalized = (data - mn) / old_range return data_normalized + def zero_mean_unit_variance_normalization(data, per_channel=True, epsilon=1e-8): - data_normalized = np.zeros(data.shape, dtype=data.dtype) - for b in range(data.shape[0]): - if per_channel: - for c in range(data.shape[1]): - mean = data[b, c].mean() - std = data[b, c].std() + epsilon - data_normalized[b, c] = (data[b, c] - mean) / std - else: - mean = data[b].mean() - std = data[b].std() + epsilon - data_normalized[b] = (data[b] - mean) / std + if per_channel: + axes = tuple(range(2, len(data.shape))) + else: + axes = tuple(range(1, len(data.shape))) + + mean = np.mean(data, axis=axes) + std = np.std(data, axis=axes) + epsilon + data_normalized = ((data.T - mean.T) / std.T).T return data_normalized @@ -72,23 +74,20 @@ def mean_std_normalization(data, mean, std, per_channel=True): assert len(std) == data_shape[1] if per_channel: - mean = np.array(mean) - std = np.array(std) - data_normalized = np.zeros(data.shape, dtype=data.dtype) - for b in range(data_shape[0]): - data_normalized[b] = ((data[b].T - mean) / std).T + mean = np.broadcast_to(mean, (len(data), len(mean))) + std = np.broadcast_to(std, (len(data), len(std))) + data_normalized = ((data.T - mean.T) / std.T).T else: data_normalized = (data - mean) / std return data_normalized def cut_off_outliers(data, percentile_lower=0.2, percentile_upper=99.8, per_channel=False): - for b in range(len(data)): - if not per_channel: - cut_off_lower, cut_off_upper = np.percentile(data[b], (percentile_lower, percentile_upper)) - np.clip(data[b], cut_off_lower, cut_off_upper, out=data[b]) - else: - for c in range(data.shape[1]): - cut_off_lower, cut_off_upper = np.percentile(data[b, c], (percentile_lower, percentile_upper)) - np.clip(data[b, c], cut_off_lower, cut_off_upper, out=data[b, c]) + if per_channel: + axes = tuple(range(2, len(data.shape))) + else: + axes = tuple(range(1, len(data.shape))) + + cut_off_lower, cut_off_upper = np.percentile(data, (percentile_lower, percentile_upper), axis=axes) + np.clip(data.T, cut_off_lower.T, cut_off_upper.T, out=data.T) return data From 7c6f37755ad5c29672f850a890c09aba714e8448 Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 16 May 2023 15:13:51 +0300 Subject: [PATCH 06/60] Gaussian noise and mean_std refactoring --- .../augmentations/noise_augmentations.py | 14 +++++------ .../augmentations/normalizations.py | 24 +++++++------------ 2 files changed, 15 insertions(+), 23 deletions(-) diff --git a/batchgenerators/augmentations/noise_augmentations.py b/batchgenerators/augmentations/noise_augmentations.py index bf8f5b1..6ba97a3 100644 --- a/batchgenerators/augmentations/noise_augmentations.py +++ b/batchgenerators/augmentations/noise_augmentations.py @@ -32,19 +32,19 @@ def augment_rician_noise(data_sample, noise_variance=(0, 0.1)): def augment_gaussian_noise(data_sample: np.ndarray, noise_variance: Tuple[float, float] = (0, 0.1), p_per_channel: float = 1, per_channel: bool = False) -> np.ndarray: + size = data_sample.shape[0] if not per_channel: variance = noise_variance[0] if noise_variance[0] == noise_variance[1] else \ random.uniform(noise_variance[0], noise_variance[1]) + variance = np.repeat(variance, size) else: - variance = None - for c in range(data_sample.shape[0]): + variance = np.repeat(noise_variance[0], size) if noise_variance[0] == noise_variance[1] else \ + np.random.uniform(noise_variance[0], noise_variance[1], size=size) + + for c in range(size): if np.random.uniform() < p_per_channel: - # lol good luck reading this - variance_here = variance if variance is not None else \ - noise_variance[0] if noise_variance[0] == noise_variance[1] else \ - random.uniform(noise_variance[0], noise_variance[1]) # bug fixed: https://github.com/MIC-DKFZ/batchgenerators/issues/86 - data_sample[c] += np.random.normal(0.0, variance_here, size=data_sample[c].shape) + data_sample[c] += np.random.normal(0.0, variance[c], size=data_sample[c].shape) return data_sample diff --git a/batchgenerators/augmentations/normalizations.py b/batchgenerators/augmentations/normalizations.py index 98d080e..8ed157c 100644 --- a/batchgenerators/augmentations/normalizations.py +++ b/batchgenerators/augmentations/normalizations.py @@ -57,23 +57,15 @@ def zero_mean_unit_variance_normalization(data, per_channel=True, epsilon=1e-8): def mean_std_normalization(data, mean, std, per_channel=True): - if isinstance(data, np.ndarray): - data_shape = data.shape - elif isinstance(data, (list, tuple)): - assert len(data) > 0 and isinstance(data[0], np.ndarray) - data_shape = (len(data),) + data[0].shape - else: - raise TypeError("Data has to be either a numpy array or a list") - - if per_channel and isinstance(mean, float) and isinstance(std, float): - mean = [mean] * data_shape[1] - std = [std] * data_shape[1] - elif per_channel and isinstance(mean, (tuple, list, np.ndarray)): - assert len(mean) == data_shape[1] - elif per_channel and isinstance(std, (tuple, list, np.ndarray)): - assert len(std) == data_shape[1] - if per_channel: + channel_dimension = data[0].shape[0] + if isinstance(mean, float) and isinstance(std, float): + mean = [mean] * channel_dimension + std = [std] * channel_dimension + else: + assert len(mean) == channel_dimension + assert len(std) == channel_dimension + mean = np.broadcast_to(mean, (len(data), len(mean))) std = np.broadcast_to(std, (len(data), len(std))) data_normalized = ((data.T - mean.T) / std.T).T From 659e1fbddc05c78ef881417382a59f1b3e3d3b28 Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 17 May 2023 11:27:20 +0300 Subject: [PATCH 07/60] Vectorized augment_contrast and augment_brightness_additive --- .../augmentations/color_augmentations.py | 79 +++++++++---------- tests/test_color_augmentations.py | 2 + 2 files changed, 38 insertions(+), 43 deletions(-) diff --git a/batchgenerators/augmentations/color_augmentations.py b/batchgenerators/augmentations/color_augmentations.py index 51acc43..dd4b9b1 100644 --- a/batchgenerators/augmentations/color_augmentations.py +++ b/batchgenerators/augmentations/color_augmentations.py @@ -25,7 +25,19 @@ def augment_contrast(data_sample: np.ndarray, preserve_range: bool = True, per_channel: bool = True, p_per_channel: float = 1) -> np.ndarray: - if not per_channel: + size = data_sample.shape[0] + if per_channel: + if callable(contrast_range): + factor = [contrast_range() for _ in range(size)] + else: + factor = [] + for _ in range(size): + if np.random.random() < 0.5 and contrast_range[0] < 1: + factor.append(np.random.uniform(contrast_range[0], 1)) + else: + factor.append(np.random.uniform(max(contrast_range[0], 1), contrast_range[1])) + factor = np.array(factor) + else: if callable(contrast_range): factor = contrast_range() else: @@ -34,37 +46,19 @@ def augment_contrast(data_sample: np.ndarray, else: factor = np.random.uniform(max(contrast_range[0], 1), contrast_range[1]) - for c in range(data_sample.shape[0]): - if np.random.uniform() < p_per_channel: - mn = data_sample[c].mean() - if preserve_range: - minm = data_sample[c].min() - maxm = data_sample[c].max() + mask = np.random.uniform(size=size) < p_per_channel + workon = data_sample[mask] + if len(workon) > 0: + axes = tuple(range(1, len(data_sample.shape))) + mean = workon.mean(axis=axes) + if preserve_range: + minm = workon.min(axis=axes) + maxm = workon.max(axis=axes) - data_sample[c] = data_sample[c] * factor + mn * (1 - factor) + data_sample[mask] = (workon.T * factor + mean * (1 - factor)).T # writing directly in data_sample - if preserve_range: - np.clip(data_sample[c], minm, maxm, out=data_sample[c]) - else: - for c in range(data_sample.shape[0]): - if np.random.uniform() < p_per_channel: - if callable(contrast_range): - factor = contrast_range() - else: - if np.random.random() < 0.5 and contrast_range[0] < 1: - factor = np.random.uniform(contrast_range[0], 1) - else: - factor = np.random.uniform(max(contrast_range[0], 1), contrast_range[1]) - - mn = data_sample[c].mean() - if preserve_range: - minm = data_sample[c].min() - maxm = data_sample[c].max() - - data_sample[c] = data_sample[c] * factor + mn * (1 - factor) - - if preserve_range: - np.clip(data_sample[c], minm, maxm, out=data_sample[c]) + if preserve_range: + np.clip(data_sample[mask], minm, maxm, out=data_sample[mask]) return data_sample @@ -79,27 +73,26 @@ def augment_brightness_additive(data_sample, mu:float, sigma:float , per_channel :param p_per_channel: :return: """ - if not per_channel: - rnd_nb = np.random.normal(mu, sigma) - for c in range(data_sample.shape[0]): - if np.random.uniform() <= p_per_channel: - data_sample[c] += rnd_nb + size = data_sample.shape[0] + if per_channel: + rnd_nb = np.random.normal(mu, sigma, size=size) else: - for c in range(data_sample.shape[0]): - if np.random.uniform() <= p_per_channel: - rnd_nb = np.random.normal(mu, sigma) - data_sample[c] += rnd_nb + rnd_nb = np.repeat(np.random.normal(mu, sigma), size) + rnd_nb[np.random.uniform(size=size) > p_per_channel] = 0.0 + axes = tuple(range(len(data_sample.shape) - 1)) + data_sample += np.expand_dims(rnd_nb, axis=axes).T # Broadcasting rules require this return data_sample def augment_brightness_multiplicative(data_sample, multiplier_range=(0.5, 2), per_channel=True): if not per_channel: multiplier = np.random.uniform(multiplier_range[0], multiplier_range[1]) - data_sample *= multiplier else: - for c in range(data_sample.shape[0]): - multiplier = np.random.uniform(multiplier_range[0], multiplier_range[1]) - data_sample[c] *= multiplier + axes = [1 for _ in range(len(data_sample.shape))] + axes[0] = data_sample.shape[0] + multiplier = np.random.uniform(multiplier_range[0], multiplier_range[1], size=axes) + + data_sample *= multiplier return data_sample diff --git a/tests/test_color_augmentations.py b/tests/test_color_augmentations.py index 5ea5802..64f17f5 100644 --- a/tests/test_color_augmentations.py +++ b/tests/test_color_augmentations.py @@ -27,7 +27,9 @@ def setUp(self): self.data_2D = np.random.random((2, 64, 56)) self.factor = (0.75, 1.25) + self.d_3D = augment_contrast(self.data_3D, contrast_range=self.factor, preserve_range=False, per_channel=True) self.d_3D = augment_contrast(self.data_3D, contrast_range=self.factor, preserve_range=False, per_channel=False) + self.d_2D = augment_contrast(self.data_2D, contrast_range=self.factor, preserve_range=False, per_channel=True) self.d_2D = augment_contrast(self.data_2D, contrast_range=self.factor, preserve_range=False, per_channel=False) def test_augment_contrast_3D(self): From 0a1dd793e8b2bde222ceb3862a0efbcbf7115fde Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 17 May 2023 12:37:24 +0300 Subject: [PATCH 08/60] Making convert_seg_image_to_one_hot_encoding batched and using it --- batchgenerators/augmentations/utils.py | 8 ++++---- batchgenerators/transforms/utility_transforms.py | 16 +++++++--------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/batchgenerators/augmentations/utils.py b/batchgenerators/augmentations/utils.py index 021321f..eaa0b67 100755 --- a/batchgenerators/augmentations/utils.py +++ b/batchgenerators/augmentations/utils.py @@ -47,6 +47,7 @@ def convert_seg_image_to_one_hot_encoding(image, classes=None): image must be either (x, y, z) or (x, y) Takes as input an nd array of a label map (any dimension). Outputs a one hot encoding of the label map. Example (3D): if input is of shape (x, y, z), the output will ne of shape (n_classes, x, y, z) + Prefer convert_seg_image_to_one_hot_encoding_batched. ''' if classes is None: classes = np.unique(image) @@ -62,11 +63,10 @@ def convert_seg_image_to_one_hot_encoding_batched(image, classes=None): ''' if classes is None: classes = np.unique(image) - output_shape = [image.shape[0]] + [len(classes)] + list(image.shape[1:]) + output_shape = (image.shape[0], len(classes)) + image.shape[1:] out_image = np.zeros(output_shape, dtype=image.dtype) - for b in range(image.shape[0]): - for i, c in enumerate(classes): - out_image[b, i][image[b] == c] = 1 + for i, c in enumerate(classes): + out_image[:, i][image == c] = 1 return out_image diff --git a/batchgenerators/transforms/utility_transforms.py b/batchgenerators/transforms/utility_transforms.py index 6ff2ae4..f12dd8e 100644 --- a/batchgenerators/transforms/utility_transforms.py +++ b/batchgenerators/transforms/utility_transforms.py @@ -14,12 +14,12 @@ # limitations under the License. import copy -from typing import List, Type, Union, Tuple +from typing import List, Union, Tuple import numpy as np -from batchgenerators.augmentations.utils import convert_seg_image_to_one_hot_encoding, \ - convert_seg_to_bounding_box_coordinates, transpose_channels +from batchgenerators.augmentations.utils import convert_seg_to_bounding_box_coordinates, transpose_channels, \ + convert_seg_image_to_one_hot_encoding_batched from batchgenerators.transforms.abstract_transforms import AbstractTransform @@ -118,9 +118,7 @@ def __init__(self, classes, seg_channel=0, output_key="seg"): def __call__(self, **data_dict): seg = data_dict.get("seg") if seg is not None: - new_seg = np.zeros([seg.shape[0], len(self.classes)] + list(seg.shape[2:]), dtype=seg.dtype) - for b in range(seg.shape[0]): - new_seg[b] = convert_seg_image_to_one_hot_encoding(seg[b, self.seg_channel], self.classes) + new_seg = convert_seg_image_to_one_hot_encoding_batched(seg[:, self.seg_channel], self.classes) data_dict[self.output_key] = new_seg else: from warnings import warn @@ -139,9 +137,9 @@ def __call__(self, **data_dict): seg = data_dict.get("seg") if seg is not None: new_seg = np.zeros([seg.shape[0], len(self.classes) * seg.shape[1]] + list(seg.shape[2:]), dtype=seg.dtype) - for b in range(seg.shape[0]): - for c in range(seg.shape[1]): - new_seg[b, c*len(self.classes):(c+1)*len(self.classes)] = convert_seg_image_to_one_hot_encoding(seg[b, c], self.classes) + for c in range(seg.shape[1]): + new_seg[:, c * len(self.classes):(c + 1) * len(self.classes)] = \ + convert_seg_image_to_one_hot_encoding_batched(seg[:, c], self.classes) data_dict["seg"] = new_seg else: from warnings import warn From b9c9161f69cae7b11d040a7bc3ad52dbc98b59d6 Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 17 May 2023 12:39:43 +0300 Subject: [PATCH 09/60] Various small improvements caching, optimizing conditionals, using tuples instead of lists, doing operations inplace --- .../crop_and_pad_augmentations.py | 5 +- .../augmentations/spatial_transformations.py | 2 +- batchgenerators/augmentations/utils.py | 64 +++++++++---------- 3 files changed, 36 insertions(+), 35 deletions(-) diff --git a/batchgenerators/augmentations/crop_and_pad_augmentations.py b/batchgenerators/augmentations/crop_and_pad_augmentations.py index 5930113..0728890 100644 --- a/batchgenerators/augmentations/crop_and_pad_augmentations.py +++ b/batchgenerators/augmentations/crop_and_pad_augmentations.py @@ -95,13 +95,14 @@ def crop(data, seg=None, crop_size=128, margins=(0, 0, 0), crop_type="center", assert len(crop_size) == len( data_shape) - 2, "If you provide a list/tuple as center crop make sure it has the same dimension as your " \ "data (2d/3d)" + crop_size = tuple(crop_size) if not isinstance(margins, (np.ndarray, tuple, list)): margins = [margins] * dim - data_return = np.zeros((data_shape[0], data_shape[1]) + tuple(crop_size), dtype=data_dtype) + data_return = np.zeros((data_shape[0], data_shape[1]) + crop_size, dtype=data_dtype) if seg is not None: - seg_return = np.zeros((seg_shape[0], seg_shape[1]) + tuple(crop_size), dtype=seg_dtype) + seg_return = np.zeros((seg_shape[0], seg_shape[1]) + crop_size, dtype=seg_dtype) else: seg_return = None diff --git a/batchgenerators/augmentations/spatial_transformations.py b/batchgenerators/augmentations/spatial_transformations.py index 0b41565..fe0677b 100644 --- a/batchgenerators/augmentations/spatial_transformations.py +++ b/batchgenerators/augmentations/spatial_transformations.py @@ -36,7 +36,7 @@ def augment_rot90(sample_data, sample_seg, num_rot=(1, 2, 3), axes=(0, 1, 2)): num_rot = np.random.choice(num_rot) axes = np.random.choice(axes, size=2, replace=False) axes.sort() - axes = [i + 1 for i in axes] + axes += 1 sample_data = np.rot90(sample_data, num_rot, axes) if sample_seg is not None: sample_seg = np.rot90(sample_seg, num_rot, axes) diff --git a/batchgenerators/augmentations/utils.py b/batchgenerators/augmentations/utils.py index eaa0b67..34cd345 100755 --- a/batchgenerators/augmentations/utils.py +++ b/batchgenerators/augmentations/utils.py @@ -37,8 +37,9 @@ def generate_elastic_transform_coordinates(shape, alpha, sigma): def create_zero_centered_coordinate_mesh(shape): tmp = tuple([np.arange(i) for i in shape]) coords = np.array(np.meshgrid(*tmp, indexing='ij')).astype(float) + to_add = ((np.array(shape).astype(float) - 1) / 2.) for d in range(len(shape)): - coords[d] -= ((np.array(shape).astype(float) - 1) / 2.)[d] + coords[d] -= to_add[d] return coords @@ -100,9 +101,10 @@ def elastic_deform_coordinates_2(coordinates, sigmas, magnitudes): random_values_ = np.fft.fftn(random_values) deformation_field = fourier_gaussian(random_values_, sigmas) deformation_field = np.fft.ifftn(deformation_field).real + mx = np.max(np.abs(deformation_field)) + deformation_field *= (magnitudes[d] + 1e-8) / mx offsets.append(deformation_field) - mx = np.max(np.abs(offsets[-1])) - offsets[-1] = offsets[-1] / (mx / (magnitudes[d] + 1e-8)) + offsets = np.array(offsets) indices = offsets + coordinates return indices @@ -134,10 +136,10 @@ def scale_coords(coords, scale): def uncenter_coords(coords): - shp = coords.shape[1:] + shp = (coords.shape[1:] - 1) / 2. coords = deepcopy(coords) for d in range(coords.shape[0]): - coords[d] += (shp[d] - 1) / 2. + coords[d] += shp[d] return coords @@ -145,7 +147,7 @@ def interpolate_img(img, coords, order=3, mode='nearest', cval=0.0, is_seg=False if is_seg and order != 0: unique_labels = np.unique(img) result = np.zeros(coords.shape[1:], img.dtype) - for i, c in enumerate(unique_labels): + for c in unique_labels: res_new = map_coordinates((img == c).astype(float), coords, order=order, mode=mode, cval=cval) result[res_new >= 0.5] = c return result @@ -160,10 +162,9 @@ def generate_noise(shape, alpha, sigma): def find_entries_in_array(entries, myarray): - entries = np.array(entries) - values = np.arange(np.max(myarray) + 1) - lut = np.zeros(len(values), 'bool') - lut[entries.astype("int")] = True + entries = np.array(entries, dtype=int) + lut = np.zeros(np.max(myarray) + 1, 'bool') + lut[entries] = True return np.take(lut, myarray.astype(int)) @@ -330,8 +331,8 @@ def random_crop_2D_image_batched(img, crop_size): def resize_image_by_padding(image, new_shape, pad_value=None): - shape = tuple(list(image.shape)) - new_shape = tuple(np.max(np.concatenate((shape, new_shape)).reshape((2, len(shape))), axis=0)) + shape = image.shape + new_shape = np.max(np.concatenate((shape, new_shape)).reshape((2, len(shape))), axis=0) if pad_value is None: if len(shape) == 2: pad_value = image[0, 0] @@ -339,8 +340,8 @@ def resize_image_by_padding(image, new_shape, pad_value=None): pad_value = image[0, 0, 0] else: raise ValueError("Image must be either 2 or 3 dimensional") - res = np.ones(list(new_shape), dtype=image.dtype) * pad_value - start = np.array(new_shape) / 2. - np.array(shape) / 2. + res = np.ones(new_shape, dtype=image.dtype) * pad_value + start = new_shape / 2. - np.array(shape) / 2. if len(shape) == 2: res[int(start[0]):int(start[0]) + int(shape[0]), int(start[1]):int(start[1]) + int(shape[1])] = image elif len(shape) == 3: @@ -350,8 +351,8 @@ def resize_image_by_padding(image, new_shape, pad_value=None): def resize_image_by_padding_batched(image, new_shape, pad_value=None): - shape = tuple(list(image.shape[2:])) - new_shape = tuple(np.max(np.concatenate((shape, new_shape)).reshape((2, len(shape))), axis=0)) + shape = image.shape[2:] + new_shape = np.max(np.concatenate((shape, new_shape)).reshape((2, len(shape))), axis=0) if pad_value is None: if len(shape) == 2: pad_value = image[0, 0] @@ -359,7 +360,7 @@ def resize_image_by_padding_batched(image, new_shape, pad_value=None): pad_value = image[0, 0, 0] else: raise ValueError("Image must be either 2 or 3 dimensional") - start = np.array(new_shape) / 2. - np.array(shape) / 2. + start = new_shape / 2. - np.array(shape) / 2. if len(shape) == 2: res = np.ones((image.shape[0], image.shape[1], new_shape[0], new_shape[1]), dtype=image.dtype) * pad_value res[:, :, int(start[0]):int(start[0]) + int(shape[0]), int(start[1]):int(start[1]) + int(shape[1])] = image[:, @@ -476,16 +477,18 @@ def general_cc_var_num_channels(img, diff_order=0, mink_norm=1, sigma=1, mask_im for c in range(img_internal.shape[0]): white_colors.append(np.max(img_internal[c][mask_im != 1])) - som = np.sqrt(np.sum([i ** 2 for i in white_colors])) + white_colors = np.array(white_colors) + som = np.sqrt(np.sum(np.power(white_colors, 2))) - white_colors = [i / som for i in white_colors] + white_colors /= som + white_colors *= np.sqrt(3.) for c in range(output_img.shape[0]): - output_img[c] /= (white_colors[c] * np.sqrt(3.)) + output_img[c] /= white_colors[c] if clip_range: - output_img[output_img < minm] = minm - output_img[output_img > maxm] = maxm + np.clip(output_img, minm, maxm, out= output_img) + return white_colors, output_img @@ -598,7 +601,7 @@ def resize_segmentation(segmentation, new_shape, order=3): else: reshaped = np.zeros(new_shape, dtype=segmentation.dtype) - for i, c in enumerate(unique_labels): + for c in unique_labels: mask = segmentation == c reshaped_multihot = resize(mask.astype(float), new_shape, order, mode="edge", clip=True, anti_aliasing=False) reshaped[reshaped_multihot >= 0.5] = c @@ -693,9 +696,7 @@ def pad_nd_image(image, new_shape=None, mode="constant", kwargs=None, return_sli num_axes_nopad = len(image.shape) - len(new_shape) new_shape = [max(new_shape[i], old_shape[i]) for i in range(len(new_shape))] - - if not isinstance(new_shape, np.ndarray): - new_shape = np.array(new_shape) + new_shape = np.array(new_shape) if shape_must_be_divisible_by is not None: if not isinstance(shape_must_be_divisible_by, (list, tuple, np.ndarray)): @@ -704,17 +705,16 @@ def pad_nd_image(image, new_shape=None, mode="constant", kwargs=None, return_sli assert len(shape_must_be_divisible_by) == len(new_shape) for i in range(len(new_shape)): - if new_shape[i] % shape_must_be_divisible_by[i] == 0: - new_shape[i] -= shape_must_be_divisible_by[i] - - new_shape = np.array([new_shape[i] + shape_must_be_divisible_by[i] - new_shape[i] % shape_must_be_divisible_by[i] for i in range(len(new_shape))]) + modulo = new_shape[i] % shape_must_be_divisible_by[i] + if modulo != 0: + new_shape[i] += shape_must_be_divisible_by[i] - modulo difference = new_shape - old_shape pad_below = difference // 2 - pad_above = difference // 2 + difference % 2 + pad_above = pad_below + difference % 2 pad_list = [[0, 0]]*num_axes_nopad + list([list(i) for i in zip(pad_below, pad_above)]) - if not ((all([i == 0 for i in pad_below])) and (all([i == 0 for i in pad_above]))): + if np.any(pad_below) or np.any(pad_above): res = np.pad(image, pad_list, mode, **kwargs) else: res = image From 58e324b3aab759376d7fac2dde36bfac476a53fe Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 17 May 2023 13:45:04 +0300 Subject: [PATCH 10/60] Improving dataloader and numpy to tensor --- batchgenerators/augmentations/utils.py | 3 +-- batchgenerators/dataloading/data_loader.py | 12 +++++---- .../channel_selection_transforms.py | 2 +- .../transforms/local_transforms.py | 8 +++--- .../transforms/utility_transforms.py | 26 +++++++++---------- 5 files changed, 26 insertions(+), 25 deletions(-) diff --git a/batchgenerators/augmentations/utils.py b/batchgenerators/augmentations/utils.py index 34cd345..226d53b 100755 --- a/batchgenerators/augmentations/utils.py +++ b/batchgenerators/augmentations/utils.py @@ -480,8 +480,7 @@ def general_cc_var_num_channels(img, diff_order=0, mink_norm=1, sigma=1, mask_im white_colors = np.array(white_colors) som = np.sqrt(np.sum(np.power(white_colors, 2))) - white_colors /= som - white_colors *= np.sqrt(3.) + white_colors *= np.sqrt(3.) / som for c in range(output_img.shape[0]): output_img[c] /= white_colors[c] diff --git a/batchgenerators/dataloading/data_loader.py b/batchgenerators/dataloading/data_loader.py index 28d0df1..fc9f84f 100644 --- a/batchgenerators/dataloading/data_loader.py +++ b/batchgenerators/dataloading/data_loader.py @@ -169,6 +169,10 @@ def __init__(self, data, batch_size, num_threads_in_multithreaded=1, seed_for_sh # when you derive, make sure to set this! We can't set it here because we don't know what data will be like self.indices = None + if self.infinite: + # Use separate get indices method + self.get_indices = self.get_indices_infinite + def reset(self): assert self.indices is not None @@ -182,11 +186,10 @@ def reset(self): self.last_reached = False - def get_indices(self): - # if self.infinite, this is easy - if self.infinite: - return np.random.choice(self.indices, self.batch_size, replace=True, p=self.sampling_probabilities) + def get_indices_infinite(self): + return np.random.choice(self.indices, self.batch_size, replace=True, p=self.sampling_probabilities) + def get_indices(self): if self.last_reached: self.reset() raise StopIteration @@ -199,7 +202,6 @@ def get_indices(self): for b in range(self.batch_size): if self.current_position < len(self.indices): indices.append(self.indices[self.current_position]) - self.current_position += 1 else: self.last_reached = True diff --git a/batchgenerators/transforms/channel_selection_transforms.py b/batchgenerators/transforms/channel_selection_transforms.py index ec89cfe..601a375 100644 --- a/batchgenerators/transforms/channel_selection_transforms.py +++ b/batchgenerators/transforms/channel_selection_transforms.py @@ -118,7 +118,7 @@ def __call__(self, **data_dict): random_number = np.random.rand() if random_number < self.swap_probability: seg[:, [self.axis1, self.axis2]] = seg[:, [self.axis2, self.axis1]] - data_dict[self.label_key] = seg + data_dict[self.label_key] = seg return data_dict diff --git a/batchgenerators/transforms/local_transforms.py b/batchgenerators/transforms/local_transforms.py index 8cb9ca4..42363aa 100644 --- a/batchgenerators/transforms/local_transforms.py +++ b/batchgenerators/transforms/local_transforms.py @@ -61,8 +61,8 @@ def _generate_kernel(self, img_shp: Tuple[int, ...]) -> np.ndarray: kernel_image = kernel_2d # normalize to [0, 1] - kernel_image = kernel_image - kernel_image.min() - kernel_image = kernel_image / max(1e-8, kernel_image.max()) + kernel_image -= kernel_image.min() + kernel_image /= max(1e-8, kernel_image.max()) return kernel_image def _generate_multiple_kernel_image(self, img_shp: Tuple[int, ...], num_kernels: int) -> np.ndarray: @@ -167,7 +167,7 @@ def __call__(self, **data_dict): # now rescale so that the maximum value of the kernel is max_strength strength = sample_scalar(self.max_strength, data[bi, ci], kernel) if callable( self.max_strength) else strength - kernel_scaled = np.copy(kernel) / mx * strength + kernel_scaled = kernel / mx * strength data[bi, ci] += kernel_scaled else: for ci in range(c): @@ -177,7 +177,7 @@ def __call__(self, **data_dict): kernel -= kernel.mean() mx = max(np.max(np.abs(kernel)), 1e-8) strength = sample_scalar(self.max_strength, data[bi, ci], kernel) - kernel = kernel / mx * strength + kernel *= strength / mx data[bi, ci] += kernel return data_dict diff --git a/batchgenerators/transforms/utility_transforms.py b/batchgenerators/transforms/utility_transforms.py index f12dd8e..593332c 100644 --- a/batchgenerators/transforms/utility_transforms.py +++ b/batchgenerators/transforms/utility_transforms.py @@ -17,6 +17,7 @@ from typing import List, Union, Tuple import numpy as np +import torch from batchgenerators.augmentations.utils import convert_seg_to_bounding_box_coordinates, transpose_channels, \ convert_seg_image_to_one_hot_encoding_batched @@ -34,25 +35,24 @@ def __init__(self, keys=None, cast_to=None): if keys is not None and not isinstance(keys, (list, tuple)): keys = [keys] self.keys = keys - self.cast_to = cast_to + if cast_to is not None: + if cast_to == 'half': + self.cast_to = torch.half + elif cast_to == 'float': + self.cast_to = torch.float + elif cast_to == 'long': + self.cast_to = torch.long + elif cast_to == 'bool': + self.cast_to = torch.bool + else: + raise ValueError(f'Unknown value for cast_to: {self.cast_to}') def cast(self, tensor): if self.cast_to is not None: - if self.cast_to == 'half': - tensor = tensor.half() - elif self.cast_to == 'float': - tensor = tensor.float() - elif self.cast_to == 'long': - tensor = tensor.long() - elif self.cast_to == 'bool': - tensor = tensor.bool() - else: - raise ValueError('Unknown value for cast_to: %s' % self.cast_to) + tensor = tensor.to(self.cast_to) return tensor def __call__(self, **data_dict): - import torch - if self.keys is None: for key, val in data_dict.items(): if isinstance(val, np.ndarray): From f42847d4b9fa7a5a36ab4c660d56d915d262ab43 Mon Sep 17 00:00:00 2001 From: ancestor-mithril <58839912+ancestor-mithril@users.noreply.github.com> Date: Sat, 20 May 2023 13:53:14 +0300 Subject: [PATCH 11/60] Setting cast to None in NumpyToTensor --- batchgenerators/transforms/utility_transforms.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/batchgenerators/transforms/utility_transforms.py b/batchgenerators/transforms/utility_transforms.py index 593332c..a31e3e8 100644 --- a/batchgenerators/transforms/utility_transforms.py +++ b/batchgenerators/transforms/utility_transforms.py @@ -46,6 +46,8 @@ def __init__(self, keys=None, cast_to=None): self.cast_to = torch.bool else: raise ValueError(f'Unknown value for cast_to: {self.cast_to}') + else: + self.cast_to = None def cast(self, tensor): if self.cast_to is not None: From 0637882248611aa1e8213019f39d4e4d4c9c0e25 Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Wed, 16 Aug 2023 13:01:55 +0300 Subject: [PATCH 12/60] Optimizing rotations --- batchgenerators/augmentations/utils.py | 45 +++++++++++++------------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/batchgenerators/augmentations/utils.py b/batchgenerators/augmentations/utils.py index 226d53b..831aac2 100755 --- a/batchgenerators/augmentations/utils.py +++ b/batchgenerators/augmentations/utils.py @@ -375,39 +375,38 @@ def resize_image_by_padding_batched(image, new_shape, pad_value=None): return res -def create_matrix_rotation_x_3d(angle, matrix=None): - rotation_x = np.array([[1, 0, 0], - [0, np.cos(angle), -np.sin(angle)], - [0, np.sin(angle), np.cos(angle)]]) - if matrix is None: - return rotation_x - +def create_matrix_rotation_x_3d(angle, matrix: np.ndarray): + cos_a = np.cos(angle) + sin_a = np.sin(angle) + rotation_x = np.array(((1, 0, 0), + (0, cos_a, -sin_a), + (0, sin_a, cos_a))) return np.dot(matrix, rotation_x) -def create_matrix_rotation_y_3d(angle, matrix=None): - rotation_y = np.array([[np.cos(angle), 0, np.sin(angle)], - [0, 1, 0], - [-np.sin(angle), 0, np.cos(angle)]]) - if matrix is None: - return rotation_y - +def create_matrix_rotation_y_3d(angle, matrix: np.ndarray): + cos_a = np.cos(angle) + sin_a = np.sin(angle) + rotation_y = np.array(((cos_a, 0, sin_a), + (0, 1, 0), + (-sin_a, 0, cos_a))) return np.dot(matrix, rotation_y) -def create_matrix_rotation_z_3d(angle, matrix=None): - rotation_z = np.array([[np.cos(angle), -np.sin(angle), 0], - [np.sin(angle), np.cos(angle), 0], - [0, 0, 1]]) - if matrix is None: - return rotation_z - +def create_matrix_rotation_z_3d(angle, matrix: np.ndarray): + cos_a = np.cos(angle) + sin_a = np.sin(angle) + rotation_z = np.array(((cos_a, -sin_a, 0), + (sin_a, cos_a, 0), + (0, 0, 1))) return np.dot(matrix, rotation_z) def create_matrix_rotation_2d(angle, matrix=None): - rotation = np.array([[np.cos(angle), -np.sin(angle)], - [np.sin(angle), np.cos(angle)]]) + cos_a = np.cos(angle) + sin_a = np.sin(angle) + rotation = np.array(((cos_a, -sin_a), + (sin_a, cos_a))) if matrix is None: return rotation From 9b9c4465d3494e2cf6e7f989de891a2f661ca2c8 Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Wed, 16 Aug 2023 13:02:33 +0300 Subject: [PATCH 13/60] Doing batched augment_contrast --- .../augmentations/color_augmentations.py | 18 ++++++++---- .../transforms/color_transforms.py | 14 +++++----- tests/test_color_augmentations.py | 28 +++++++++++++++++-- 3 files changed, 45 insertions(+), 15 deletions(-) diff --git a/batchgenerators/augmentations/color_augmentations.py b/batchgenerators/augmentations/color_augmentations.py index dd4b9b1..e5cac10 100644 --- a/batchgenerators/augmentations/color_augmentations.py +++ b/batchgenerators/augmentations/color_augmentations.py @@ -24,32 +24,38 @@ def augment_contrast(data_sample: np.ndarray, contrast_range: Union[Tuple[float, float], Callable[[], float]] = (0.75, 1.25), preserve_range: bool = True, per_channel: bool = True, - p_per_channel: float = 1) -> np.ndarray: - size = data_sample.shape[0] + p_per_channel: float = 1, + batched=False) -> np.ndarray: + size = data_sample.shape[1 if batched else 0] if per_channel: if callable(contrast_range): factor = [contrast_range() for _ in range(size)] else: factor = [] for _ in range(size): - if np.random.random() < 0.5 and contrast_range[0] < 1: + if contrast_range[0] < 1 and np.random.random() < 0.5: factor.append(np.random.uniform(contrast_range[0], 1)) else: factor.append(np.random.uniform(max(contrast_range[0], 1), contrast_range[1])) + factor = np.array(factor) + if batched: + factor = factor.repeat(data_sample.shape[0]) else: if callable(contrast_range): factor = contrast_range() else: - if np.random.random() < 0.5 and contrast_range[0] < 1: + if contrast_range[0] < 1 and np.random.random() < 0.5: factor = np.random.uniform(contrast_range[0], 1) else: factor = np.random.uniform(max(contrast_range[0], 1), contrast_range[1]) mask = np.random.uniform(size=size) < p_per_channel + if batched: + mask = np.atleast_2d(mask).repeat(data_sample.shape[0], axis=0) workon = data_sample[mask] if len(workon) > 0: - axes = tuple(range(1, len(data_sample.shape))) + axes = tuple(range(1, len(workon.shape))) mean = workon.mean(axis=axes) if preserve_range: minm = workon.min(axis=axes) @@ -58,7 +64,7 @@ def augment_contrast(data_sample: np.ndarray, data_sample[mask] = (workon.T * factor + mean * (1 - factor)).T # writing directly in data_sample if preserve_range: - np.clip(data_sample[mask], minm, maxm, out=data_sample[mask]) + np.clip(data_sample[mask].T, minm, maxm, out=data_sample[mask].T) return data_sample diff --git a/batchgenerators/transforms/color_transforms.py b/batchgenerators/transforms/color_transforms.py index e5b735d..8f4972e 100644 --- a/batchgenerators/transforms/color_transforms.py +++ b/batchgenerators/transforms/color_transforms.py @@ -52,13 +52,13 @@ def __init__(self, self.p_per_channel = p_per_channel def __call__(self, **data_dict): - for b in range(len(data_dict[self.data_key])): - if np.random.uniform() < self.p_per_sample: - data_dict[self.data_key][b] = augment_contrast(data_dict[self.data_key][b], - contrast_range=self.contrast_range, - preserve_range=self.preserve_range, - per_channel=self.per_channel, - p_per_channel=self.p_per_channel) + mask = np.random.uniform(size=len(data_dict[self.data_key])) < self.p_per_sample + data_dict[self.data_key][mask] = augment_contrast(data_dict[self.data_key][mask], + contrast_range=self.contrast_range, + preserve_range=self.preserve_range, + per_channel=self.per_channel, + p_per_channel=self.p_per_channel, + batched=True) return data_dict diff --git a/tests/test_color_augmentations.py b/tests/test_color_augmentations.py index 64f17f5..7cefda6 100644 --- a/tests/test_color_augmentations.py +++ b/tests/test_color_augmentations.py @@ -26,12 +26,36 @@ def setUp(self): self.data_3D = np.random.random((2, 64, 56, 48)) self.data_2D = np.random.random((2, 64, 56)) self.factor = (0.75, 1.25) + self.data_4D = np.random.random((9, 2, 64, 56, 48)) - self.d_3D = augment_contrast(self.data_3D, contrast_range=self.factor, preserve_range=False, per_channel=True) + self.d_4D = augment_contrast(self.data_4D, contrast_range=self.factor, preserve_range=True, per_channel=True, batched=True) + self.d_4D = augment_contrast(self.data_4D, contrast_range=self.factor, preserve_range=False, per_channel=False, batched=True) + self.d_3D = augment_contrast(self.data_3D, contrast_range=self.factor, preserve_range=True, per_channel=True) self.d_3D = augment_contrast(self.data_3D, contrast_range=self.factor, preserve_range=False, per_channel=False) - self.d_2D = augment_contrast(self.data_2D, contrast_range=self.factor, preserve_range=False, per_channel=True) + self.d_2D = augment_contrast(self.data_2D, contrast_range=self.factor, preserve_range=True, per_channel=True) self.d_2D = augment_contrast(self.data_2D, contrast_range=self.factor, preserve_range=False, per_channel=False) + def test_augment_contrast_4D(self): + data = self.data_4D[0] + mean = np.mean(data) + + idx0 = np.where(data < mean) # where the data is lower than mean value + idx1 = np.where(data > mean) # where the data is greater than mean value + + contrast_lower_limit_0 = self.factor[1] * (data[idx0] - mean) + mean + contrast_lower_limit_1 = self.factor[0] * (data[idx1] - mean) + mean + contrast_upper_limit_0 = self.factor[0] * (data[idx0] - mean) + mean + contrast_upper_limit_1 = self.factor[1] * (data[idx1] - mean) + mean + + # augmented values lower than mean should be lower than lower limit and greater than upper limit + self.assertTrue(np.all(np.logical_and(self.d_4D[0][idx0] >= contrast_lower_limit_0, + self.d_4D[0][idx0] <= contrast_upper_limit_0)), + "Augmented contrast below mean value not within range") + # augmented values greater than mean should be lower than upper limit and greater than lower limit + self.assertTrue(np.all(np.logical_and(self.d_4D[0][idx1] >= contrast_lower_limit_1, + self.d_4D[0][idx1] <= contrast_upper_limit_1)), + "Augmented contrast above mean not within range") + def test_augment_contrast_3D(self): mean = np.mean(self.data_3D) From ff83845b6aebafd0d6b9c687a67fdd2e5d648573 Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Wed, 16 Aug 2023 13:11:28 +0300 Subject: [PATCH 14/60] Augment gamma changes --- batchgenerators/augmentations/color_augmentations.py | 4 ++-- batchgenerators/transforms/color_transforms.py | 1 + tests/test_color_augmentations.py | 1 + 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/batchgenerators/augmentations/color_augmentations.py b/batchgenerators/augmentations/color_augmentations.py index e5cac10..6ec94a0 100644 --- a/batchgenerators/augmentations/color_augmentations.py +++ b/batchgenerators/augmentations/color_augmentations.py @@ -112,7 +112,7 @@ def augment_gamma(data_sample, gamma_range=(0.5, 2), invert_image=False, epsilon if retain_stats_here: mn = data_sample.mean() sd = data_sample.std() - if np.random.random() < 0.5 and gamma_range[0] < 1: + if gamma_range[0] < 1 and np.random.random() < 0.5: gamma = np.random.uniform(gamma_range[0], 1) else: gamma = np.random.uniform(max(gamma_range[0], 1), gamma_range[1]) @@ -129,7 +129,7 @@ def augment_gamma(data_sample, gamma_range=(0.5, 2), invert_image=False, epsilon if retain_stats_here: mn = data_sample[c].mean() sd = data_sample[c].std() - if np.random.random() < 0.5 and gamma_range[0] < 1: + if gamma_range[0] < 1 and np.random.random() < 0.5: gamma = np.random.uniform(gamma_range[0], 1) else: gamma = np.random.uniform(max(gamma_range[0], 1), gamma_range[1]) diff --git a/batchgenerators/transforms/color_transforms.py b/batchgenerators/transforms/color_transforms.py index 8f4972e..daa3db3 100644 --- a/batchgenerators/transforms/color_transforms.py +++ b/batchgenerators/transforms/color_transforms.py @@ -155,6 +155,7 @@ def __init__(self, gamma_range=(0.5, 2), invert_image=False, per_channel=False, self.invert_image = invert_image def __call__(self, **data_dict): + # TODO: augment_gamma can be vectorized twice (per channel and per sample) for b in range(len(data_dict[self.data_key])): if np.random.uniform() < self.p_per_sample: data_dict[self.data_key][b] = augment_gamma(data_dict[self.data_key][b], self.gamma_range, diff --git a/tests/test_color_augmentations.py b/tests/test_color_augmentations.py index 7cefda6..901dd6d 100644 --- a/tests/test_color_augmentations.py +++ b/tests/test_color_augmentations.py @@ -172,6 +172,7 @@ def setUp(self): self.data_input_3D = np.random.random((2, 64, 56, 48)) self.data_input_2D = np.random.random((2, 64, 56)) + self.d_3D = augment_gamma(np.copy(self.data_input_2D), gamma_range=(0.2, 1.2), per_channel=True) self.d_3D = augment_gamma(np.copy(self.data_input_2D), gamma_range=(0.2, 1.2), per_channel=False) def test_augment_gamma_3D(self): From b29d74a3b621ad166b6e102ec81a633814a598e1 Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Wed, 16 Aug 2023 14:03:16 +0300 Subject: [PATCH 15/60] improving per channel augment gamma --- .../augmentations/color_augmentations.py | 42 ++++++++++++------- tests/test_color_augmentations.py | 4 +- 2 files changed, 30 insertions(+), 16 deletions(-) diff --git a/batchgenerators/augmentations/color_augmentations.py b/batchgenerators/augmentations/color_augmentations.py index 6ec94a0..ae3d869 100644 --- a/batchgenerators/augmentations/color_augmentations.py +++ b/batchgenerators/augmentations/color_augmentations.py @@ -124,22 +124,36 @@ def augment_gamma(data_sample, gamma_range=(0.5, 2), invert_image=False, epsilon data_sample *= sd / (data_sample.std() + 1e-8) data_sample += mn else: - for c in range(data_sample.shape[0]): - retain_stats_here = retain_stats() if callable(retain_stats) else retain_stats - if retain_stats_here: - mn = data_sample[c].mean() - sd = data_sample[c].std() + shape_0 = data_sample.shape[0] + if callable(retain_stats): + retain_stats_here = np.array(retain_stats() for _ in range(shape_0)) + else: + retain_stats_here = np.array([retain_stats]).repeat(shape_0) + gamma = [] + for i in range(shape_0): if gamma_range[0] < 1 and np.random.random() < 0.5: - gamma = np.random.uniform(gamma_range[0], 1) + gamma.append(np.random.uniform(gamma_range[0], 1)) else: - gamma = np.random.uniform(max(gamma_range[0], 1), gamma_range[1]) - minm = data_sample[c].min() - rnge = data_sample[c].max() - minm - data_sample[c] = np.power(((data_sample[c] - minm) / float(rnge + epsilon)), gamma) * float(rnge + epsilon) + minm - if retain_stats_here: - data_sample[c] -= data_sample[c].mean() - data_sample[c] *= sd / (data_sample[c].std() + 1e-8) - data_sample[c] += mn + gamma.append(np.random.uniform(max(gamma_range[0], 1), gamma_range[1])) + gamma = np.array(gamma) + + axes = tuple(range(1, len(data_sample.shape))) + + retain_any_stats = np.any(retain_stats_here) + if retain_any_stats: + mn = data_sample[retain_stats_here].mean(axis=axes) + sd = data_sample[retain_stats_here].mean(axis=axes) + + minm = data_sample.min(axis=axes) + rnge = data_sample.max(axis=axes) - minm + epsilon + + data_sample = (np.power(((data_sample.T - minm) / rnge), gamma) * rnge + minm).T + + if retain_any_stats: + data_sample[retain_stats_here] = (( + data_sample[retain_stats_here].T - data_sample[retain_stats_here].mean(axis=axes)) * sd / + (data_sample[retain_stats_here].std(axis=axes) + 1e-8) + mn).T + if invert_image: data_sample = - data_sample return data_sample diff --git a/tests/test_color_augmentations.py b/tests/test_color_augmentations.py index 901dd6d..82ab171 100644 --- a/tests/test_color_augmentations.py +++ b/tests/test_color_augmentations.py @@ -172,8 +172,8 @@ def setUp(self): self.data_input_3D = np.random.random((2, 64, 56, 48)) self.data_input_2D = np.random.random((2, 64, 56)) - self.d_3D = augment_gamma(np.copy(self.data_input_2D), gamma_range=(0.2, 1.2), per_channel=True) - self.d_3D = augment_gamma(np.copy(self.data_input_2D), gamma_range=(0.2, 1.2), per_channel=False) + self.d_3D = augment_gamma(np.copy(self.data_input_3D), gamma_range=(0.2, 1.2), per_channel=True, retain_stats=True) + self.d_3D = augment_gamma(np.copy(self.data_input_3D), gamma_range=(0.2, 1.2), per_channel=False, retain_stats=False) def test_augment_gamma_3D(self): self.assertTrue(self.d_3D.min().round(decimals=3) == self.data_input_3D.min().round(decimals=3) and From 26eb25c87525d2286ca63a0b8105bfd262230d11 Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Wed, 16 Aug 2023 14:53:21 +0300 Subject: [PATCH 16/60] Added batched brightness multiplicative transform --- .../augmentations/color_augmentations.py | 16 ++++-- .../transforms/color_transforms.py | 20 +++---- tests/test_color_augmentations.py | 56 +++++++++++++------ 3 files changed, 58 insertions(+), 34 deletions(-) diff --git a/batchgenerators/augmentations/color_augmentations.py b/batchgenerators/augmentations/color_augmentations.py index ae3d869..9621d42 100644 --- a/batchgenerators/augmentations/color_augmentations.py +++ b/batchgenerators/augmentations/color_augmentations.py @@ -90,15 +90,19 @@ def augment_brightness_additive(data_sample, mu:float, sigma:float , per_channel return data_sample -def augment_brightness_multiplicative(data_sample, multiplier_range=(0.5, 2), per_channel=True): +def augment_brightness_multiplicative(data_sample, multiplier_range=(0.5, 2), per_channel=True, batched=False): if not per_channel: - multiplier = np.random.uniform(multiplier_range[0], multiplier_range[1]) + size = data_sample.shape[0] if batched else 1 + axes = tuple(range(1, len(data_sample.shape))) else: - axes = [1 for _ in range(len(data_sample.shape))] - axes[0] = data_sample.shape[0] - multiplier = np.random.uniform(multiplier_range[0], multiplier_range[1], size=axes) + if batched: + size = data_sample.shape[:2] + axes = tuple(range(2, len(data_sample.shape))) + else: + size = data_sample.shape[0] + axes = tuple(range(1, len(data_sample.shape))) - data_sample *= multiplier + data_sample *= np.expand_dims(np.random.uniform(multiplier_range[0], multiplier_range[1], size=size), axis=axes) return data_sample diff --git a/batchgenerators/transforms/color_transforms.py b/batchgenerators/transforms/color_transforms.py index daa3db3..30839c1 100644 --- a/batchgenerators/transforms/color_transforms.py +++ b/batchgenerators/transforms/color_transforms.py @@ -54,11 +54,11 @@ def __init__(self, def __call__(self, **data_dict): mask = np.random.uniform(size=len(data_dict[self.data_key])) < self.p_per_sample data_dict[self.data_key][mask] = augment_contrast(data_dict[self.data_key][mask], - contrast_range=self.contrast_range, - preserve_range=self.preserve_range, - per_channel=self.per_channel, - p_per_channel=self.p_per_channel, - batched=True) + contrast_range=self.contrast_range, + preserve_range=self.preserve_range, + per_channel=self.per_channel, + p_per_channel=self.p_per_channel, + batched=True) return data_dict @@ -121,11 +121,11 @@ def __init__(self, multiplier_range=(0.5, 2), per_channel=True, data_key="data", self.per_channel = per_channel def __call__(self, **data_dict): - for b in range(len(data_dict[self.data_key])): - if np.random.uniform() < self.p_per_sample: - data_dict[self.data_key][b] = augment_brightness_multiplicative(data_dict[self.data_key][b], - self.multiplier_range, - self.per_channel) + mask = np.random.uniform(size=len(data_dict[self.data_key])) < self.p_per_sample + data_dict[self.data_key][mask] = augment_brightness_multiplicative(data_dict[self.data_key][mask], + self.multiplier_range, + self.per_channel, + batched=True) return data_dict diff --git a/tests/test_color_augmentations.py b/tests/test_color_augmentations.py index 82ab171..19000e7 100644 --- a/tests/test_color_augmentations.py +++ b/tests/test_color_augmentations.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from batchgenerators.augmentations.color_augmentations import augment_contrast, augment_brightness_additive,\ +from batchgenerators.augmentations.color_augmentations import augment_contrast, augment_brightness_additive, \ augment_brightness_multiplicative, augment_gamma @@ -28,8 +28,10 @@ def setUp(self): self.factor = (0.75, 1.25) self.data_4D = np.random.random((9, 2, 64, 56, 48)) - self.d_4D = augment_contrast(self.data_4D, contrast_range=self.factor, preserve_range=True, per_channel=True, batched=True) - self.d_4D = augment_contrast(self.data_4D, contrast_range=self.factor, preserve_range=False, per_channel=False, batched=True) + self.d_4D = augment_contrast(self.data_4D, contrast_range=self.factor, preserve_range=True, per_channel=True, + batched=True) + self.d_4D = augment_contrast(self.data_4D, contrast_range=self.factor, preserve_range=False, per_channel=False, + batched=True) self.d_3D = augment_contrast(self.data_3D, contrast_range=self.factor, preserve_range=True, per_channel=True) self.d_3D = augment_contrast(self.data_3D, contrast_range=self.factor, preserve_range=False, per_channel=False) self.d_2D = augment_contrast(self.data_2D, contrast_range=self.factor, preserve_range=True, per_channel=True) @@ -57,7 +59,6 @@ def test_augment_contrast_4D(self): "Augmented contrast above mean not within range") def test_augment_contrast_3D(self): - mean = np.mean(self.data_3D) idx0 = np.where(self.data_3D < mean) # where the data is lower than mean value @@ -78,7 +79,6 @@ def test_augment_contrast_3D(self): "Augmented contrast above mean not within range") def test_augment_contrast_2D(self): - mean = np.mean(self.data_2D) idx0 = np.where(self.data_2D < mean) # where the data is lower than mean value @@ -106,7 +106,7 @@ def setUp(self): self.data_input_3D = np.random.random((2, 64, 56, 48)) self.data_input_2D = np.random.random((2, 64, 56)) self.factor = (0.75, 1.25) - self.multiplier_range = [2,4] + self.multiplier_range = [2, 4] self.d_3D_per_channel = augment_brightness_additive(np.copy(self.data_input_3D), mu=100, sigma=10, per_channel=True) @@ -129,8 +129,8 @@ def setUp(self): multiplier_range=self.multiplier_range, per_channel=False) def test_augment_brightness_additive_3D(self): - add_factor = self.d_3D-self.data_input_3D - self.assertTrue(len(np.unique(add_factor.round(decimals=8)))==1, + add_factor = self.d_3D - self.data_input_3D + self.assertTrue(len(np.unique(add_factor.round(decimals=8))) == 1, "Added brightness factor is not equal for all channels") add_factor = self.d_3D_per_channel - self.data_input_3D @@ -138,8 +138,8 @@ def test_augment_brightness_additive_3D(self): "Added brightness factor is not different for each channels") def test_augment_brightness_additive_2D(self): - add_factor = self.d_2D-self.data_input_2D - self.assertTrue(len(np.unique(add_factor.round(decimals=8)))==1, + add_factor = self.d_2D - self.data_input_2D + self.assertTrue(len(np.unique(add_factor.round(decimals=8))) == 1, "Added brightness factor is not equal for all channels") add_factor = self.d_2D_per_channel - self.data_input_2D @@ -147,23 +147,41 @@ def test_augment_brightness_additive_2D(self): "Added brightness factor is not different for each channels") def test_augment_brightness_multiplicative_3D(self): - mult_factor = self.d_3D_mult/self.data_input_3D - self.assertTrue(len(np.unique(mult_factor.round(decimals=6)))==1, + mult_factor = self.d_3D_mult / self.data_input_3D + self.assertTrue(len(np.unique(mult_factor.round(decimals=6))) == 1, "Multiplied brightness factor is not equal for all channels") - mult_factor = self.d_3D_per_channel_mult/self.data_input_3D + mult_factor = self.d_3D_per_channel_mult / self.data_input_3D self.assertTrue(len(np.unique(mult_factor.round(decimals=6))) == self.data_input_3D.shape[0], "Multiplied brightness factor is not different for each channels") def test_augment_brightness_multiplicative_2D(self): - mult_factor = self.d_2D_mult/self.data_input_2D - self.assertTrue(len(np.unique(mult_factor.round(decimals=6)))==1, + mult_factor = self.d_2D_mult / self.data_input_2D + self.assertTrue(len(np.unique(mult_factor.round(decimals=6))) == 1, "Multiplied brightness factor is not equal for all channels") - mult_factor = self.d_2D_per_channel_mult/self.data_input_2D + mult_factor = self.d_2D_per_channel_mult / self.data_input_2D self.assertTrue(len(np.unique(mult_factor.round(decimals=6))) == self.data_input_2D.shape[0], "Multiplied brightness factor is not different for each channels") + def test_batched_augment_brightness_multiplicative(self): + data = np.random.random((9, 2, 64, 56, 48)) + result_1 = augment_brightness_multiplicative(np.copy(data), + multiplier_range=self.multiplier_range, + per_channel=False, + batched=True) + result_2 = augment_brightness_multiplicative(np.copy(data), + multiplier_range=self.multiplier_range, + per_channel=True, + batched=True) + + mult_factor = result_1 / data + self.assertEqual(len(np.unique(mult_factor.round(decimals=6))), data.shape[0], + "Multiplied brightness factor per sample is not equal for all channels") + + mult_factor = result_2 / data + self.assertEqual(len(np.unique(mult_factor.round(decimals=6))), data.shape[0] * data.shape[1], + "Multiplied brightness factor per sample is not different for all channels") class TestAugmentGamma(unittest.TestCase): @@ -172,8 +190,10 @@ def setUp(self): self.data_input_3D = np.random.random((2, 64, 56, 48)) self.data_input_2D = np.random.random((2, 64, 56)) - self.d_3D = augment_gamma(np.copy(self.data_input_3D), gamma_range=(0.2, 1.2), per_channel=True, retain_stats=True) - self.d_3D = augment_gamma(np.copy(self.data_input_3D), gamma_range=(0.2, 1.2), per_channel=False, retain_stats=False) + self.d_3D = augment_gamma(np.copy(self.data_input_3D), gamma_range=(0.2, 1.2), per_channel=True, + retain_stats=True) + self.d_3D = augment_gamma(np.copy(self.data_input_3D), gamma_range=(0.2, 1.2), per_channel=False, + retain_stats=False) def test_augment_gamma_3D(self): self.assertTrue(self.d_3D.min().round(decimals=3) == self.data_input_3D.min().round(decimals=3) and From 93c08d1dd4790e595caf0a25cd3b9aa14dc1286b Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Wed, 16 Aug 2023 15:07:03 +0300 Subject: [PATCH 17/60] Doing batched augmentation only if batches are not empty --- .../transforms/color_transforms.py | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/batchgenerators/transforms/color_transforms.py b/batchgenerators/transforms/color_transforms.py index 30839c1..4ad3247 100644 --- a/batchgenerators/transforms/color_transforms.py +++ b/batchgenerators/transforms/color_transforms.py @@ -53,12 +53,13 @@ def __init__(self, def __call__(self, **data_dict): mask = np.random.uniform(size=len(data_dict[self.data_key])) < self.p_per_sample - data_dict[self.data_key][mask] = augment_contrast(data_dict[self.data_key][mask], - contrast_range=self.contrast_range, - preserve_range=self.preserve_range, - per_channel=self.per_channel, - p_per_channel=self.p_per_channel, - batched=True) + if np.any(mask): + data_dict[self.data_key][mask] = augment_contrast(data_dict[self.data_key][mask], + contrast_range=self.contrast_range, + preserve_range=self.preserve_range, + per_channel=self.per_channel, + p_per_channel=self.p_per_channel, + batched=True) return data_dict @@ -122,10 +123,11 @@ def __init__(self, multiplier_range=(0.5, 2), per_channel=True, data_key="data", def __call__(self, **data_dict): mask = np.random.uniform(size=len(data_dict[self.data_key])) < self.p_per_sample - data_dict[self.data_key][mask] = augment_brightness_multiplicative(data_dict[self.data_key][mask], - self.multiplier_range, - self.per_channel, - batched=True) + if np.any(mask): + data_dict[self.data_key][mask] = augment_brightness_multiplicative(data_dict[self.data_key][mask], + self.multiplier_range, + self.per_channel, + batched=True) return data_dict From acbd7e9538346713670735884630e3396de682ec Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Wed, 16 Aug 2023 15:27:07 +0300 Subject: [PATCH 18/60] Factored out the setup for multiplicative brightness Using lru_cache for caching tuple creation --- .../augmentations/color_augmentations.py | 39 ++++++++++++------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/batchgenerators/augmentations/color_augmentations.py b/batchgenerators/augmentations/color_augmentations.py index 9621d42..708a7bc 100644 --- a/batchgenerators/augmentations/color_augmentations.py +++ b/batchgenerators/augmentations/color_augmentations.py @@ -14,6 +14,7 @@ # limitations under the License. from builtins import range +from functools import lru_cache from typing import Tuple, Union, Callable import numpy as np @@ -69,7 +70,8 @@ def augment_contrast(data_sample: np.ndarray, return data_sample -def augment_brightness_additive(data_sample, mu:float, sigma:float , per_channel:bool=True, p_per_channel:float=1.): +def augment_brightness_additive(data_sample, mu: float, sigma: float, per_channel: bool = True, + p_per_channel: float = 1.): """ data_sample must have shape (c, x, y(, z))) :param data_sample: @@ -90,18 +92,28 @@ def augment_brightness_additive(data_sample, mu:float, sigma:float , per_channel return data_sample -def augment_brightness_multiplicative(data_sample, multiplier_range=(0.5, 2), per_channel=True, batched=False): - if not per_channel: - size = data_sample.shape[0] if batched else 1 - axes = tuple(range(1, len(data_sample.shape))) - else: - if batched: - size = data_sample.shape[:2] - axes = tuple(range(2, len(data_sample.shape))) +def setup_augment_brightness_multiplicative(per_channel: bool, batched: bool, shape: Tuple[int]): + def get_size(per_channel, batched, shape): + if per_channel: + if batched: + return shape[:2] + return shape[0] else: - size = data_sample.shape[0] - axes = tuple(range(1, len(data_sample.shape))) + if batched: + return shape[0] + return 1 + @lru_cache(maxsize=2) # axes are expected to remain the same + def get_axes(per_channel, batched, n): + if per_channel and batched: + return tuple(range(2, n)) + return tuple(range(1, n)) + + return get_size(per_channel, batched, shape), get_axes(per_channel, batched, len(shape)) + + +def augment_brightness_multiplicative(data_sample, multiplier_range=(0.5, 2), per_channel=True, batched=False): + size, axes = setup_augment_brightness_multiplicative(per_channel, batched, data_sample.shape) data_sample *= np.expand_dims(np.random.uniform(multiplier_range[0], multiplier_range[1], size=size), axis=axes) return data_sample @@ -155,8 +167,9 @@ def augment_gamma(data_sample, gamma_range=(0.5, 2), invert_image=False, epsilon if retain_any_stats: data_sample[retain_stats_here] = (( - data_sample[retain_stats_here].T - data_sample[retain_stats_here].mean(axis=axes)) * sd / - (data_sample[retain_stats_here].std(axis=axes) + 1e-8) + mn).T + data_sample[retain_stats_here].T - data_sample[ + retain_stats_here].mean(axis=axes)) * sd / + (data_sample[retain_stats_here].std(axis=axes) + 1e-8) + mn).T if invert_image: data_sample = - data_sample From bf174e595e760e0fd6cc7abeb7dd210c028af845 Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Wed, 16 Aug 2023 16:14:52 +0300 Subject: [PATCH 19/60] Added batched implementation for Gaussian Noise Transform --- .../augmentations/noise_augmentations.py | 23 +++++++++++++------ .../transforms/noise_transforms.py | 8 +++---- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/batchgenerators/augmentations/noise_augmentations.py b/batchgenerators/augmentations/noise_augmentations.py index 6ba97a3..b63e333 100644 --- a/batchgenerators/augmentations/noise_augmentations.py +++ b/batchgenerators/augmentations/noise_augmentations.py @@ -30,9 +30,7 @@ def augment_rician_noise(data_sample, noise_variance=(0, 0.1)): return data_sample -def augment_gaussian_noise(data_sample: np.ndarray, noise_variance: Tuple[float, float] = (0, 0.1), - p_per_channel: float = 1, per_channel: bool = False) -> np.ndarray: - size = data_sample.shape[0] +def setup_augment_gaussian_noise(noise_variance: Tuple[float, float], per_channel: bool, size: int): if not per_channel: variance = noise_variance[0] if noise_variance[0] == noise_variance[1] else \ random.uniform(noise_variance[0], noise_variance[1]) @@ -40,11 +38,22 @@ def augment_gaussian_noise(data_sample: np.ndarray, noise_variance: Tuple[float, else: variance = np.repeat(noise_variance[0], size) if noise_variance[0] == noise_variance[1] else \ np.random.uniform(noise_variance[0], noise_variance[1], size=size) + return variance + + +def augment_gaussian_noise(data_sample: np.ndarray, noise_variance: Tuple[float, float] = (0, 0.1), + p_per_channel: float = 1, per_channel: bool = False, batched: bool = False) -> np.ndarray: + mask = np.random.uniform(size=data_sample.shape[1 if batched else 0]) < p_per_channel + size = np.count_nonzero(mask) + if size: + if batched: + num_samples = data_sample.shape[0] + mask = np.atleast_2d(mask).repeat(num_samples, axis=0) + size *= num_samples + + variance = setup_augment_gaussian_noise(noise_variance, per_channel, size) + data_sample[mask] += np.random.normal(0.0, variance, data_sample[mask].T.shape).T - for c in range(size): - if np.random.uniform() < p_per_channel: - # bug fixed: https://github.com/MIC-DKFZ/batchgenerators/issues/86 - data_sample[c] += np.random.normal(0.0, variance[c], size=data_sample[c].shape) return data_sample diff --git a/batchgenerators/transforms/noise_transforms.py b/batchgenerators/transforms/noise_transforms.py index 6a30767..be07f84 100644 --- a/batchgenerators/transforms/noise_transforms.py +++ b/batchgenerators/transforms/noise_transforms.py @@ -71,10 +71,10 @@ def __init__(self, noise_variance=(0, 0.1), p_per_sample=1, p_per_channel: float self.per_channel = per_channel def __call__(self, **data_dict): - for b in range(len(data_dict[self.data_key])): - if np.random.uniform() < self.p_per_sample: - data_dict[self.data_key][b] = augment_gaussian_noise(data_dict[self.data_key][b], self.noise_variance, - self.p_per_channel, self.per_channel) + mask = np.random.uniform(size=len(data_dict[self.data_key])) < self.p_per_sample + if np.any(mask): + data_dict[self.data_key][mask] = augment_gaussian_noise(data_dict[self.data_key][mask], self.noise_variance, + self.p_per_channel, self.per_channel, batched=True) return data_dict From 7369a8a808d502ad22cf1b388439743a5f741c83 Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Wed, 16 Aug 2023 17:16:47 +0300 Subject: [PATCH 20/60] Makeup for Gaussian Blur Transform --- .../augmentations/noise_augmentations.py | 24 ++++++++++++------- .../transforms/noise_transforms.py | 1 + 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/batchgenerators/augmentations/noise_augmentations.py b/batchgenerators/augmentations/noise_augmentations.py index b63e333..f82d59a 100644 --- a/batchgenerators/augmentations/noise_augmentations.py +++ b/batchgenerators/augmentations/noise_augmentations.py @@ -60,22 +60,30 @@ def augment_gaussian_noise(data_sample: np.ndarray, noise_variance: Tuple[float, def augment_gaussian_blur(data_sample: np.ndarray, sigma_range: Tuple[float, float], per_channel: bool = True, p_per_channel: float = 1, different_sigma_per_axis: bool = False, p_isotropic: float = 0) -> np.ndarray: + # TODO: Vectorize per channel (gaussian_filter accepts axes) if not per_channel: # Godzilla Had a Stroke Trying to Read This and F***ing Died # https://i.kym-cdn.com/entries/icons/original/000/034/623/Untitled-3.png - sigma = get_range_val(sigma_range) if ((not different_sigma_per_axis) or - ((np.random.uniform() < p_isotropic) and - different_sigma_per_axis)) \ - else [get_range_val(sigma_range) for _ in data_sample.shape[1:]] + # sigma = get_range_val(sigma_range) if ((not different_sigma_per_axis) or + # ((np.random.uniform() < p_isotropic) and + # different_sigma_per_axis)) \ + # else [get_range_val(sigma_range) for _ in data_sample.shape[1:]] + + # Godzilla revived + if not different_sigma_per_axis or np.random.uniform() < p_isotropic: + sigma = get_range_val(sigma_range) + else: + sigma = [get_range_val(sigma_range) for _ in data_sample.shape[1:]] else: sigma = None for c in range(data_sample.shape[0]): if np.random.uniform() <= p_per_channel: if per_channel: - sigma = get_range_val(sigma_range) if ((not different_sigma_per_axis) or - ((np.random.uniform() < p_isotropic) and - different_sigma_per_axis)) \ - else [get_range_val(sigma_range) for _ in data_sample.shape[1:]] + if not different_sigma_per_axis or np.random.uniform() < p_isotropic: + sigma = get_range_val(sigma_range) + else: + sigma = [get_range_val(sigma_range) for _ in data_sample.shape[1:]] + data_sample[c] = gaussian_filter(data_sample[c], sigma, order=0) return data_sample diff --git a/batchgenerators/transforms/noise_transforms.py b/batchgenerators/transforms/noise_transforms.py index be07f84..d3de83b 100644 --- a/batchgenerators/transforms/noise_transforms.py +++ b/batchgenerators/transforms/noise_transforms.py @@ -102,6 +102,7 @@ def __init__(self, blur_sigma: Tuple[float, float] = (1, 5), different_sigma_per self.p_isotropic = p_isotropic def __call__(self, **data_dict): + # TODO: Do batched gaussian blur for b in range(len(data_dict[self.data_key])): if np.random.uniform() < self.p_per_sample: data_dict[self.data_key][b] = augment_gaussian_blur(data_dict[self.data_key][b], self.blur_sigma, From 3020ec54d08025dfc857618cf3d97a918928d668 Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Thu, 17 Aug 2023 12:04:55 +0300 Subject: [PATCH 21/60] Removed unittest2 dependency *unittest2 also has errors --- requirements.txt | 3 +-- setup.py | 1 - tests/test_axis_mirroring.py | 3 +-- 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/requirements.txt b/requirements.txt index f4e40a2..4b72370 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,5 +4,4 @@ scikit-learn numpy>=1.10.2 scipy scikit-image -scikit-learn -unittest2 \ No newline at end of file +scikit-learn \ No newline at end of file diff --git a/setup.py b/setup.py index a62dc17..437b8b3 100755 --- a/setup.py +++ b/setup.py @@ -16,7 +16,6 @@ "scikit-image", "scikit-learn", "future", - "unittest2", "threadpoolctl" ], keywords=['data augmentation', 'deep learning', 'image segmentation', 'image classification', diff --git a/tests/test_axis_mirroring.py b/tests/test_axis_mirroring.py index 78839e4..8710de7 100644 --- a/tests/test_axis_mirroring.py +++ b/tests/test_axis_mirroring.py @@ -14,7 +14,6 @@ # limitations under the License. import unittest -import unittest2 import numpy as np from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter from skimage import data @@ -23,7 +22,7 @@ from batchgenerators.transforms.spatial_transforms import MirrorTransform -class TestMirrorAxis(unittest2.TestCase): +class TestMirrorAxis(unittest.TestCase): def setUp(self): self.seed = 1234 From 7912a621ed9c975174d8a6d2e3b29e6d94ba6bc2 Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Thu, 17 Aug 2023 12:12:12 +0300 Subject: [PATCH 22/60] Improved crop and pad augmentation and spatial transforms --- .../crop_and_pad_augmentations.py | 81 +++++++++---------- .../augmentations/spatial_transformations.py | 16 +--- .../transforms/spatial_transforms.py | 2 +- 3 files changed, 41 insertions(+), 58 deletions(-) diff --git a/batchgenerators/augmentations/crop_and_pad_augmentations.py b/batchgenerators/augmentations/crop_and_pad_augmentations.py index 0728890..16abdf8 100644 --- a/batchgenerators/augmentations/crop_and_pad_augmentations.py +++ b/batchgenerators/augmentations/crop_and_pad_augmentations.py @@ -16,6 +16,7 @@ from builtins import range import numpy as np from batchgenerators.augmentations.utils import pad_nd_image +from typing import Union, Sequence def center_crop(data, crop_size, seg=None): @@ -30,15 +31,11 @@ def get_lbs_for_random_crop(crop_size, data_shape, margins): :param margins: :return: """ - lbs = [] - for i in range(len(data_shape) - 2): - new_shape = data_shape[i+2] - crop_size[i] - margin = margins[i] - if new_shape > 2 * margin: - lbs.append(np.random.randint(margin, new_shape - margin)) - else: - lbs.append(new_shape // 2) - return lbs + new_shape = data_shape - crop_size + mask = new_shape > 2 * margins + new_shape[mask] = np.random.randint(margins[mask], new_shape[mask] - margins[mask]) + new_shape[~mask] //= 2 + return new_shape def get_lbs_for_center_crop(crop_size, data_shape): @@ -47,13 +44,11 @@ def get_lbs_for_center_crop(crop_size, data_shape): :param data_shape: (b,c,x,y(,z)) must be the whole thing! :return: """ - lbs = [] - for i in range(len(data_shape) - 2): - lbs.append((data_shape[i + 2] - crop_size[i]) // 2) - return lbs + return (data_shape - crop_size) // 2 -def crop(data, seg=None, crop_size=128, margins=(0, 0, 0), crop_type="center", +def crop(data: Union[Sequence[np.ndarray], np.ndarray], seg: Union[Sequence[np.ndarray], np.ndarray] = None, + crop_size=128, margins=(0, 0, 0), crop_type="center", pad_mode='constant', pad_kwargs={'constant_values': 0}, pad_mode_seg='constant', pad_kwargs_seg={'constant_values': 0}): """ @@ -71,45 +66,40 @@ def crop(data, seg=None, crop_size=128, margins=(0, 0, 0), crop_type="center", :param crop_type: random or center :return: """ - if not isinstance(data, (list, tuple, np.ndarray)): - raise TypeError("data has to be either a numpy array or a list") - - data_shape = tuple([len(data)] + list(data[0].shape)) + data_shape = (len(data), *data[0].shape) data_dtype = data[0].dtype dim = len(data_shape) - 2 if seg is not None: - seg_shape = tuple([len(seg)] + list(seg[0].shape)) + seg_shape = (len(seg), *seg[0].shape) seg_dtype = seg[0].dtype - if not isinstance(seg, (list, tuple, np.ndarray)): - raise TypeError("data has to be either a numpy array or a list") - - assert all([i == j for i, j in zip(seg_shape[2:], data_shape[2:])]), "data and seg must have the same spatial " \ - "dimensions. Data: %s, seg: %s" % \ - (str(data_shape), str(seg_shape)) + assert np.array_equal(seg_shape[2:], data_shape[2:]), "data and seg must have the same spatial dimensions. " \ + f"Data: {data_shape}, seg: {seg_shape}" if type(crop_size) not in (tuple, list, np.ndarray): - crop_size = (crop_size, ) * dim + crop_size = (crop_size,) * dim else: - assert len(crop_size) == len( - data_shape) - 2, "If you provide a list/tuple as center crop make sure it has the same dimension as your " \ - "data (2d/3d)" - crop_size = tuple(crop_size) + assert len(crop_size) == dim, ("If you provide a list/tuple as center crop make sure it has the same dimension " + "as your data (2d/3d)") + crop_size = np.array(crop_size) if not isinstance(margins, (np.ndarray, tuple, list)): - margins = [margins] * dim + margins = (margins,) * dim + margins = np.array(margins) - data_return = np.zeros((data_shape[0], data_shape[1]) + crop_size, dtype=data_dtype) + data_return = np.zeros((data_shape[0], data_shape[1], *crop_size), dtype=data_dtype) if seg is not None: - seg_return = np.zeros((seg_shape[0], seg_shape[1]) + crop_size, dtype=seg_dtype) + seg_return = np.zeros((seg_shape[0], seg_shape[1], *crop_size), dtype=seg_dtype) else: seg_return = None + for b in range(data_shape[0]): - data_shape_here = [data_shape[0]] + list(data[b].shape) + data_first_dim = data[b].shape[0] + data_shape_here = np.array(data[b].shape[1:]) if seg is not None: - seg_shape_here = [seg_shape[0]] + list(seg[b].shape) + seg_first_dim = seg[b].shape[0] if crop_type == "center": lbs = get_lbs_for_center_crop(crop_size, data_shape_here) @@ -118,22 +108,23 @@ def crop(data, seg=None, crop_size=128, margins=(0, 0, 0), crop_type="center", else: raise NotImplementedError("crop_type must be either center or random") - need_to_pad = [[0, 0]] + [[abs(min(0, lbs[d])), - abs(min(0, data_shape_here[d + 2] - (lbs[d] + crop_size[d])))] - for d in range(dim)] + zero = np.zeros(dim, dtype=int) + temp1 = np.abs(np.minimum(lbs, zero)) + temp2 = np.abs(np.minimum(zero, data_shape_here - lbs - crop_size)) + need_to_pad = np.array(((0, 0), *zip(temp1, temp2))) # we should crop first, then pad -> reduces i/o for memmaps, reduces RAM usage and improves speed - ubs = [min(lbs[d] + crop_size[d], data_shape_here[d+2]) for d in range(dim)] - lbs = [max(0, lbs[d]) for d in range(dim)] + ubs = np.minimum(data_shape_here, lbs + crop_size) + lbs = np.maximum(zero, lbs) - slicer_data = [slice(0, data_shape_here[1])] + [slice(lbs[d], ubs[d]) for d in range(dim)] - data_cropped = data[b][tuple(slicer_data)] + slicer_data = (slice(0, data_first_dim), *(slice(lbs[d], ubs[d]) for d in range(dim))) + data_cropped = data[b][slicer_data] if seg_return is not None: - slicer_seg = [slice(0, seg_shape_here[1])] + [slice(lbs[d], ubs[d]) for d in range(dim)] - seg_cropped = seg[b][tuple(slicer_seg)] + slicer_data = (slice(0, seg_first_dim), *(slice(lbs[d], ubs[d]) for d in range(dim))) + seg_cropped = seg[b][slicer_data] - if any([i > 0 for j in need_to_pad for i in j]): + if np.any(need_to_pad): data_return[b] = np.pad(data_cropped, need_to_pad, pad_mode, **pad_kwargs) if seg_return is not None: seg_return[b] = np.pad(seg_cropped, need_to_pad, pad_mode_seg, **pad_kwargs_seg) diff --git a/batchgenerators/augmentations/spatial_transformations.py b/batchgenerators/augmentations/spatial_transformations.py index fe0677b..77b1e3b 100644 --- a/batchgenerators/augmentations/spatial_transformations.py +++ b/batchgenerators/augmentations/spatial_transformations.py @@ -194,17 +194,9 @@ def augment_spatial(data, seg, patch_size, patch_center_dist_from_border=30, dim = len(patch_size) seg_result = None if seg is not None: - if dim == 2: - seg_result = np.zeros((seg.shape[0], seg.shape[1], patch_size[0], patch_size[1]), dtype=np.float32) - else: - seg_result = np.zeros((seg.shape[0], seg.shape[1], patch_size[0], patch_size[1], patch_size[2]), - dtype=np.float32) + seg_result = np.zeros((seg.shape[0], seg.shape[1], *patch_size), dtype=np.float32) - if dim == 2: - data_result = np.zeros((data.shape[0], data.shape[1], patch_size[0], patch_size[1]), dtype=np.float32) - else: - data_result = np.zeros((data.shape[0], data.shape[1], patch_size[0], patch_size[1], patch_size[2]), - dtype=np.float32) + data_result = np.zeros((data.shape[0], data.shape[1], *patch_size), dtype=np.float32) if not isinstance(patch_center_dist_from_border, (list, tuple, np.ndarray)): patch_center_dist_from_border = dim * [patch_center_dist_from_border] @@ -246,12 +238,12 @@ def augment_spatial(data, seg, patch_size, patch_center_dist_from_border=30, if independent_scale_for_each_axis and np.random.uniform() < p_independent_scale_per_axis: sc = [] for _ in range(dim): - if np.random.random() < 0.5 and scale[0] < 1: + if scale[0] < 1 and np.random.random() < 0.5: sc.append(np.random.uniform(scale[0], 1)) else: sc.append(np.random.uniform(max(scale[0], 1), scale[1])) else: - if np.random.random() < 0.5 and scale[0] < 1: + if scale[0] < 1 and np.random.random() < 0.5: sc = np.random.uniform(scale[0], 1) else: sc = np.random.uniform(max(scale[0], 1), scale[1]) diff --git a/batchgenerators/transforms/spatial_transforms.py b/batchgenerators/transforms/spatial_transforms.py index 5d0b208..ed27501 100644 --- a/batchgenerators/transforms/spatial_transforms.py +++ b/batchgenerators/transforms/spatial_transforms.py @@ -308,7 +308,7 @@ def __init__(self, patch_size, patch_center_dist_from_border=30, self.p_el_per_sample = p_el_per_sample self.data_key = data_key self.label_key = label_key - self.patch_size = patch_size + self.patch_size = tuple(patch_size) self.patch_center_dist_from_border = patch_center_dist_from_border self.do_elastic_deform = do_elastic_deform self.alpha = alpha From a821dee1dae33cdc47ed131cab61252cf542a872 Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Thu, 17 Aug 2023 12:16:57 +0300 Subject: [PATCH 23/60] Improving resample augmentation and resample transform * also adding minor improvements to utils functions (reformatting file, using lru_cache where possible) --- .../augmentations/resample_augmentations.py | 8 +- batchgenerators/augmentations/utils.py | 179 +++++++++--------- .../transforms/resample_transforms.py | 3 + 3 files changed, 95 insertions(+), 95 deletions(-) diff --git a/batchgenerators/augmentations/resample_augmentations.py b/batchgenerators/augmentations/resample_augmentations.py index 4b0d8a0..ce13132 100644 --- a/batchgenerators/augmentations/resample_augmentations.py +++ b/batchgenerators/augmentations/resample_augmentations.py @@ -50,9 +50,6 @@ def augment_linear_downsampling_scipy(data_sample, zoom_range=(0.5, 1), per_chan ignore_axes: tuple/list ''' - if not isinstance(zoom_range, (list, tuple, np.ndarray)): - zoom_range = [zoom_range] - shp = np.array(data_sample.shape[1:]) dim = len(shp) @@ -70,7 +67,7 @@ def augment_linear_downsampling_scipy(data_sample, zoom_range=(0.5, 1), per_chan target_shape[i] = shp[i] if channels is None: - channels = list(range(data_sample.shape[0])) + channels = range(data_sample.shape[0]) for c in channels: if np.random.uniform() < p_per_channel: @@ -88,8 +85,7 @@ def augment_linear_downsampling_scipy(data_sample, zoom_range=(0.5, 1), per_chan downsampled = resize(data_sample[c].astype(float), target_shape, order=order_downsample, mode='edge', anti_aliasing=False) - data_sample[c] = resize(downsampled, shp, order=order_upsample, mode='edge', - anti_aliasing=False) + data_sample[c] = resize(downsampled, shp, order=order_upsample, mode='edge', anti_aliasing=False) return data_sample diff --git a/batchgenerators/augmentations/utils.py b/batchgenerators/augmentations/utils.py index 831aac2..66c71f2 100755 --- a/batchgenerators/augmentations/utils.py +++ b/batchgenerators/augmentations/utils.py @@ -14,6 +14,8 @@ # limitations under the License. import random +from functools import lru_cache + import numpy as np from copy import deepcopy from scipy.ndimage import map_coordinates, fourier_gaussian @@ -34,10 +36,10 @@ def generate_elastic_transform_coordinates(shape, alpha, sigma): return indices +@lru_cache(maxsize=None) # we should have only 1 hit def create_zero_centered_coordinate_mesh(shape): - tmp = tuple([np.arange(i) for i in shape]) - coords = np.array(np.meshgrid(*tmp, indexing='ij')).astype(float) - to_add = ((np.array(shape).astype(float) - 1) / 2.) + coords = np.array(np.meshgrid(*(np.arange(i) for i in shape), indexing='ij'), dtype=float) + to_add = (np.array(shape, dtype=float) - 1) / 2. for d in range(len(shape)): coords[d] -= to_add[d] return coords @@ -52,7 +54,7 @@ def convert_seg_image_to_one_hot_encoding(image, classes=None): ''' if classes is None: classes = np.unique(image) - out_image = np.zeros([len(classes)]+list(image.shape), dtype=image.dtype) + out_image = np.zeros([len(classes)] + list(image.shape), dtype=image.dtype) for i, c in enumerate(classes): out_image[i][image == c] = 1 return out_image @@ -72,9 +74,8 @@ def convert_seg_image_to_one_hot_encoding_batched(image, classes=None): def elastic_deform_coordinates(coordinates, alpha, sigma): - n_dim = len(coordinates) offsets = [] - for _ in range(n_dim): + for _ in range(len(coordinates)): offsets.append( gaussian_filter((np.random.random(coordinates.shape[1:]) * 2 - 1), sigma, mode="constant", cval=0) * alpha) offsets = np.array(offsets) @@ -125,11 +126,9 @@ def rotate_coords_2d(coords, angle): return coords -def scale_coords(coords, scale): +def scale_coords(coords: np.ndarray, scale): if isinstance(scale, (tuple, list, np.ndarray)): - assert len(scale) == len(coords) - for i in range(len(scale)): - coords[i] *= scale[i] + coords = (coords.T * scale).T else: coords *= scale return coords @@ -485,91 +484,90 @@ def general_cc_var_num_channels(img, diff_order=0, mink_norm=1, sigma=1, mask_im output_img[c] /= white_colors[c] if clip_range: - np.clip(output_img, minm, maxm, out= output_img) + np.clip(output_img, minm, maxm, out=output_img) return white_colors, output_img -def convert_seg_to_bounding_box_coordinates(data_dict, dim, get_rois_from_seg_flag=False, class_specific_seg_flag=False): - - ''' - This function generates bounding box annotations from given pixel-wise annotations. - :param data_dict: Input data dictionary as returned by the batch generator. - :param dim: Dimension in which the model operates (2 or 3). - :param get_rois_from_seg: Flag specifying one of the following scenarios: - 1. A label map with individual ROIs identified by increasing label values, accompanied by a vector containing - in each position the class target for the lesion with the corresponding label (set flag to False) - 2. A binary label map. There is only one foreground class and single lesions are not identified. - All lesions have the same class target (foreground). In this case the Dataloader runs a Connected Component - Labelling algorithm to create processable lesion - class target pairs on the fly (set flag to True). - :param class_specific_seg_flag: if True, returns the pixelwise-annotations in class specific manner, - e.g. a multi-class label map. If False, returns a binary annotation map (only foreground vs. background). - :return: data_dict: same as input, with additional keys: - - 'bb_target': bounding box coordinates (b, n_boxes, (y1, x1, y2, x2, (z1), (z2))) - - 'roi_labels': corresponding class labels for each box (b, n_boxes, class_label) - - 'roi_masks': corresponding binary segmentation mask for each lesion (box). Only used in Mask RCNN. (b, n_boxes, y, x, (z)) - - 'seg': now label map (see class_specific_seg_flag) - ''' - - bb_target = [] - roi_masks = [] - roi_labels = [] - out_seg = np.copy(data_dict['seg']) - for b in range(data_dict['seg'].shape[0]): - - p_coords_list = [] - p_roi_masks_list = [] - p_roi_labels_list = [] - - if np.sum(data_dict['seg'][b]!=0) > 0: - if get_rois_from_seg_flag: - clusters, n_cands = lb(data_dict['seg'][b]) - data_dict['class_target'][b] = [data_dict['class_target'][b]] * n_cands - else: - n_cands = int(np.max(data_dict['seg'][b])) - clusters = data_dict['seg'][b] - - rois = np.array([(clusters == ii) * 1 for ii in range(1, n_cands + 1)]) # separate clusters and concat - for rix, r in enumerate(rois): - if np.sum(r !=0) > 0: #check if the lesion survived data augmentation - seg_ixs = np.argwhere(r != 0) - coord_list = [np.min(seg_ixs[:, 1])-1, np.min(seg_ixs[:, 2])-1, np.max(seg_ixs[:, 1])+1, - np.max(seg_ixs[:, 2])+1] - if dim == 3: - - coord_list.extend([np.min(seg_ixs[:, 3])-1, np.max(seg_ixs[:, 3])+1]) - - p_coords_list.append(coord_list) - p_roi_masks_list.append(r) - # add background class = 0. rix is a patient wide index of lesions. since 'class_target' is - # also patient wide, this assignment is not dependent on patch occurrances. - p_roi_labels_list.append(data_dict['class_target'][b][rix] + 1) - - if class_specific_seg_flag: - out_seg[b][data_dict['seg'][b] == rix + 1] = data_dict['class_target'][b][rix] + 1 - - if not class_specific_seg_flag: - out_seg[b][data_dict['seg'][b] > 0] = 1 - - bb_target.append(np.array(p_coords_list)) - roi_masks.append(np.array(p_roi_masks_list).astype('uint8')) - roi_labels.append(np.array(p_roi_labels_list)) +def convert_seg_to_bounding_box_coordinates(data_dict, dim, get_rois_from_seg_flag=False, + class_specific_seg_flag=False): + ''' + This function generates bounding box annotations from given pixel-wise annotations. + :param data_dict: Input data dictionary as returned by the batch generator. + :param dim: Dimension in which the model operates (2 or 3). + :param get_rois_from_seg: Flag specifying one of the following scenarios: + 1. A label map with individual ROIs identified by increasing label values, accompanied by a vector containing + in each position the class target for the lesion with the corresponding label (set flag to False) + 2. A binary label map. There is only one foreground class and single lesions are not identified. + All lesions have the same class target (foreground). In this case the Dataloader runs a Connected Component + Labelling algorithm to create processable lesion - class target pairs on the fly (set flag to True). + :param class_specific_seg_flag: if True, returns the pixelwise-annotations in class specific manner, + e.g. a multi-class label map. If False, returns a binary annotation map (only foreground vs. background). + :return: data_dict: same as input, with additional keys: + - 'bb_target': bounding box coordinates (b, n_boxes, (y1, x1, y2, x2, (z1), (z2))) + - 'roi_labels': corresponding class labels for each box (b, n_boxes, class_label) + - 'roi_masks': corresponding binary segmentation mask for each lesion (box). Only used in Mask RCNN. (b, n_boxes, y, x, (z)) + - 'seg': now label map (see class_specific_seg_flag) + ''' + bb_target = [] + roi_masks = [] + roi_labels = [] + out_seg = np.copy(data_dict['seg']) + for b in range(data_dict['seg'].shape[0]): + p_coords_list = [] + p_roi_masks_list = [] + p_roi_labels_list = [] + + if np.sum(data_dict['seg'][b] != 0) > 0: + if get_rois_from_seg_flag: + clusters, n_cands = lb(data_dict['seg'][b]) + data_dict['class_target'][b] = [data_dict['class_target'][b]] * n_cands else: - bb_target.append([]) - roi_masks.append(np.zeros_like(data_dict['seg'][b])[None]) - roi_labels.append(np.array([-1])) + n_cands = int(np.max(data_dict['seg'][b])) + clusters = data_dict['seg'][b] + + rois = np.array([(clusters == ii) * 1 for ii in range(1, n_cands + 1)]) # separate clusters and concat + for rix, r in enumerate(rois): + if np.sum(r != 0) > 0: # check if the lesion survived data augmentation + seg_ixs = np.argwhere(r != 0) + coord_list = [np.min(seg_ixs[:, 1]) - 1, np.min(seg_ixs[:, 2]) - 1, np.max(seg_ixs[:, 1]) + 1, + np.max(seg_ixs[:, 2]) + 1] + if dim == 3: + coord_list.extend([np.min(seg_ixs[:, 3]) - 1, np.max(seg_ixs[:, 3]) + 1]) + + p_coords_list.append(coord_list) + p_roi_masks_list.append(r) + # add background class = 0. rix is a patient wide index of lesions. since 'class_target' is + # also patient wide, this assignment is not dependent on patch occurrances. + p_roi_labels_list.append(data_dict['class_target'][b][rix] + 1) + + if class_specific_seg_flag: + out_seg[b][data_dict['seg'][b] == rix + 1] = data_dict['class_target'][b][rix] + 1 + + if not class_specific_seg_flag: + out_seg[b][data_dict['seg'][b] > 0] = 1 + + bb_target.append(np.array(p_coords_list)) + roi_masks.append(np.array(p_roi_masks_list).astype('uint8')) + roi_labels.append(np.array(p_roi_labels_list)) + + + else: + bb_target.append([]) + roi_masks.append(np.zeros_like(data_dict['seg'][b])[None]) + roi_labels.append(np.array([-1])) - if get_rois_from_seg_flag: - data_dict.pop('class_target', None) + if get_rois_from_seg_flag: + data_dict.pop('class_target', None) - data_dict['bb_target'] = np.array(bb_target) - data_dict['roi_masks'] = np.array(roi_masks) - data_dict['class_target'] = np.array(roi_labels) - data_dict['seg'] = out_seg + data_dict['bb_target'] = np.array(bb_target) + data_dict['roi_masks'] = np.array(roi_masks) + data_dict['class_target'] = np.array(roi_labels) + data_dict['seg'] = out_seg - return data_dict + return data_dict def transpose_channels(batch): @@ -595,13 +593,15 @@ def resize_segmentation(segmentation, new_shape, order=3): unique_labels = np.unique(segmentation) assert len(segmentation.shape) == len(new_shape), "new shape must have same dimensionality as segmentation" if order == 0: - return resize(segmentation.astype(float), new_shape, order, mode="edge", clip=True, anti_aliasing=False).astype(tpe) + return resize(segmentation.astype(float), new_shape, order, mode="edge", clip=True, anti_aliasing=False).astype( + tpe) else: reshaped = np.zeros(new_shape, dtype=segmentation.dtype) for c in unique_labels: mask = segmentation == c - reshaped_multihot = resize(mask.astype(float), new_shape, order, mode="edge", clip=True, anti_aliasing=False) + reshaped_multihot = resize(mask.astype(float), new_shape, order, mode="edge", clip=True, + anti_aliasing=False) reshaped[reshaped_multihot >= 0.5] = c return reshaped @@ -660,7 +660,8 @@ def uniform(low, high, size=None): return np.random.uniform(low, high, size) -def pad_nd_image(image, new_shape=None, mode="constant", kwargs=None, return_slicer=False, shape_must_be_divisible_by=None): +def pad_nd_image(image, new_shape=None, mode="constant", kwargs=None, return_slicer=False, + shape_must_be_divisible_by=None): """ one padder to pad them all. Documentation? Well okay. A little bit @@ -710,7 +711,7 @@ def pad_nd_image(image, new_shape=None, mode="constant", kwargs=None, return_sli difference = new_shape - old_shape pad_below = difference // 2 pad_above = pad_below + difference % 2 - pad_list = [[0, 0]]*num_axes_nopad + list([list(i) for i in zip(pad_below, pad_above)]) + pad_list = [[0, 0]] * num_axes_nopad + list([list(i) for i in zip(pad_below, pad_above)]) if np.any(pad_below) or np.any(pad_above): res = np.pad(image, pad_list, mode, **kwargs) diff --git a/batchgenerators/transforms/resample_transforms.py b/batchgenerators/transforms/resample_transforms.py index 5b00c0c..bc6cf09 100644 --- a/batchgenerators/transforms/resample_transforms.py +++ b/batchgenerators/transforms/resample_transforms.py @@ -57,6 +57,9 @@ def __init__(self, zoom_range=(0.5, 1), per_channel=False, p_per_channel=1, self.p_per_channel = p_per_channel self.p_per_sample = p_per_sample self.data_key = data_key + assert isinstance(zoom_range, (tuple, list, np.ndarray)) + assert (len(zoom_range) == 2 or isinstance(zoom_range[0], (tuple, list, np.ndarray)) and + all(len(zoom) == 2 for zoom in zoom_range)) self.zoom_range = zoom_range self.ignore_axes = ignore_axes From af70e892f1161fcc2e96f0fd8120b49b15d31a99 Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Thu, 17 Aug 2023 12:39:19 +0300 Subject: [PATCH 24/60] Misc changes for lru_cache to take effect --- .../augmentations/color_augmentations.py | 30 ++++++++++--------- batchgenerators/augmentations/utils.py | 4 +-- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/batchgenerators/augmentations/color_augmentations.py b/batchgenerators/augmentations/color_augmentations.py index 708a7bc..d061059 100644 --- a/batchgenerators/augmentations/color_augmentations.py +++ b/batchgenerators/augmentations/color_augmentations.py @@ -92,23 +92,25 @@ def augment_brightness_additive(data_sample, mu: float, sigma: float, per_channe return data_sample -def setup_augment_brightness_multiplicative(per_channel: bool, batched: bool, shape: Tuple[int]): - def get_size(per_channel, batched, shape): - if per_channel: - if batched: - return shape[:2] +def get_size(per_channel, batched, shape): + if per_channel: + if batched: + return shape[:2] + return shape[0] + else: + if batched: return shape[0] - else: - if batched: - return shape[0] - return 1 + return 1 + - @lru_cache(maxsize=2) # axes are expected to remain the same - def get_axes(per_channel, batched, n): - if per_channel and batched: - return tuple(range(2, n)) - return tuple(range(1, n)) +@lru_cache(maxsize=None) # There will be only 1 miss, using maxsize None to remove locking and checks. +def get_axes(per_channel, batched, n): + if per_channel and batched: + return tuple(range(2, n)) + return tuple(range(1, n)) + +def setup_augment_brightness_multiplicative(per_channel: bool, batched: bool, shape: Tuple[int]): return get_size(per_channel, batched, shape), get_axes(per_channel, batched, len(shape)) diff --git a/batchgenerators/augmentations/utils.py b/batchgenerators/augmentations/utils.py index 66c71f2..3570aa5 100755 --- a/batchgenerators/augmentations/utils.py +++ b/batchgenerators/augmentations/utils.py @@ -36,7 +36,7 @@ def generate_elastic_transform_coordinates(shape, alpha, sigma): return indices -@lru_cache(maxsize=None) # we should have only 1 hit +@lru_cache(maxsize=None) # There will be only 1 miss, using maxsize None to remove locking and checks. def create_zero_centered_coordinate_mesh(shape): coords = np.array(np.meshgrid(*(np.arange(i) for i in shape), indexing='ij'), dtype=float) to_add = (np.array(shape, dtype=float) - 1) / 2. @@ -66,7 +66,7 @@ def convert_seg_image_to_one_hot_encoding_batched(image, classes=None): ''' if classes is None: classes = np.unique(image) - output_shape = (image.shape[0], len(classes)) + image.shape[1:] + output_shape = (image.shape[0], len(classes), *image.shape[1:]) out_image = np.zeros(output_shape, dtype=image.dtype) for i, c in enumerate(classes): out_image[:, i][image == c] = 1 From 8d7e7ac8d0c4f37b9b6ea502b7f533f9986b91cd Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Thu, 17 Aug 2023 13:11:16 +0300 Subject: [PATCH 25/60] Misc improvements to spatial transform --- .../augmentations/spatial_transformations.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/batchgenerators/augmentations/spatial_transformations.py b/batchgenerators/augmentations/spatial_transformations.py index 77b1e3b..a857004 100644 --- a/batchgenerators/augmentations/spatial_transformations.py +++ b/batchgenerators/augmentations/spatial_transformations.py @@ -199,7 +199,8 @@ def augment_spatial(data, seg, patch_size, patch_center_dist_from_border=30, data_result = np.zeros((data.shape[0], data.shape[1], *patch_size), dtype=np.float32) if not isinstance(patch_center_dist_from_border, (list, tuple, np.ndarray)): - patch_center_dist_from_border = dim * [patch_center_dist_from_border] + patch_center_dist_from_border = (patch_center_dist_from_border,) * dim + patch_center_dist_from_border = np.array(patch_center_dist_from_border) for sample_id in range(data.shape[0]): coords = create_zero_centered_coordinate_mesh(patch_size) @@ -253,13 +254,14 @@ def augment_spatial(data, seg, patch_size, patch_center_dist_from_border=30, # now find a nice center location if modified_coords: + data_shape_here = np.array(data.shape[2:]) + if random_crop: + ctr = np.random.uniform(patch_center_dist_from_border, data_shape_here - patch_center_dist_from_border) + else: + ctr = data_shape_here / 2. - 0.5 for d in range(dim): - if random_crop: - ctr = np.random.uniform(patch_center_dist_from_border[d], - data.shape[d + 2] - patch_center_dist_from_border[d]) - else: - ctr = data.shape[d + 2] / 2. - 0.5 - coords[d] += ctr + coords[d] += ctr[d] + for channel_id in range(data.shape[1]): data_result[sample_id, channel_id] = interpolate_img(data[sample_id, channel_id], coords, order_data, border_mode_data, cval=border_cval_data) @@ -274,7 +276,7 @@ def augment_spatial(data, seg, patch_size, patch_center_dist_from_border=30, else: s = seg[sample_id:sample_id + 1] if random_crop: - margin = [patch_center_dist_from_border[d] - patch_size[d] // 2 for d in range(dim)] + margin = patch_center_dist_from_border - np.array(patch_size) // 2 d, s = random_crop_aug(data[sample_id:sample_id + 1], s, patch_size, margin) else: d, s = center_crop_aug(data[sample_id:sample_id + 1], patch_size, s) From 5c0794f6d5a2a9e52ed4e3b75a5532104c50ef5d Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Thu, 17 Aug 2023 16:44:52 +0300 Subject: [PATCH 26/60] Improving vectorized computation by broadcasting lower dimensional operands instead of transposing the higher dimensional ones --- .../augmentations/color_augmentations.py | 34 ++++++++++++------- .../crop_and_pad_augmentations.py | 9 ++--- .../augmentations/normalizations.py | 25 +++++++++----- .../augmentations/spatial_transformations.py | 5 ++- batchgenerators/augmentations/utils.py | 30 ++++++++++++---- 5 files changed, 72 insertions(+), 31 deletions(-) diff --git a/batchgenerators/augmentations/color_augmentations.py b/batchgenerators/augmentations/color_augmentations.py index d061059..1a563b9 100644 --- a/batchgenerators/augmentations/color_augmentations.py +++ b/batchgenerators/augmentations/color_augmentations.py @@ -18,7 +18,8 @@ from typing import Tuple, Union, Callable import numpy as np -from batchgenerators.augmentations.utils import general_cc_var_num_channels, illumination_jitter +from batchgenerators.augmentations.utils import general_cc_var_num_channels, illumination_jitter, get_broadcast_axes, \ + reverse_broadcast def augment_contrast(data_sample: np.ndarray, @@ -52,10 +53,11 @@ def augment_contrast(data_sample: np.ndarray, factor = np.random.uniform(max(contrast_range[0], 1), contrast_range[1]) mask = np.random.uniform(size=size) < p_per_channel - if batched: - mask = np.atleast_2d(mask).repeat(data_sample.shape[0], axis=0) - workon = data_sample[mask] - if len(workon) > 0: + if np.any(mask): + if batched: + mask = np.atleast_2d(mask).repeat(data_sample.shape[0], axis=0) + + workon = data_sample[mask] axes = tuple(range(1, len(workon.shape))) mean = workon.mean(axis=axes) if preserve_range: @@ -65,7 +67,10 @@ def augment_contrast(data_sample: np.ndarray, data_sample[mask] = (workon.T * factor + mean * (1 - factor)).T # writing directly in data_sample if preserve_range: - np.clip(data_sample[mask].T, minm, maxm, out=data_sample[mask].T) + broadcast_axes = get_broadcast_axes(len(workon.shape)) + minm = reverse_broadcast(minm, broadcast_axes) + maxm = reverse_broadcast(maxm, broadcast_axes) + np.clip(data_sample[mask], minm, maxm, out=data_sample[mask]) return data_sample @@ -103,7 +108,6 @@ def get_size(per_channel, batched, shape): return 1 -@lru_cache(maxsize=None) # There will be only 1 miss, using maxsize None to remove locking and checks. def get_axes(per_channel, batched, n): if per_channel and batched: return tuple(range(2, n)) @@ -165,13 +169,19 @@ def augment_gamma(data_sample, gamma_range=(0.5, 2), invert_image=False, epsilon minm = data_sample.min(axis=axes) rnge = data_sample.max(axis=axes) - minm + epsilon - data_sample = (np.power(((data_sample.T - minm) / rnge), gamma) * rnge + minm).T + # aux = (np.power(((data_sample.T - minm) / rnge), gamma) * rnge + minm).T # This is slower + broadcast_axes = get_broadcast_axes(len(data_sample.shape)) + minm = reverse_broadcast(minm, broadcast_axes) + rnge = reverse_broadcast(rnge, broadcast_axes) + gamma = reverse_broadcast(gamma, broadcast_axes) + data_sample = np.power((data_sample - minm) / rnge, gamma) * rnge + minm if retain_any_stats: - data_sample[retain_stats_here] = (( - data_sample[retain_stats_here].T - data_sample[ - retain_stats_here].mean(axis=axes)) * sd / - (data_sample[retain_stats_here].std(axis=axes) + 1e-8) + mn).T + data_sample[retain_stats_here] -= reverse_broadcast( + data_sample[retain_stats_here].mean(axis=axes), broadcast_axes) + data_sample[retain_stats_here] *= reverse_broadcast( + sd / (data_sample[retain_stats_here].std(axis=axes) + 1e-8), broadcast_axes) + data_sample[retain_stats_here] += reverse_broadcast(mn, broadcast_axes) if invert_image: data_sample = - data_sample diff --git a/batchgenerators/augmentations/crop_and_pad_augmentations.py b/batchgenerators/augmentations/crop_and_pad_augmentations.py index 16abdf8..2ee9c15 100644 --- a/batchgenerators/augmentations/crop_and_pad_augmentations.py +++ b/batchgenerators/augmentations/crop_and_pad_augmentations.py @@ -41,7 +41,7 @@ def get_lbs_for_random_crop(crop_size, data_shape, margins): def get_lbs_for_center_crop(crop_size, data_shape): """ :param crop_size: - :param data_shape: (b,c,x,y(,z)) must be the whole thing! + :param data_shape: (b,c,x,y(,z)) must be the only x,y(,z)! :return: """ return (data_shape - crop_size) // 2 @@ -110,18 +110,19 @@ def crop(data: Union[Sequence[np.ndarray], np.ndarray], seg: Union[Sequence[np.n zero = np.zeros(dim, dtype=int) temp1 = np.abs(np.minimum(lbs, zero)) - temp2 = np.abs(np.minimum(zero, data_shape_here - lbs - crop_size)) + lbs_plus_crop_size = lbs + crop_size + temp2 = np.abs(np.minimum(zero, data_shape_here - lbs_plus_crop_size)) need_to_pad = np.array(((0, 0), *zip(temp1, temp2))) # we should crop first, then pad -> reduces i/o for memmaps, reduces RAM usage and improves speed - ubs = np.minimum(data_shape_here, lbs + crop_size) + ubs = np.minimum(data_shape_here, lbs_plus_crop_size) lbs = np.maximum(zero, lbs) slicer_data = (slice(0, data_first_dim), *(slice(lbs[d], ubs[d]) for d in range(dim))) data_cropped = data[b][slicer_data] if seg_return is not None: - slicer_data = (slice(0, seg_first_dim), *(slice(lbs[d], ubs[d]) for d in range(dim))) + slicer_data = (slice(0, seg_first_dim), *slicer_data[1:]) seg_cropped = seg[b][slicer_data] if np.any(need_to_pad): diff --git a/batchgenerators/augmentations/normalizations.py b/batchgenerators/augmentations/normalizations.py index 8ed157c..029a288 100644 --- a/batchgenerators/augmentations/normalizations.py +++ b/batchgenerators/augmentations/normalizations.py @@ -15,6 +15,8 @@ import numpy as np +from batchgenerators.augmentations.utils import get_broadcast_axes, reverse_broadcast + def range_normalization(data, rnge=(0, 1), per_channel=True, eps=1e-8): if per_channel: @@ -32,7 +34,12 @@ def min_max_normalization_batched(data, eps, axes): mn = data.min(axis=axes) mx = data.max(axis=axes) old_range = mx - mn + eps + data_normalized = ((data.T - mn.T) / old_range.T).T + # broadcast_axes = get_broadcast_axes(len(data.shape)) + # mn = reverse_broadcast(mn, broadcast_axes) + # old_range = reverse_broadcast(old_range, broadcast_axes) + # data_normalized = (data - mn) / old_range return data_normalized @@ -60,17 +67,17 @@ def mean_std_normalization(data, mean, std, per_channel=True): if per_channel: channel_dimension = data[0].shape[0] if isinstance(mean, float) and isinstance(std, float): - mean = [mean] * channel_dimension - std = [std] * channel_dimension + mean = (mean,) * channel_dimension + std = (std,) * channel_dimension else: assert len(mean) == channel_dimension assert len(std) == channel_dimension - mean = np.broadcast_to(mean, (len(data), len(mean))) - std = np.broadcast_to(std, (len(data), len(std))) - data_normalized = ((data.T - mean.T) / std.T).T - else: - data_normalized = (data - mean) / std + broadcast_axes = tuple(range(2, len(data.shape))) + mean = np.expand_dims(np.broadcast_to(mean, (len(data), len(mean))), axis=broadcast_axes) + std = np.expand_dims(np.broadcast_to(std, (len(data), len(std))), axis=broadcast_axes) + + data_normalized = (data - mean) / std return data_normalized @@ -81,5 +88,7 @@ def cut_off_outliers(data, percentile_lower=0.2, percentile_upper=99.8, per_chan axes = tuple(range(1, len(data.shape))) cut_off_lower, cut_off_upper = np.percentile(data, (percentile_lower, percentile_upper), axis=axes) - np.clip(data.T, cut_off_lower.T, cut_off_upper.T, out=data.T) + cut_off_lower = np.expand_dims(cut_off_lower, axis=axes) + cut_off_upper = np.expand_dims(cut_off_upper, axis=axes) + np.clip(data, cut_off_lower, cut_off_upper, out=data) return data diff --git a/batchgenerators/augmentations/spatial_transformations.py b/batchgenerators/augmentations/spatial_transformations.py index a857004..c05fc91 100644 --- a/batchgenerators/augmentations/spatial_transformations.py +++ b/batchgenerators/augmentations/spatial_transformations.py @@ -19,7 +19,7 @@ from batchgenerators.augmentations.utils import create_zero_centered_coordinate_mesh, elastic_deform_coordinates, \ interpolate_img, \ rotate_coords_2d, rotate_coords_3d, scale_coords, resize_segmentation, resize_multichannel_image, \ - elastic_deform_coordinates_2 + elastic_deform_coordinates_2, get_broadcast_axes, reverse_broadcast from batchgenerators.augmentations.crop_and_pad_augmentations import random_crop as random_crop_aug from batchgenerators.augmentations.crop_and_pad_augmentations import center_crop as center_crop_aug @@ -259,8 +259,11 @@ def augment_spatial(data, seg, patch_size, patch_center_dist_from_border=30, ctr = np.random.uniform(patch_center_dist_from_border, data_shape_here - patch_center_dist_from_border) else: ctr = data_shape_here / 2. - 0.5 + for d in range(dim): coords[d] += ctr[d] + # vectorized version, seems a bit slower + # coords += reverse_broadcast(ctr, get_broadcast_axes(len(coords.shape))) for channel_id in range(data.shape[1]): data_result[sample_id, channel_id] = interpolate_img(data[sample_id, channel_id], coords, order_data, diff --git a/batchgenerators/augmentations/utils.py b/batchgenerators/augmentations/utils.py index 3570aa5..f9a62b2 100755 --- a/batchgenerators/augmentations/utils.py +++ b/batchgenerators/augmentations/utils.py @@ -15,6 +15,7 @@ import random from functools import lru_cache +from typing import Tuple import numpy as np from copy import deepcopy @@ -36,6 +37,25 @@ def generate_elastic_transform_coordinates(shape, alpha, sigma): return indices +def get_broadcast_axes(n: int) -> Tuple[int]: + """ + Args: + n: len(array.shape), where array is the array for which we want to broadcast to. + Returns: broadcast axes, (0, 1, ...) + """ + return tuple(range(n - 1)) + + +def reverse_broadcast(a: np.ndarray, axes: Tuple[int]) -> np.ndarray: + """ + Args: + a: array which we want to broadcast for batched operations + axes: (0, 1, ...) + Returns: array of shape (len(a), 1, 1, ...) + """ + return np.expand_dims(a, axis=axes).T + + @lru_cache(maxsize=None) # There will be only 1 miss, using maxsize None to remove locking and checks. def create_zero_centered_coordinate_mesh(shape): coords = np.array(np.meshgrid(*(np.arange(i) for i in shape), indexing='ij'), dtype=float) @@ -54,7 +74,7 @@ def convert_seg_image_to_one_hot_encoding(image, classes=None): ''' if classes is None: classes = np.unique(image) - out_image = np.zeros([len(classes)] + list(image.shape), dtype=image.dtype) + out_image = np.zeros((len(classes), *image.shape), dtype=image.dtype) for i, c in enumerate(classes): out_image[i][image == c] = 1 return out_image @@ -66,8 +86,7 @@ def convert_seg_image_to_one_hot_encoding_batched(image, classes=None): ''' if classes is None: classes = np.unique(image) - output_shape = (image.shape[0], len(classes), *image.shape[1:]) - out_image = np.zeros(output_shape, dtype=image.dtype) + out_image = np.zeros((image.shape[0], len(classes), *image.shape[1:]), dtype=image.dtype) for i, c in enumerate(classes): out_image[:, i][image == c] = 1 return out_image @@ -128,9 +147,8 @@ def rotate_coords_2d(coords, angle): def scale_coords(coords: np.ndarray, scale): if isinstance(scale, (tuple, list, np.ndarray)): - coords = (coords.T * scale).T - else: - coords *= scale + scale = reverse_broadcast(scale, get_broadcast_axes(len(coords.shape))) + coords *= scale return coords From 3b89bf57effc96e2fae1951ac9e3a47eb175c737 Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Tue, 22 Aug 2023 13:43:04 +0300 Subject: [PATCH 27/60] Minor changes --- .../augmentations/color_augmentations.py | 31 ++++++------------- .../crop_and_pad_augmentations.py | 7 +++-- 2 files changed, 14 insertions(+), 24 deletions(-) diff --git a/batchgenerators/augmentations/color_augmentations.py b/batchgenerators/augmentations/color_augmentations.py index 1a563b9..bc0d842 100644 --- a/batchgenerators/augmentations/color_augmentations.py +++ b/batchgenerators/augmentations/color_augmentations.py @@ -92,30 +92,18 @@ def augment_brightness_additive(data_sample, mu: float, sigma: float, per_channe else: rnd_nb = np.repeat(np.random.normal(mu, sigma), size) rnd_nb[np.random.uniform(size=size) > p_per_channel] = 0.0 - axes = tuple(range(len(data_sample.shape) - 1)) - data_sample += np.expand_dims(rnd_nb, axis=axes).T # Broadcasting rules require this + data_sample += reverse_broadcast(rnd_nb, get_broadcast_axes(len(data_sample.shape))) return data_sample -def get_size(per_channel, batched, shape): +def setup_augment_brightness_multiplicative(per_channel: bool, batched: bool, shape: Tuple[int]): if per_channel: if batched: - return shape[:2] - return shape[0] - else: - if batched: - return shape[0] - return 1 - - -def get_axes(per_channel, batched, n): - if per_channel and batched: - return tuple(range(2, n)) - return tuple(range(1, n)) - - -def setup_augment_brightness_multiplicative(per_channel: bool, batched: bool, shape: Tuple[int]): - return get_size(per_channel, batched, shape), get_axes(per_channel, batched, len(shape)) + return shape[:2], tuple(range(2, len(shape))) + return shape[0], tuple(range(1, len(shape))) + if batched: + return shape[0], tuple(range(1, len(shape))) + return 1, tuple(range(1, len(shape))) def augment_brightness_multiplicative(data_sample, multiplier_range=(0.5, 2), per_channel=True, batched=False): @@ -148,9 +136,10 @@ def augment_gamma(data_sample, gamma_range=(0.5, 2), invert_image=False, epsilon else: shape_0 = data_sample.shape[0] if callable(retain_stats): - retain_stats_here = np.array(retain_stats() for _ in range(shape_0)) + retain_stats_here = [retain_stats() for _ in range(shape_0)] else: - retain_stats_here = np.array([retain_stats]).repeat(shape_0) + retain_stats_here = (retain_stats,) * shape_0 + retain_stats_here = np.array(retain_stats_here) gamma = [] for i in range(shape_0): if gamma_range[0] < 1 and np.random.random() < 0.5: diff --git a/batchgenerators/augmentations/crop_and_pad_augmentations.py b/batchgenerators/augmentations/crop_and_pad_augmentations.py index 2ee9c15..c12013f 100644 --- a/batchgenerators/augmentations/crop_and_pad_augmentations.py +++ b/batchgenerators/augmentations/crop_and_pad_augmentations.py @@ -66,12 +66,12 @@ def crop(data: Union[Sequence[np.ndarray], np.ndarray], seg: Union[Sequence[np.n :param crop_type: random or center :return: """ - data_shape = (len(data), *data[0].shape) + data_shape = (len(data),) + data[0].shape data_dtype = data[0].dtype dim = len(data_shape) - 2 if seg is not None: - seg_shape = (len(seg), *seg[0].shape) + seg_shape = (len(seg),) + seg[0].shape seg_dtype = seg[0].dtype assert np.array_equal(seg_shape[2:], data_shape[2:]), "data and seg must have the same spatial dimensions. " \ @@ -112,7 +112,8 @@ def crop(data: Union[Sequence[np.ndarray], np.ndarray], seg: Union[Sequence[np.n temp1 = np.abs(np.minimum(lbs, zero)) lbs_plus_crop_size = lbs + crop_size temp2 = np.abs(np.minimum(zero, data_shape_here - lbs_plus_crop_size)) - need_to_pad = np.array(((0, 0), *zip(temp1, temp2))) + need_to_pad = ((0, 0),) + tuple(zip(temp1, temp2)) + need_to_pad = np.array(need_to_pad) # we should crop first, then pad -> reduces i/o for memmaps, reduces RAM usage and improves speed ubs = np.minimum(data_shape_here, lbs_plus_crop_size) From f9b71d47b817d3692c9b1ffc7c54df794f08767b Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Tue, 22 Aug 2023 13:43:50 +0300 Subject: [PATCH 28/60] Solved bug with single threaded augmenter due to usage of unexisting private method in nnUNetTrainer --- batchgenerators/dataloading/single_threaded_augmenter.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/batchgenerators/dataloading/single_threaded_augmenter.py b/batchgenerators/dataloading/single_threaded_augmenter.py index 5637c0e..f27cce2 100755 --- a/batchgenerators/dataloading/single_threaded_augmenter.py +++ b/batchgenerators/dataloading/single_threaded_augmenter.py @@ -40,3 +40,6 @@ def __next__(self): def next(self): return self.__next__() + + def _finish(self): + pass From b6db109d5da39e74b98e64f54f59f3556820f81e Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Tue, 22 Aug 2023 13:46:02 +0300 Subject: [PATCH 29/60] Further improving NumpyToTensor transform --- .../transforms/utility_transforms.py | 29 ++++++++----------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/batchgenerators/transforms/utility_transforms.py b/batchgenerators/transforms/utility_transforms.py index a31e3e8..5cf1d88 100644 --- a/batchgenerators/transforms/utility_transforms.py +++ b/batchgenerators/transforms/utility_transforms.py @@ -35,24 +35,19 @@ def __init__(self, keys=None, cast_to=None): if keys is not None and not isinstance(keys, (list, tuple)): keys = [keys] self.keys = keys - if cast_to is not None: - if cast_to == 'half': - self.cast_to = torch.half - elif cast_to == 'float': - self.cast_to = torch.float - elif cast_to == 'long': - self.cast_to = torch.long - elif cast_to == 'bool': - self.cast_to = torch.bool - else: - raise ValueError(f'Unknown value for cast_to: {self.cast_to}') - else: - self.cast_to = None - def cast(self, tensor): - if self.cast_to is not None: - tensor = tensor.to(self.cast_to) - return tensor + if cast_to is None: + self.cast = lambda x: x + elif cast_to == 'half': + self.cast = lambda x: x.to(torch.half) + elif cast_to == 'float': + self.cast = lambda x: x.to(torch.float) + elif cast_to == 'long': + self.cast = lambda x: x.to(torch.long) + elif cast_to == 'bool': + self.cast = lambda x: x.to(torch.bool) + else: + raise ValueError(f'Unknown value for cast_to: {cast_to}') def __call__(self, **data_dict): if self.keys is None: From d9b203a62214cd33ab5a9a29bbf08f54635d5ccd Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Tue, 22 Aug 2023 14:20:33 +0300 Subject: [PATCH 30/60] Using pandas unique instead of np unique pandas unique is faster because it uses hashtable --- batchgenerators/augmentations/utils.py | 15 ++++++++------- .../transforms/channel_selection_transforms.py | 7 ++++++- requirements.txt | 3 ++- setup.py | 3 ++- 4 files changed, 18 insertions(+), 10 deletions(-) diff --git a/batchgenerators/augmentations/utils.py b/batchgenerators/augmentations/utils.py index f9a62b2..c2563b8 100755 --- a/batchgenerators/augmentations/utils.py +++ b/batchgenerators/augmentations/utils.py @@ -18,6 +18,7 @@ from typing import Tuple import numpy as np +import pandas as pd from copy import deepcopy from scipy.ndimage import map_coordinates, fourier_gaussian from scipy.ndimage.filters import gaussian_filter, gaussian_gradient_magnitude @@ -73,7 +74,7 @@ def convert_seg_image_to_one_hot_encoding(image, classes=None): Prefer convert_seg_image_to_one_hot_encoding_batched. ''' if classes is None: - classes = np.unique(image) + classes = pd.unique(image.reshape(-1)) out_image = np.zeros((len(classes), *image.shape), dtype=image.dtype) for i, c in enumerate(classes): out_image[i][image == c] = 1 @@ -85,7 +86,7 @@ def convert_seg_image_to_one_hot_encoding_batched(image, classes=None): same as convert_seg_image_to_one_hot_encoding, but expects image to be (b, x, y, z) or (b, x, y) ''' if classes is None: - classes = np.unique(image) + classes = pd.unique(image.reshape(-1)) out_image = np.zeros((image.shape[0], len(classes), *image.shape[1:]), dtype=image.dtype) for i, c in enumerate(classes): out_image[:, i][image == c] = 1 @@ -162,7 +163,7 @@ def uncenter_coords(coords): def interpolate_img(img, coords, order=3, mode='nearest', cval=0.0, is_seg=False): if is_seg and order != 0: - unique_labels = np.unique(img) + unique_labels = pd.unique(img.reshape(-1)) result = np.zeros(coords.shape[1:], img.dtype) for c in unique_labels: res_new = map_coordinates((img == c).astype(float), coords, order=order, mode=mode, cval=cval) @@ -349,7 +350,7 @@ def random_crop_2D_image_batched(img, crop_size): def resize_image_by_padding(image, new_shape, pad_value=None): shape = image.shape - new_shape = np.max(np.concatenate((shape, new_shape)).reshape((2, len(shape))), axis=0) + new_shape = np.maximum(shape, new_shape) if pad_value is None: if len(shape) == 2: pad_value = image[0, 0] @@ -368,8 +369,8 @@ def resize_image_by_padding(image, new_shape, pad_value=None): def resize_image_by_padding_batched(image, new_shape, pad_value=None): - shape = image.shape[2:] - new_shape = np.max(np.concatenate((shape, new_shape)).reshape((2, len(shape))), axis=0) + shape = image.shape[1:] + new_shape = np.maximum(shape, new_shape) if pad_value is None: if len(shape) == 2: pad_value = image[0, 0] @@ -608,7 +609,7 @@ def resize_segmentation(segmentation, new_shape, order=3): :return: ''' tpe = segmentation.dtype - unique_labels = np.unique(segmentation) + unique_labels = pd.unique(segmentation.reshape(-1)) assert len(segmentation.shape) == len(new_shape), "new shape must have same dimensionality as segmentation" if order == 0: return resize(segmentation.astype(float), new_shape, order, mode="edge", clip=True, anti_aliasing=False).astype( diff --git a/batchgenerators/transforms/channel_selection_transforms.py b/batchgenerators/transforms/channel_selection_transforms.py index 601a375..5bd1ae1 100644 --- a/batchgenerators/transforms/channel_selection_transforms.py +++ b/batchgenerators/transforms/channel_selection_transforms.py @@ -15,6 +15,9 @@ import numpy as np from warnings import warn + +import pandas as pd + from batchgenerators.transforms.abstract_transforms import AbstractTransform @@ -167,6 +170,7 @@ def __init__(self, label, label_key="seg"): self.label = [label] else: self.label = sorted(label) + self.label = set(self.label) def __call__(self, **data_dict): seg = data_dict.get(self.label_key) @@ -175,7 +179,8 @@ def __call__(self, **data_dict): warn("You used SegLabelSelectionBinarizeTransform but there is no 'seg' key in your data_dict, returning " "data_dict unmodified", Warning) else: - discard_labels = set(np.unique(seg)) - set(self.label) - set([0]) + + discard_labels = set(pd.unique(seg.reshape(-1))) - self.label - {0} for label in discard_labels: seg[seg == label] = 0 for label in self.label: diff --git a/requirements.txt b/requirements.txt index 4b72370..741b18d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ scikit-learn numpy>=1.10.2 scipy scikit-image -scikit-learn \ No newline at end of file +scikit-learn +pandas \ No newline at end of file diff --git a/setup.py b/setup.py index 437b8b3..e13e845 100755 --- a/setup.py +++ b/setup.py @@ -16,7 +16,8 @@ "scikit-image", "scikit-learn", "future", - "threadpoolctl" + "threadpoolctl", + "pandas" ], keywords=['data augmentation', 'deep learning', 'image segmentation', 'image classification', 'medical image analysis', 'medical image segmentation'], From e5b342edac8f2d807d66bddb760836bd145cbf45 Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Tue, 22 Aug 2023 16:27:25 +0300 Subject: [PATCH 31/60] Fixing batched operations (new random for each sample) --- .../augmentations/color_augmentations.py | 34 +++++++++++-------- .../augmentations/noise_augmentations.py | 7 +--- .../augmentations/spatial_transformations.py | 3 +- 3 files changed, 22 insertions(+), 22 deletions(-) diff --git a/batchgenerators/augmentations/color_augmentations.py b/batchgenerators/augmentations/color_augmentations.py index bc0d842..c3c1a00 100644 --- a/batchgenerators/augmentations/color_augmentations.py +++ b/batchgenerators/augmentations/color_augmentations.py @@ -22,27 +22,23 @@ reverse_broadcast -def augment_contrast(data_sample: np.ndarray, - contrast_range: Union[Tuple[float, float], Callable[[], float]] = (0.75, 1.25), - preserve_range: bool = True, - per_channel: bool = True, - p_per_channel: float = 1, - batched=False) -> np.ndarray: - size = data_sample.shape[1 if batched else 0] +def get_augment_contrast_factor(contrast_range: Union[Tuple[float, float], Callable[[], float]], + per_channel: bool, + size: int): + # TODO: callable contrast_range is not used. Remove this feature. if per_channel: if callable(contrast_range): factor = [contrast_range() for _ in range(size)] else: factor = [] + contrast_l = max(contrast_range[0], 1) for _ in range(size): if contrast_range[0] < 1 and np.random.random() < 0.5: factor.append(np.random.uniform(contrast_range[0], 1)) else: - factor.append(np.random.uniform(max(contrast_range[0], 1), contrast_range[1])) + factor.append(np.random.uniform(contrast_l, contrast_range[1])) factor = np.array(factor) - if batched: - factor = factor.repeat(data_sample.shape[0]) else: if callable(contrast_range): factor = contrast_range() @@ -52,12 +48,19 @@ def augment_contrast(data_sample: np.ndarray, else: factor = np.random.uniform(max(contrast_range[0], 1), contrast_range[1]) - mask = np.random.uniform(size=size) < p_per_channel - if np.any(mask): - if batched: - mask = np.atleast_2d(mask).repeat(data_sample.shape[0], axis=0) + return factor + +def augment_contrast(data_sample: np.ndarray, + contrast_range: Union[Tuple[float, float], Callable[[], float]] = (0.75, 1.25), + preserve_range: bool = True, + per_channel: bool = True, + p_per_channel: float = 1, + batched=False) -> np.ndarray: + mask = np.random.uniform(size=data_sample.shape[:2] if batched else data_sample.shape[0]) < p_per_channel + if np.any(mask): workon = data_sample[mask] + factor = get_augment_contrast_factor(contrast_range, per_channel, len(workon)) axes = tuple(range(1, len(workon.shape))) mean = workon.mean(axis=axes) if preserve_range: @@ -141,11 +144,12 @@ def augment_gamma(data_sample, gamma_range=(0.5, 2), invert_image=False, epsilon retain_stats_here = (retain_stats,) * shape_0 retain_stats_here = np.array(retain_stats_here) gamma = [] + gamma_l = max(gamma_range[0], 1) for i in range(shape_0): if gamma_range[0] < 1 and np.random.random() < 0.5: gamma.append(np.random.uniform(gamma_range[0], 1)) else: - gamma.append(np.random.uniform(max(gamma_range[0], 1), gamma_range[1])) + gamma.append(np.random.uniform(gamma_l, gamma_range[1])) gamma = np.array(gamma) axes = tuple(range(1, len(data_sample.shape))) diff --git a/batchgenerators/augmentations/noise_augmentations.py b/batchgenerators/augmentations/noise_augmentations.py index f82d59a..347b36b 100644 --- a/batchgenerators/augmentations/noise_augmentations.py +++ b/batchgenerators/augmentations/noise_augmentations.py @@ -43,14 +43,9 @@ def setup_augment_gaussian_noise(noise_variance: Tuple[float, float], per_channe def augment_gaussian_noise(data_sample: np.ndarray, noise_variance: Tuple[float, float] = (0, 0.1), p_per_channel: float = 1, per_channel: bool = False, batched: bool = False) -> np.ndarray: - mask = np.random.uniform(size=data_sample.shape[1 if batched else 0]) < p_per_channel + mask = np.random.uniform(size=data_sample.shape[:2] if batched else data_sample.shape[0]) < p_per_channel size = np.count_nonzero(mask) if size: - if batched: - num_samples = data_sample.shape[0] - mask = np.atleast_2d(mask).repeat(num_samples, axis=0) - size *= num_samples - variance = setup_augment_gaussian_noise(noise_variance, per_channel, size) data_sample[mask] += np.random.normal(0.0, variance, data_sample[mask].T.shape).T diff --git a/batchgenerators/augmentations/spatial_transformations.py b/batchgenerators/augmentations/spatial_transformations.py index c05fc91..3a3c7d3 100644 --- a/batchgenerators/augmentations/spatial_transformations.py +++ b/batchgenerators/augmentations/spatial_transformations.py @@ -238,11 +238,12 @@ def augment_spatial(data, seg, patch_size, patch_center_dist_from_border=30, if do_scale and np.random.uniform() < p_scale_per_sample: if independent_scale_for_each_axis and np.random.uniform() < p_independent_scale_per_axis: sc = [] + scale_l = max(scale[0], 1) for _ in range(dim): if scale[0] < 1 and np.random.random() < 0.5: sc.append(np.random.uniform(scale[0], 1)) else: - sc.append(np.random.uniform(max(scale[0], 1), scale[1])) + sc.append(np.random.uniform(scale_l, scale[1])) else: if scale[0] < 1 and np.random.random() < 0.5: sc = np.random.uniform(scale[0], 1) From 0821b661ef01ddd45b388750b54fe50b2caf8d0c Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Tue, 22 Aug 2023 17:06:49 +0300 Subject: [PATCH 32/60] Implementing batched mirror transform --- .../augmentations/spatial_transformations.py | 37 ++++++++++--------- .../transforms/spatial_transforms.py | 23 +++++------- 2 files changed, 29 insertions(+), 31 deletions(-) diff --git a/batchgenerators/augmentations/spatial_transformations.py b/batchgenerators/augmentations/spatial_transformations.py index 3a3c7d3..b7093e0 100644 --- a/batchgenerators/augmentations/spatial_transformations.py +++ b/batchgenerators/augmentations/spatial_transformations.py @@ -112,25 +112,28 @@ def augment_zoom(sample_data, sample_seg, zoom_factors, order=3, order_seg=1): return sample_data, target_seg -def augment_mirroring(sample_data, sample_seg=None, axes=(0, 1, 2)): - if (len(sample_data.shape) != 3) and (len(sample_data.shape) != 4): - raise Exception( - "Invalid dimension for sample_data and sample_seg. sample_data and sample_seg should be either " - "[channels, x, y] or [channels, x, y, z]") +def augment_mirroring_batched(sample_data, sample_seg=None, axes=(0, 1, 2)): + assert len(sample_data.shape) == 5 or len(sample_data.shape) == 4, \ + "Invalid dimension for sample_data and sample_seg. sample_data and sample_seg should be either " \ + "[batch, channels, x, y] or [batch, channels, x, y, z]" + workon = np.expand_dims(sample_data, 0) if sample_seg is None else np.stack((sample_data, sample_seg)) if 0 in axes and np.random.uniform() < 0.5: - sample_data[:, :] = sample_data[:, ::-1] - if sample_seg is not None: - sample_seg[:, :] = sample_seg[:, ::-1] + workon[:, :, :, :] = workon[:, :, :, ::-1] if 1 in axes and np.random.uniform() < 0.5: - sample_data[:, :, :] = sample_data[:, :, ::-1] - if sample_seg is not None: - sample_seg[:, :, :] = sample_seg[:, :, ::-1] - if 2 in axes and len(sample_data.shape) == 4: - if np.random.uniform() < 0.5: - sample_data[:, :, :, :] = sample_data[:, :, :, ::-1] - if sample_seg is not None: - sample_seg[:, :, :, :] = sample_seg[:, :, :, ::-1] - return sample_data, sample_seg + workon[:, :, :, :, :] = workon[:, :, :, :, ::-1] + if 2 in axes and len(sample_data.shape) == 6 and np.random.uniform() < 0.5: + workon[:, :, :, :, :, :] = workon[:, :, :, :, :, ::-1] + if sample_seg is None: + return workon[0], None + return workon + + +def augment_mirroring(sample_data, sample_seg=None, axes=(0, 1, 2)): + sample_data = np.expand_dims(sample_data, 0) + if sample_seg is not None: + sample_seg = np.expand_dims(sample_seg, 0) + sample_data, sample_seg = augment_mirroring_batched(sample_data, sample_seg, axes) + return sample_data[0], sample_seg[0] if sample_seg is not None else None def augment_channel_translation(data, const_channel=0, max_shifts=None): diff --git a/batchgenerators/transforms/spatial_transforms.py b/batchgenerators/transforms/spatial_transforms.py index ed27501..ea998b7 100644 --- a/batchgenerators/transforms/spatial_transforms.py +++ b/batchgenerators/transforms/spatial_transforms.py @@ -16,7 +16,7 @@ from batchgenerators.transforms.abstract_transforms import AbstractTransform from batchgenerators.augmentations.spatial_transformations import augment_spatial, augment_spatial_2, \ augment_channel_translation, \ - augment_mirroring, augment_transpose_axes, augment_zoom, augment_resize, augment_rot90 + augment_mirroring, augment_transpose_axes, augment_zoom, augment_resize, augment_rot90, augment_mirroring_batched import numpy as np @@ -203,19 +203,14 @@ def __call__(self, **data_dict): data = data_dict.get(self.data_key) seg = data_dict.get(self.label_key) - for b in range(len(data)): - if np.random.uniform() < self.p_per_sample: - sample_seg = None - if seg is not None: - sample_seg = seg[b] - ret_val = augment_mirroring(data[b], sample_seg, axes=self.axes) - data[b] = ret_val[0] - if seg is not None: - seg[b] = ret_val[1] - - data_dict[self.data_key] = data - if seg is not None: - data_dict[self.label_key] = seg + mask = np.random.uniform(size=len(data)) < self.p_per_sample + if np.any(mask): + if seg is None: + data[mask], _ = augment_mirroring_batched(data[mask], None, self.axes) + else: + data[mask], seg[mask] = augment_mirroring_batched(data[mask], seg[mask], self.axes) + data_dict[self.label_key] = seg + data_dict[self.data_key] = data return data_dict From c87865e3fb9d2f6493797a5c820e15d6d18744ed Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Tue, 22 Aug 2023 17:07:53 +0300 Subject: [PATCH 33/60] Misc change removed explicit axis keyword from np.expand_dims --- batchgenerators/augmentations/color_augmentations.py | 2 +- batchgenerators/augmentations/normalizations.py | 8 ++++---- batchgenerators/augmentations/utils.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/batchgenerators/augmentations/color_augmentations.py b/batchgenerators/augmentations/color_augmentations.py index c3c1a00..8c21eb1 100644 --- a/batchgenerators/augmentations/color_augmentations.py +++ b/batchgenerators/augmentations/color_augmentations.py @@ -111,7 +111,7 @@ def setup_augment_brightness_multiplicative(per_channel: bool, batched: bool, sh def augment_brightness_multiplicative(data_sample, multiplier_range=(0.5, 2), per_channel=True, batched=False): size, axes = setup_augment_brightness_multiplicative(per_channel, batched, data_sample.shape) - data_sample *= np.expand_dims(np.random.uniform(multiplier_range[0], multiplier_range[1], size=size), axis=axes) + data_sample *= np.expand_dims(np.random.uniform(multiplier_range[0], multiplier_range[1], size=size), axes) return data_sample diff --git a/batchgenerators/augmentations/normalizations.py b/batchgenerators/augmentations/normalizations.py index 029a288..c5ddc64 100644 --- a/batchgenerators/augmentations/normalizations.py +++ b/batchgenerators/augmentations/normalizations.py @@ -74,8 +74,8 @@ def mean_std_normalization(data, mean, std, per_channel=True): assert len(std) == channel_dimension broadcast_axes = tuple(range(2, len(data.shape))) - mean = np.expand_dims(np.broadcast_to(mean, (len(data), len(mean))), axis=broadcast_axes) - std = np.expand_dims(np.broadcast_to(std, (len(data), len(std))), axis=broadcast_axes) + mean = np.expand_dims(np.broadcast_to(mean, (len(data), len(mean))), broadcast_axes) + std = np.expand_dims(np.broadcast_to(std, (len(data), len(std))), broadcast_axes) data_normalized = (data - mean) / std return data_normalized @@ -88,7 +88,7 @@ def cut_off_outliers(data, percentile_lower=0.2, percentile_upper=99.8, per_chan axes = tuple(range(1, len(data.shape))) cut_off_lower, cut_off_upper = np.percentile(data, (percentile_lower, percentile_upper), axis=axes) - cut_off_lower = np.expand_dims(cut_off_lower, axis=axes) - cut_off_upper = np.expand_dims(cut_off_upper, axis=axes) + cut_off_lower = np.expand_dims(cut_off_lower, axes) + cut_off_upper = np.expand_dims(cut_off_upper, axes) np.clip(data, cut_off_lower, cut_off_upper, out=data) return data diff --git a/batchgenerators/augmentations/utils.py b/batchgenerators/augmentations/utils.py index c2563b8..aa7911f 100755 --- a/batchgenerators/augmentations/utils.py +++ b/batchgenerators/augmentations/utils.py @@ -54,7 +54,7 @@ def reverse_broadcast(a: np.ndarray, axes: Tuple[int]) -> np.ndarray: axes: (0, 1, ...) Returns: array of shape (len(a), 1, 1, ...) """ - return np.expand_dims(a, axis=axes).T + return np.expand_dims(a, axes).T @lru_cache(maxsize=None) # There will be only 1 miss, using maxsize None to remove locking and checks. From 17a08770190ed0b665897c93ce3781cf9eaede15 Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Tue, 22 Aug 2023 17:11:38 +0300 Subject: [PATCH 34/60] Misc removed unused imports --- batchgenerators/augmentations/color_augmentations.py | 1 - batchgenerators/augmentations/normalizations.py | 2 -- batchgenerators/augmentations/resample_augmentations.py | 1 - batchgenerators/augmentations/spatial_transformations.py | 2 +- batchgenerators/transforms/noise_transforms.py | 1 - batchgenerators/transforms/spatial_transforms.py | 2 +- 6 files changed, 2 insertions(+), 7 deletions(-) diff --git a/batchgenerators/augmentations/color_augmentations.py b/batchgenerators/augmentations/color_augmentations.py index 8c21eb1..62a46ef 100644 --- a/batchgenerators/augmentations/color_augmentations.py +++ b/batchgenerators/augmentations/color_augmentations.py @@ -14,7 +14,6 @@ # limitations under the License. from builtins import range -from functools import lru_cache from typing import Tuple, Union, Callable import numpy as np diff --git a/batchgenerators/augmentations/normalizations.py b/batchgenerators/augmentations/normalizations.py index c5ddc64..45bc3c3 100644 --- a/batchgenerators/augmentations/normalizations.py +++ b/batchgenerators/augmentations/normalizations.py @@ -15,8 +15,6 @@ import numpy as np -from batchgenerators.augmentations.utils import get_broadcast_axes, reverse_broadcast - def range_normalization(data, rnge=(0, 1), per_channel=True, eps=1e-8): if per_channel: diff --git a/batchgenerators/augmentations/resample_augmentations.py b/batchgenerators/augmentations/resample_augmentations.py index ce13132..ad7e530 100644 --- a/batchgenerators/augmentations/resample_augmentations.py +++ b/batchgenerators/augmentations/resample_augmentations.py @@ -15,7 +15,6 @@ from builtins import range import numpy as np -import random from skimage.transform import resize from batchgenerators.augmentations.utils import uniform diff --git a/batchgenerators/augmentations/spatial_transformations.py b/batchgenerators/augmentations/spatial_transformations.py index b7093e0..d3c9e08 100644 --- a/batchgenerators/augmentations/spatial_transformations.py +++ b/batchgenerators/augmentations/spatial_transformations.py @@ -19,7 +19,7 @@ from batchgenerators.augmentations.utils import create_zero_centered_coordinate_mesh, elastic_deform_coordinates, \ interpolate_img, \ rotate_coords_2d, rotate_coords_3d, scale_coords, resize_segmentation, resize_multichannel_image, \ - elastic_deform_coordinates_2, get_broadcast_axes, reverse_broadcast + elastic_deform_coordinates_2 from batchgenerators.augmentations.crop_and_pad_augmentations import random_crop as random_crop_aug from batchgenerators.augmentations.crop_and_pad_augmentations import center_crop as center_crop_aug diff --git a/batchgenerators/transforms/noise_transforms.py b/batchgenerators/transforms/noise_transforms.py index d3de83b..a214a78 100644 --- a/batchgenerators/transforms/noise_transforms.py +++ b/batchgenerators/transforms/noise_transforms.py @@ -20,7 +20,6 @@ import numpy as np from typing import Union, Tuple -from scipy import ndimage from scipy.ndimage import median_filter from scipy.signal import convolve diff --git a/batchgenerators/transforms/spatial_transforms.py b/batchgenerators/transforms/spatial_transforms.py index ea998b7..1c4abe4 100644 --- a/batchgenerators/transforms/spatial_transforms.py +++ b/batchgenerators/transforms/spatial_transforms.py @@ -16,7 +16,7 @@ from batchgenerators.transforms.abstract_transforms import AbstractTransform from batchgenerators.augmentations.spatial_transformations import augment_spatial, augment_spatial_2, \ augment_channel_translation, \ - augment_mirroring, augment_transpose_axes, augment_zoom, augment_resize, augment_rot90, augment_mirroring_batched + augment_transpose_axes, augment_zoom, augment_resize, augment_rot90, augment_mirroring_batched import numpy as np From d8d9d5c9abc7b4c46d5af8894c562ecdbaa955fb Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Tue, 22 Aug 2023 17:33:10 +0300 Subject: [PATCH 35/60] Misc Minor improvements --- batchgenerators/augmentations/noise_augmentations.py | 4 ++-- batchgenerators/transforms/color_transforms.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/batchgenerators/augmentations/noise_augmentations.py b/batchgenerators/augmentations/noise_augmentations.py index 347b36b..c2c6b72 100644 --- a/batchgenerators/augmentations/noise_augmentations.py +++ b/batchgenerators/augmentations/noise_augmentations.py @@ -34,9 +34,9 @@ def setup_augment_gaussian_noise(noise_variance: Tuple[float, float], per_channe if not per_channel: variance = noise_variance[0] if noise_variance[0] == noise_variance[1] else \ random.uniform(noise_variance[0], noise_variance[1]) - variance = np.repeat(variance, size) + variance = np.array((variance,) * size) else: - variance = np.repeat(noise_variance[0], size) if noise_variance[0] == noise_variance[1] else \ + variance = np.array((noise_variance[0],) * size) if noise_variance[0] == noise_variance[1] else \ np.random.uniform(noise_variance[0], noise_variance[1], size=size) return variance diff --git a/batchgenerators/transforms/color_transforms.py b/batchgenerators/transforms/color_transforms.py index 4ad3247..93c480e 100644 --- a/batchgenerators/transforms/color_transforms.py +++ b/batchgenerators/transforms/color_transforms.py @@ -122,9 +122,10 @@ def __init__(self, multiplier_range=(0.5, 2), per_channel=True, data_key="data", self.per_channel = per_channel def __call__(self, **data_dict): - mask = np.random.uniform(size=len(data_dict[self.data_key])) < self.p_per_sample + data = data_dict[self.data_key] + mask = np.random.uniform(size=len(data)) < self.p_per_sample if np.any(mask): - data_dict[self.data_key][mask] = augment_brightness_multiplicative(data_dict[self.data_key][mask], + data_dict[self.data_key][mask] = augment_brightness_multiplicative(data[mask], self.multiplier_range, self.per_channel, batched=True) From c4d7c4b7bb0d6fc9078fb644461d06c0e424a58c Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Tue, 22 Aug 2023 18:40:43 +0300 Subject: [PATCH 36/60] Solving bug with unpickle-able NumpyToTensor --- .../transforms/utility_transforms.py | 51 ++++++++++++++----- 1 file changed, 39 insertions(+), 12 deletions(-) diff --git a/batchgenerators/transforms/utility_transforms.py b/batchgenerators/transforms/utility_transforms.py index 5cf1d88..834afa4 100644 --- a/batchgenerators/transforms/utility_transforms.py +++ b/batchgenerators/transforms/utility_transforms.py @@ -37,18 +37,41 @@ def __init__(self, keys=None, cast_to=None): self.keys = keys if cast_to is None: - self.cast = lambda x: x + self.cast = self.no_cast elif cast_to == 'half': - self.cast = lambda x: x.to(torch.half) + self.cast = self.half_cast elif cast_to == 'float': - self.cast = lambda x: x.to(torch.float) + self.cast = self.float_cast elif cast_to == 'long': - self.cast = lambda x: x.to(torch.long) + self.cast = self.long_cast elif cast_to == 'bool': - self.cast = lambda x: x.to(torch.bool) + self.cast = self.bool_cast else: raise ValueError(f'Unknown value for cast_to: {cast_to}') + def cast(self, x): + pass + + @staticmethod + def no_cast(x): + return x + + @staticmethod + def float_cast(x): + return x.to(torch.float) + + @staticmethod + def long_cast(x): + return x.to(torch.long) + + @staticmethod + def bool_cast( x): + return x.to(torch.bool) + + @staticmethod + def half_cast(x): + return x.to(torch.half) + def __call__(self, **data_dict): if self.keys is None: for key, val in data_dict.items(): @@ -60,13 +83,13 @@ def __call__(self, **data_dict): for key in self.keys: if isinstance(data_dict[key], np.ndarray): data_dict[key] = self.cast(torch.from_numpy(data_dict[key])).contiguous() - elif isinstance(data_dict[key], (list, tuple)) and all([isinstance(i, np.ndarray) for i in data_dict[key]]): + elif isinstance(data_dict[key], (list, tuple)) and all( + [isinstance(i, np.ndarray) for i in data_dict[key]]): data_dict[key] = [self.cast(torch.from_numpy(i)).contiguous() for i in data_dict[key]] return data_dict - class ListToNumpy(AbstractTransform): """Utility function for pytorch. Converts data (and seg) numpy ndarrays to pytorch tensors """ @@ -195,13 +218,15 @@ def __call__(self, **data_dict): if seg is not None: if not seg.shape[1] % self.output_channels == 0: from warnings import warn - warn("Calling ConvertMultiSegToArgmaxTransform but number of input channels {} cannot be divided into {} output channels.".format(seg.shape[1], self.output_channels)) + warn( + "Calling ConvertMultiSegToArgmaxTransform but number of input channels {} cannot be divided into {} output channels.".format( + seg.shape[1], self.output_channels)) n_labels = seg.shape[1] // self.output_channels target_size = list(seg.shape) target_size[1] = self.output_channels output = np.zeros(target_size, dtype=seg.dtype) for i in range(self.output_channels): - output[:, i] = np.argmax(seg[:, i*n_labels:(i+1)*n_labels], 1) + output[:, i] = np.argmax(seg[:, i * n_labels:(i + 1) * n_labels], 1) if self.labels is not None: if list(self.labels) != list(range(n_labels)): for index, value in enumerate(reversed(self.labels)): @@ -225,7 +250,8 @@ def __init__(self, dim, get_rois_from_seg_flag=False, class_specific_seg_flag=Fa self.class_specific_seg_flag = class_specific_seg_flag def __call__(self, **data_dict): - data_dict = convert_seg_to_bounding_box_coordinates(data_dict, self.dim, self.get_rois_from_seg_flag, class_specific_seg_flag=self.class_specific_seg_flag) + data_dict = convert_seg_to_bounding_box_coordinates(data_dict, self.dim, self.get_rois_from_seg_flag, + class_specific_seg_flag=self.class_specific_seg_flag) return data_dict @@ -233,6 +259,7 @@ class MoveSegToDataChannel(AbstractTransform): """ concatenates data_dict['seg'] to data_dict['data'] """ + def __call__(self, **data_dict): data_dict['data'] = np.concatenate((data_dict['data'], data_dict['seg']), axis=1) return data_dict @@ -400,7 +427,7 @@ def __call__(self, **data_dict): selected_channels = inp[:, self.channel_indexes] if outp is None: - #warn("output key %s is not present in dict, it will be created" % self.output_key) + # warn("output key %s is not present in dict, it will be created" % self.output_key) outp = selected_channels data_dict[self.output_key] = outp else: @@ -477,7 +504,7 @@ def __call__(self, **data_dict): # expected to have the same length some_value = data_dict.get(self.relevant_keys[0]) for b in range(len(some_value)): - new_dict = {i: data_dict[i][b:b+1] for i in self.relevant_keys} + new_dict = {i: data_dict[i][b:b + 1] for i in self.relevant_keys} random_transform = np.random.choice(len(self.list_of_transforms), p=self.p) ret = self.list_of_transforms[random_transform](**new_dict) for i in self.relevant_keys: From 707f3ebe6994fb8c14463e460862dd519cc6f5f6 Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Wed, 23 Aug 2023 10:14:42 +0300 Subject: [PATCH 37/60] Fixing mirroring --- .../augmentations/spatial_transformations.py | 33 ++++++++++++++----- .../transforms/spatial_transforms.py | 2 +- 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/batchgenerators/augmentations/spatial_transformations.py b/batchgenerators/augmentations/spatial_transformations.py index d3c9e08..de3e89f 100644 --- a/batchgenerators/augmentations/spatial_transformations.py +++ b/batchgenerators/augmentations/spatial_transformations.py @@ -116,13 +116,17 @@ def augment_mirroring_batched(sample_data, sample_seg=None, axes=(0, 1, 2)): assert len(sample_data.shape) == 5 or len(sample_data.shape) == 4, \ "Invalid dimension for sample_data and sample_seg. sample_data and sample_seg should be either " \ "[batch, channels, x, y] or [batch, channels, x, y, z]" + size = len(sample_data) workon = np.expand_dims(sample_data, 0) if sample_seg is None else np.stack((sample_data, sample_seg)) - if 0 in axes and np.random.uniform() < 0.5: - workon[:, :, :, :] = workon[:, :, :, ::-1] - if 1 in axes and np.random.uniform() < 0.5: - workon[:, :, :, :, :] = workon[:, :, :, :, ::-1] - if 2 in axes and len(sample_data.shape) == 6 and np.random.uniform() < 0.5: - workon[:, :, :, :, :, :] = workon[:, :, :, :, :, ::-1] + if 0 in axes: + mask = np.random.uniform(size=size) < 0.5 + workon[:, mask] = np.flip(workon[:, mask], 3) + if 1 in axes: + mask = np.random.uniform(size=size) < 0.5 + workon[:, mask] = np.flip(workon[:, mask], 4) + if 2 in axes and size == 5: + mask = np.random.uniform(size=size) < 0.5 + workon[:, mask] = np.flip(workon[:, mask], 5) if sample_seg is None: return workon[0], None return workon @@ -132,8 +136,21 @@ def augment_mirroring(sample_data, sample_seg=None, axes=(0, 1, 2)): sample_data = np.expand_dims(sample_data, 0) if sample_seg is not None: sample_seg = np.expand_dims(sample_seg, 0) - sample_data, sample_seg = augment_mirroring_batched(sample_data, sample_seg, axes) - return sample_data[0], sample_seg[0] if sample_seg is not None else None + return augment_mirroring_batched(sample_data, sample_seg, axes) + if 0 in axes and np.random.uniform() < 0.5: + sample_data[:, :] = sample_data[:, ::-1] + if sample_seg is not None: + sample_seg[:, :] = sample_seg[:, ::-1] + if 1 in axes and np.random.uniform() < 0.5: + sample_data[:, :, :] = sample_data[:, :, ::-1] + if sample_seg is not None: + sample_seg[:, :, :] = sample_seg[:, :, ::-1] + if 2 in axes and len(sample_data.shape) == 4: + if np.random.uniform() < 0.5: + sample_data[:, :, :, :] = sample_data[:, :, :, ::-1] + if sample_seg is not None: + sample_seg[:, :, :, :] = sample_seg[:, :, :, ::-1] + return sample_data, sample_seg def augment_channel_translation(data, const_channel=0, max_shifts=None): diff --git a/batchgenerators/transforms/spatial_transforms.py b/batchgenerators/transforms/spatial_transforms.py index 1c4abe4..bec7357 100644 --- a/batchgenerators/transforms/spatial_transforms.py +++ b/batchgenerators/transforms/spatial_transforms.py @@ -200,7 +200,7 @@ def __init__(self, axes=(0, 1, 2), data_key="data", label_key="seg", p_per_sampl "is now axes=(0, 1, 2). Please adapt your scripts accordingly.") def __call__(self, **data_dict): - data = data_dict.get(self.data_key) + data = data_dict[self.data_key] seg = data_dict.get(self.label_key) mask = np.random.uniform(size=len(data)) < self.p_per_sample From 70adf88ec5411dfb58d0acd66b6f02e45afb62d9 Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Mon, 28 Aug 2023 16:20:14 +0300 Subject: [PATCH 38/60] Using keepdims instead of broadcasting again --- .../augmentations/color_augmentations.py | 39 ++++++++----------- .../augmentations/normalizations.py | 20 ++++------ 2 files changed, 23 insertions(+), 36 deletions(-) diff --git a/batchgenerators/augmentations/color_augmentations.py b/batchgenerators/augmentations/color_augmentations.py index 62a46ef..5ba4487 100644 --- a/batchgenerators/augmentations/color_augmentations.py +++ b/batchgenerators/augmentations/color_augmentations.py @@ -23,7 +23,8 @@ def get_augment_contrast_factor(contrast_range: Union[Tuple[float, float], Callable[[], float]], per_channel: bool, - size: int): + size: int, + broadcast_size: int): # TODO: callable contrast_range is not used. Remove this feature. if per_channel: if callable(contrast_range): @@ -37,7 +38,7 @@ def get_augment_contrast_factor(contrast_range: Union[Tuple[float, float], Calla else: factor.append(np.random.uniform(contrast_l, contrast_range[1])) - factor = np.array(factor) + factor = reverse_broadcast(np.array(factor), get_broadcast_axes(broadcast_size)) else: if callable(contrast_range): factor = contrast_range() @@ -59,19 +60,16 @@ def augment_contrast(data_sample: np.ndarray, mask = np.random.uniform(size=data_sample.shape[:2] if batched else data_sample.shape[0]) < p_per_channel if np.any(mask): workon = data_sample[mask] - factor = get_augment_contrast_factor(contrast_range, per_channel, len(workon)) + factor = get_augment_contrast_factor(contrast_range, per_channel, len(workon), len(workon.shape)) axes = tuple(range(1, len(workon.shape))) - mean = workon.mean(axis=axes) + mean = workon.mean(axis=axes, keepdims=True) if preserve_range: - minm = workon.min(axis=axes) - maxm = workon.max(axis=axes) + minm = workon.min(axis=axes, keepdims=True) + maxm = workon.max(axis=axes, keepdims=True) - data_sample[mask] = (workon.T * factor + mean * (1 - factor)).T # writing directly in data_sample + data_sample[mask] = workon * factor + mean * (1 - factor) # writing directly in data_sample if preserve_range: - broadcast_axes = get_broadcast_axes(len(workon.shape)) - minm = reverse_broadcast(minm, broadcast_axes) - maxm = reverse_broadcast(maxm, broadcast_axes) np.clip(data_sample[mask], minm, maxm, out=data_sample[mask]) return data_sample @@ -155,25 +153,20 @@ def augment_gamma(data_sample, gamma_range=(0.5, 2), invert_image=False, epsilon retain_any_stats = np.any(retain_stats_here) if retain_any_stats: - mn = data_sample[retain_stats_here].mean(axis=axes) - sd = data_sample[retain_stats_here].mean(axis=axes) + mn = data_sample[retain_stats_here].mean(axis=axes, keepdims=True) + sd = data_sample[retain_stats_here].mean(axis=axes, keepdims=True) - minm = data_sample.min(axis=axes) - rnge = data_sample.max(axis=axes) - minm + epsilon + minm = data_sample.min(axis=axes, keepdims=True) + rnge = data_sample.max(axis=axes, keepdims=True) - minm + epsilon - # aux = (np.power(((data_sample.T - minm) / rnge), gamma) * rnge + minm).T # This is slower broadcast_axes = get_broadcast_axes(len(data_sample.shape)) - minm = reverse_broadcast(minm, broadcast_axes) - rnge = reverse_broadcast(rnge, broadcast_axes) - gamma = reverse_broadcast(gamma, broadcast_axes) + gamma = reverse_broadcast(gamma, broadcast_axes) # TODO: Remove data_sample = np.power((data_sample - minm) / rnge, gamma) * rnge + minm if retain_any_stats: - data_sample[retain_stats_here] -= reverse_broadcast( - data_sample[retain_stats_here].mean(axis=axes), broadcast_axes) - data_sample[retain_stats_here] *= reverse_broadcast( - sd / (data_sample[retain_stats_here].std(axis=axes) + 1e-8), broadcast_axes) - data_sample[retain_stats_here] += reverse_broadcast(mn, broadcast_axes) + data_sample[retain_stats_here] -= data_sample[retain_stats_here].mean(axis=axes, keepdims=True) + data_sample[retain_stats_here] *= sd / (data_sample[retain_stats_here].std(axis=axes, keepdims=True) + 1e-8) + data_sample[retain_stats_here] += mn if invert_image: data_sample = - data_sample diff --git a/batchgenerators/augmentations/normalizations.py b/batchgenerators/augmentations/normalizations.py index 45bc3c3..d6d813e 100644 --- a/batchgenerators/augmentations/normalizations.py +++ b/batchgenerators/augmentations/normalizations.py @@ -29,15 +29,11 @@ def range_normalization(data, rnge=(0, 1), per_channel=True, eps=1e-8): def min_max_normalization_batched(data, eps, axes): - mn = data.min(axis=axes) - mx = data.max(axis=axes) + mn = data.min(axis=axes, keepdims=True) + mx = data.max(axis=axes, keepdims=True) old_range = mx - mn + eps - data_normalized = ((data.T - mn.T) / old_range.T).T - # broadcast_axes = get_broadcast_axes(len(data.shape)) - # mn = reverse_broadcast(mn, broadcast_axes) - # old_range = reverse_broadcast(old_range, broadcast_axes) - # data_normalized = (data - mn) / old_range + data_normalized = (data - mn) / old_range return data_normalized @@ -55,9 +51,9 @@ def zero_mean_unit_variance_normalization(data, per_channel=True, epsilon=1e-8): else: axes = tuple(range(1, len(data.shape))) - mean = np.mean(data, axis=axes) - std = np.std(data, axis=axes) + epsilon - data_normalized = ((data.T - mean.T) / std.T).T + mean = np.mean(data, axis=axes, keepdims=True) + std = np.std(data, axis=axes, keepdims=True) + epsilon + data_normalized = (data - mean) / std return data_normalized @@ -85,8 +81,6 @@ def cut_off_outliers(data, percentile_lower=0.2, percentile_upper=99.8, per_chan else: axes = tuple(range(1, len(data.shape))) - cut_off_lower, cut_off_upper = np.percentile(data, (percentile_lower, percentile_upper), axis=axes) - cut_off_lower = np.expand_dims(cut_off_lower, axes) - cut_off_upper = np.expand_dims(cut_off_upper, axes) + cut_off_lower, cut_off_upper = np.percentile(data, (percentile_lower, percentile_upper), axis=axes, keepdims=True) np.clip(data, cut_off_lower, cut_off_upper, out=data) return data From 251e74ac084e7494623edbacae6c32cdc38ec3a2 Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Mon, 28 Aug 2023 16:29:14 +0300 Subject: [PATCH 39/60] Fixed augment mirroring again --- batchgenerators/augmentations/spatial_transformations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/batchgenerators/augmentations/spatial_transformations.py b/batchgenerators/augmentations/spatial_transformations.py index de3e89f..268b03f 100644 --- a/batchgenerators/augmentations/spatial_transformations.py +++ b/batchgenerators/augmentations/spatial_transformations.py @@ -124,7 +124,7 @@ def augment_mirroring_batched(sample_data, sample_seg=None, axes=(0, 1, 2)): if 1 in axes: mask = np.random.uniform(size=size) < 0.5 workon[:, mask] = np.flip(workon[:, mask], 4) - if 2 in axes and size == 5: + if 2 in axes and len(workon.shape) == 6: mask = np.random.uniform(size=size) < 0.5 workon[:, mask] = np.flip(workon[:, mask], 5) if sample_seg is None: From 5cf312b31ba6cb8a087fd5f3c604f2a7bf171c47 Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Wed, 30 Aug 2023 10:16:51 +0300 Subject: [PATCH 40/60] Sorting pd.unique when needed --- batchgenerators/augmentations/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/batchgenerators/augmentations/utils.py b/batchgenerators/augmentations/utils.py index aa7911f..b5febb7 100755 --- a/batchgenerators/augmentations/utils.py +++ b/batchgenerators/augmentations/utils.py @@ -74,7 +74,7 @@ def convert_seg_image_to_one_hot_encoding(image, classes=None): Prefer convert_seg_image_to_one_hot_encoding_batched. ''' if classes is None: - classes = pd.unique(image.reshape(-1)) + classes = np.sort(pd.unique(image.reshape(-1))) out_image = np.zeros((len(classes), *image.shape), dtype=image.dtype) for i, c in enumerate(classes): out_image[i][image == c] = 1 @@ -86,7 +86,7 @@ def convert_seg_image_to_one_hot_encoding_batched(image, classes=None): same as convert_seg_image_to_one_hot_encoding, but expects image to be (b, x, y, z) or (b, x, y) ''' if classes is None: - classes = pd.unique(image.reshape(-1)) + classes = np.sort(pd.unique(image.reshape(-1))) out_image = np.zeros((image.shape[0], len(classes), *image.shape[1:]), dtype=image.dtype) for i, c in enumerate(classes): out_image[:, i][image == c] = 1 @@ -163,7 +163,7 @@ def uncenter_coords(coords): def interpolate_img(img, coords, order=3, mode='nearest', cval=0.0, is_seg=False): if is_seg and order != 0: - unique_labels = pd.unique(img.reshape(-1)) + unique_labels = pd.unique(img.reshape(-1)) # does not need sorting result = np.zeros(coords.shape[1:], img.dtype) for c in unique_labels: res_new = map_coordinates((img == c).astype(float), coords, order=order, mode=mode, cval=cval) @@ -609,12 +609,12 @@ def resize_segmentation(segmentation, new_shape, order=3): :return: ''' tpe = segmentation.dtype - unique_labels = pd.unique(segmentation.reshape(-1)) assert len(segmentation.shape) == len(new_shape), "new shape must have same dimensionality as segmentation" if order == 0: return resize(segmentation.astype(float), new_shape, order, mode="edge", clip=True, anti_aliasing=False).astype( tpe) else: + unique_labels = pd.unique(segmentation.reshape(-1)) # does not need sorting reshaped = np.zeros(new_shape, dtype=segmentation.dtype) for c in unique_labels: From 8f7da406f1e8a2eb6eb51105e6c1626a62f622c2 Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Wed, 30 Aug 2023 10:41:23 +0300 Subject: [PATCH 41/60] Minimizing array copy when data was already np.ndarray --- batchgenerators/augmentations/crop_and_pad_augmentations.py | 4 ++-- batchgenerators/augmentations/spatial_transformations.py | 2 +- batchgenerators/augmentations/utils.py | 3 +-- batchgenerators/transforms/spatial_transforms.py | 2 +- 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/batchgenerators/augmentations/crop_and_pad_augmentations.py b/batchgenerators/augmentations/crop_and_pad_augmentations.py index c12013f..e416359 100644 --- a/batchgenerators/augmentations/crop_and_pad_augmentations.py +++ b/batchgenerators/augmentations/crop_and_pad_augmentations.py @@ -82,11 +82,11 @@ def crop(data: Union[Sequence[np.ndarray], np.ndarray], seg: Union[Sequence[np.n else: assert len(crop_size) == dim, ("If you provide a list/tuple as center crop make sure it has the same dimension " "as your data (2d/3d)") - crop_size = np.array(crop_size) + crop_size = np.asarray(crop_size) if not isinstance(margins, (np.ndarray, tuple, list)): margins = (margins,) * dim - margins = np.array(margins) + margins = np.asarray(margins) data_return = np.zeros((data_shape[0], data_shape[1], *crop_size), dtype=data_dtype) if seg is not None: diff --git a/batchgenerators/augmentations/spatial_transformations.py b/batchgenerators/augmentations/spatial_transformations.py index 268b03f..bd3f7d6 100644 --- a/batchgenerators/augmentations/spatial_transformations.py +++ b/batchgenerators/augmentations/spatial_transformations.py @@ -220,7 +220,7 @@ def augment_spatial(data, seg, patch_size, patch_center_dist_from_border=30, if not isinstance(patch_center_dist_from_border, (list, tuple, np.ndarray)): patch_center_dist_from_border = (patch_center_dist_from_border,) * dim - patch_center_dist_from_border = np.array(patch_center_dist_from_border) + patch_center_dist_from_border = np.asarray(patch_center_dist_from_border) for sample_id in range(data.shape[0]): coords = create_zero_centered_coordinate_mesh(patch_size) diff --git a/batchgenerators/augmentations/utils.py b/batchgenerators/augmentations/utils.py index b5febb7..49fb1b6 100755 --- a/batchgenerators/augmentations/utils.py +++ b/batchgenerators/augmentations/utils.py @@ -713,8 +713,7 @@ def pad_nd_image(image, new_shape=None, mode="constant", kwargs=None, return_sli num_axes_nopad = len(image.shape) - len(new_shape) - new_shape = [max(new_shape[i], old_shape[i]) for i in range(len(new_shape))] - new_shape = np.array(new_shape) + new_shape = np.maximum(new_shape, old_shape) if shape_must_be_divisible_by is not None: if not isinstance(shape_must_be_divisible_by, (list, tuple, np.ndarray)): diff --git a/batchgenerators/transforms/spatial_transforms.py b/batchgenerators/transforms/spatial_transforms.py index bec7357..bc5c0e6 100644 --- a/batchgenerators/transforms/spatial_transforms.py +++ b/batchgenerators/transforms/spatial_transforms.py @@ -325,7 +325,7 @@ def __init__(self, patch_size, patch_center_dist_from_border=30, self.p_independent_scale_per_axis = p_independent_scale_per_axis def __call__(self, **data_dict): - data = data_dict.get(self.data_key) + data = data_dict[self.data_key] seg = data_dict.get(self.label_key) if self.patch_size is None: From 03e6d2c0da268702f99b0a68f3e280ea6b57b397 Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Wed, 30 Aug 2023 12:23:51 +0300 Subject: [PATCH 42/60] Adjusted test crop --- batchgenerators/augmentations/utils.py | 1 - tests/test_crop.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/batchgenerators/augmentations/utils.py b/batchgenerators/augmentations/utils.py index 49fb1b6..5660aac 100755 --- a/batchgenerators/augmentations/utils.py +++ b/batchgenerators/augmentations/utils.py @@ -712,7 +712,6 @@ def pad_nd_image(image, new_shape=None, mode="constant", kwargs=None, return_sli old_shape = new_shape num_axes_nopad = len(image.shape) - len(new_shape) - new_shape = np.maximum(new_shape, old_shape) if shape_must_be_divisible_by is not None: diff --git a/tests/test_crop.py b/tests/test_crop.py index 3933a96..d6bbe2e 100644 --- a/tests/test_crop.py +++ b/tests/test_crop.py @@ -268,7 +268,7 @@ def test_pad_nd_image_and_seg_2D(self): print('Zero padding with new_shape.shape smaller than data.shape. [DONE]') print('Zero padding with new_shape.shape bigger than data.shape. [START]') - self.assertRaises(IndexError, pad_nd_image_and_seg, data, seg, new_shape=new_shape6) + self.assertRaises(ValueError, pad_nd_image_and_seg, data, seg, new_shape=new_shape6) print('Zero padding with new_shape.shape bigger than data.shape. [DONE]') print('Padding to bigger output shape in all dimensions with constant_value=1 for segmentation padding . [START]') @@ -352,7 +352,7 @@ def test_pad_nd_image_and_seg_3D(self): print('Zero padding with new_shape.shape smaller than data.shape. [DONE]') print('Zero padding with new_shape.shape bigger than data.shape. [START]') - self.assertRaises(IndexError, pad_nd_image_and_seg, data, seg, new_shape=new_shape6) + self.assertRaises(ValueError, pad_nd_image_and_seg, data, seg, new_shape=new_shape6) print('Zero padding with new_shape.shape bigger than data.shape. [DONE]') print('Padding to bigger output shape in all dimensions with constant_value=1 for segmentation padding . [START]') From d69737698ad79fb6aa981b3cd25599aafec9cfbe Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Thu, 31 Aug 2023 11:16:13 +0300 Subject: [PATCH 43/60] Numpy To Tensor Revisited, removed contiguous and from numpy calls and included them into the cast file --- .../transforms/utility_transforms.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/batchgenerators/transforms/utility_transforms.py b/batchgenerators/transforms/utility_transforms.py index 834afa4..6bb807a 100644 --- a/batchgenerators/transforms/utility_transforms.py +++ b/batchgenerators/transforms/utility_transforms.py @@ -49,43 +49,43 @@ def __init__(self, keys=None, cast_to=None): else: raise ValueError(f'Unknown value for cast_to: {cast_to}') - def cast(self, x): + def cast(self, x: np.ndarray) -> torch.Tensor: pass @staticmethod - def no_cast(x): - return x + def no_cast(x: np.ndarray) -> torch.Tensor: + return torch.from_numpy(x).contiguous() @staticmethod - def float_cast(x): - return x.to(torch.float) + def float_cast(x: np.ndarray) -> torch.Tensor: + return torch.from_numpy(x).to(torch.float, memory_format=torch.contiguous_format) @staticmethod - def long_cast(x): - return x.to(torch.long) + def long_cast(x: np.ndarray) -> torch.Tensor: + return torch.from_numpy(x).to(torch.long, memory_format=torch.contiguous_format) @staticmethod - def bool_cast( x): - return x.to(torch.bool) + def bool_cast(x: np.ndarray) -> torch.Tensor: + return torch.from_numpy(x).to(torch.bool, memory_format=torch.contiguous_format) @staticmethod - def half_cast(x): - return x.to(torch.half) + def half_cast(x: np.ndarray) -> torch.Tensor: + return torch.from_numpy(x).to(torch.half, memory_format=torch.contiguous_format) def __call__(self, **data_dict): if self.keys is None: for key, val in data_dict.items(): if isinstance(val, np.ndarray): - data_dict[key] = self.cast(torch.from_numpy(val)).contiguous() + data_dict[key] = self.cast(val) elif isinstance(val, (list, tuple)) and all([isinstance(i, np.ndarray) for i in val]): - data_dict[key] = [self.cast(torch.from_numpy(i)).contiguous() for i in val] + data_dict[key] = [self.cast(i) for i in val] else: for key in self.keys: if isinstance(data_dict[key], np.ndarray): - data_dict[key] = self.cast(torch.from_numpy(data_dict[key])).contiguous() + data_dict[key] = self.cast(data_dict[key]) elif isinstance(data_dict[key], (list, tuple)) and all( [isinstance(i, np.ndarray) for i in data_dict[key]]): - data_dict[key] = [self.cast(torch.from_numpy(i)).contiguous() for i in data_dict[key]] + data_dict[key] = [self.cast(i) for i in data_dict[key]] return data_dict From 4a1075d78f09adc5aa190a1092f0a1bcebba3bc6 Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Wed, 6 Sep 2023 10:56:16 +0300 Subject: [PATCH 44/60] Replaced 'get_range_val' with 'uniform' --- .../augmentations/noise_augmentations.py | 14 +++++++------- batchgenerators/augmentations/utils.py | 14 +++++++------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/batchgenerators/augmentations/noise_augmentations.py b/batchgenerators/augmentations/noise_augmentations.py index c2c6b72..157aaac 100644 --- a/batchgenerators/augmentations/noise_augmentations.py +++ b/batchgenerators/augmentations/noise_augmentations.py @@ -17,7 +17,7 @@ from typing import Tuple import numpy as np -from batchgenerators.augmentations.utils import get_range_val, mask_random_squares +from batchgenerators.augmentations.utils import mask_random_squares, uniform from builtins import range from scipy.ndimage import gaussian_filter @@ -66,18 +66,18 @@ def augment_gaussian_blur(data_sample: np.ndarray, sigma_range: Tuple[float, flo # Godzilla revived if not different_sigma_per_axis or np.random.uniform() < p_isotropic: - sigma = get_range_val(sigma_range) + sigma = uniform(sigma_range[0], sigma_range[1]) else: - sigma = [get_range_val(sigma_range) for _ in data_sample.shape[1:]] + sigma = [uniform(sigma_range[0], sigma_range[1]) for _ in data_sample.shape[1:]] else: sigma = None for c in range(data_sample.shape[0]): if np.random.uniform() <= p_per_channel: if per_channel: if not different_sigma_per_axis or np.random.uniform() < p_isotropic: - sigma = get_range_val(sigma_range) + sigma = uniform(sigma_range[0], sigma_range[1]) else: - sigma = [get_range_val(sigma_range) for _ in data_sample.shape[1:]] + sigma = [uniform(sigma_range[0], sigma_range[1]) for _ in data_sample.shape[1:]] data_sample[c] = gaussian_filter(data_sample[c], sigma, order=0) return data_sample @@ -86,8 +86,8 @@ def augment_gaussian_blur(data_sample: np.ndarray, sigma_range: Tuple[float, flo def augment_blank_square_noise(data_sample, square_size, n_squares, noise_val=(0, 0), channel_wise_n_val=False, square_pos=None): # rnd_n_val = get_range_val(noise_val) - rnd_square_size = get_range_val(square_size) - rnd_n_squares = get_range_val(n_squares) + rnd_square_size = uniform(square_size[0], square_size[1]) + rnd_n_squares = uniform(n_squares[0], n_squares[1]) data_sample = mask_random_squares(data_sample, square_size=rnd_square_size, n_squares=rnd_n_squares, n_val=noise_val, channel_wise_n_val=channel_wise_n_val, diff --git a/batchgenerators/augmentations/utils.py b/batchgenerators/augmentations/utils.py index 5660aac..cac82c9 100755 --- a/batchgenerators/augmentations/utils.py +++ b/batchgenerators/augmentations/utils.py @@ -674,7 +674,7 @@ def uniform(low, high, size=None): if size is None: return low else: - return np.ones(size) * low + return np.full(size, low) else: return np.random.uniform(low, high, size) @@ -740,7 +740,7 @@ def pad_nd_image(image, new_shape=None, mode="constant", kwargs=None, return_sli else: pad_list = np.array(pad_list) pad_list[:, 1] = np.array(res.shape) - pad_list[:, 1] - slicer = list(slice(*i) for i in pad_list) + slicer = [slice(*i) for i in pad_list] return res, slicer @@ -761,23 +761,23 @@ def mask_random_square(img, square_size, n_val, channel_wise_n_val=False, square h_start = pos_wh[1] if img.ndim == 2: - rnd_n_val = get_range_val(n_val) + rnd_n_val = uniform(n_val[0], n_val[1]) img[h_start:(h_start + square_size), w_start:(w_start + square_size)] = rnd_n_val elif img.ndim == 3: if channel_wise_n_val: for i in range(img.shape[0]): - rnd_n_val = get_range_val(n_val) + rnd_n_val = uniform(n_val[0], n_val[1]) img[i, h_start:(h_start + square_size), w_start:(w_start + square_size)] = rnd_n_val else: - rnd_n_val = get_range_val(n_val) + rnd_n_val = uniform(n_val[0], n_val[1]) img[:, h_start:(h_start + square_size), w_start:(w_start + square_size)] = rnd_n_val elif img.ndim == 4: if channel_wise_n_val: for i in range(img.shape[0]): - rnd_n_val = get_range_val(n_val) + rnd_n_val = uniform(n_val[0], n_val[1]) img[:, i, h_start:(h_start + square_size), w_start:(w_start + square_size)] = rnd_n_val else: - rnd_n_val = get_range_val(n_val) + rnd_n_val = uniform(n_val[0], n_val[1]) img[:, :, h_start:(h_start + square_size), w_start:(w_start + square_size)] = rnd_n_val return img From 5f890e84269630eb3bae2798ee41c1b365a207f9 Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Wed, 6 Sep 2023 12:22:48 +0300 Subject: [PATCH 45/60] Misc --- batchgenerators/augmentations/crop_and_pad_augmentations.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/batchgenerators/augmentations/crop_and_pad_augmentations.py b/batchgenerators/augmentations/crop_and_pad_augmentations.py index e416359..9bf8e82 100644 --- a/batchgenerators/augmentations/crop_and_pad_augmentations.py +++ b/batchgenerators/augmentations/crop_and_pad_augmentations.py @@ -94,7 +94,6 @@ def crop(data: Union[Sequence[np.ndarray], np.ndarray], seg: Union[Sequence[np.n else: seg_return = None - for b in range(data_shape[0]): data_first_dim = data[b].shape[0] data_shape_here = np.array(data[b].shape[1:]) @@ -119,11 +118,11 @@ def crop(data: Union[Sequence[np.ndarray], np.ndarray], seg: Union[Sequence[np.n ubs = np.minimum(data_shape_here, lbs_plus_crop_size) lbs = np.maximum(zero, lbs) - slicer_data = (slice(0, data_first_dim), *(slice(lbs[d], ubs[d]) for d in range(dim))) + slicer_data = (slice(0, data_first_dim), *[slice(lbs[d], ubs[d]) for d in range(dim)]) data_cropped = data[b][slicer_data] if seg_return is not None: - slicer_data = (slice(0, seg_first_dim), *slicer_data[1:]) + slicer_data = (slice(0, seg_first_dim),) + slicer_data[1:] seg_cropped = seg[b][slicer_data] if np.any(need_to_pad): From ffb582466143f2a946e9bd6360b1fa3c9c06eb9a Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Wed, 6 Sep 2023 12:41:48 +0300 Subject: [PATCH 46/60] Faster mirroring --- .../augmentations/spatial_transformations.py | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/batchgenerators/augmentations/spatial_transformations.py b/batchgenerators/augmentations/spatial_transformations.py index bd3f7d6..ae8d749 100644 --- a/batchgenerators/augmentations/spatial_transformations.py +++ b/batchgenerators/augmentations/spatial_transformations.py @@ -103,9 +103,8 @@ def augment_zoom(sample_data, sample_seg, zoom_factors, order=3, order_seg=1): sample_data = resize_multichannel_image(sample_data, target_shape_here, order) if sample_seg is not None: - target_seg = np.ones([sample_seg.shape[0]] + target_shape_here) - for c in range(sample_seg.shape[0]): - target_seg[c] = resize_segmentation(sample_seg[c], target_shape_here, order_seg) + target_seg = np.array([ + resize_segmentation(sample_seg[c], target_shape_here, order_seg) for c in range(sample_seg.shape[0])]) else: target_seg = None @@ -117,19 +116,23 @@ def augment_mirroring_batched(sample_data, sample_seg=None, axes=(0, 1, 2)): "Invalid dimension for sample_data and sample_seg. sample_data and sample_seg should be either " \ "[batch, channels, x, y] or [batch, channels, x, y, z]" size = len(sample_data) - workon = np.expand_dims(sample_data, 0) if sample_seg is None else np.stack((sample_data, sample_seg)) + has_sample_seg = sample_seg is not None if 0 in axes: mask = np.random.uniform(size=size) < 0.5 - workon[:, mask] = np.flip(workon[:, mask], 3) + sample_data[mask] = np.flip(sample_data[mask], 2) + if has_sample_seg: + sample_seg[mask] = np.flip(sample_seg[mask], 2) if 1 in axes: mask = np.random.uniform(size=size) < 0.5 - workon[:, mask] = np.flip(workon[:, mask], 4) - if 2 in axes and len(workon.shape) == 6: + sample_data[mask] = np.flip(sample_data[mask], 3) + if has_sample_seg: + sample_seg[mask] = np.flip(sample_seg[mask], 3) + if 2 in axes and len(sample_data.shape) == 5: mask = np.random.uniform(size=size) < 0.5 - workon[:, mask] = np.flip(workon[:, mask], 5) - if sample_seg is None: - return workon[0], None - return workon + sample_data[mask] = np.flip(sample_data[mask], 4) + if has_sample_seg: + sample_seg[mask] = np.flip(sample_seg[mask], 4) + return sample_data, sample_seg def augment_mirroring(sample_data, sample_seg=None, axes=(0, 1, 2)): From 64387297f824782c1edd5897716bf1f4286548c6 Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Wed, 6 Sep 2023 14:48:10 +0300 Subject: [PATCH 47/60] Replaced len(ndarray.shape) to ndarray.ndim --- .../augmentations/color_augmentations.py | 10 ++-- .../augmentations/normalizations.py | 14 +++--- .../augmentations/spatial_transformations.py | 20 ++++---- batchgenerators/augmentations/utils.py | 50 ++++++++----------- .../transforms/crop_and_pad_transforms.py | 4 +- .../transforms/local_transforms.py | 5 +- .../transforms/noise_transforms.py | 4 +- .../transforms/spatial_transforms.py | 8 +-- .../transforms/utility_transforms.py | 6 +-- 9 files changed, 57 insertions(+), 64 deletions(-) diff --git a/batchgenerators/augmentations/color_augmentations.py b/batchgenerators/augmentations/color_augmentations.py index 5ba4487..76c7d7d 100644 --- a/batchgenerators/augmentations/color_augmentations.py +++ b/batchgenerators/augmentations/color_augmentations.py @@ -60,8 +60,8 @@ def augment_contrast(data_sample: np.ndarray, mask = np.random.uniform(size=data_sample.shape[:2] if batched else data_sample.shape[0]) < p_per_channel if np.any(mask): workon = data_sample[mask] - factor = get_augment_contrast_factor(contrast_range, per_channel, len(workon), len(workon.shape)) - axes = tuple(range(1, len(workon.shape))) + factor = get_augment_contrast_factor(contrast_range, per_channel, len(workon), workon.ndim) + axes = tuple(range(1, workon.ndim)) mean = workon.mean(axis=axes, keepdims=True) if preserve_range: minm = workon.min(axis=axes, keepdims=True) @@ -92,7 +92,7 @@ def augment_brightness_additive(data_sample, mu: float, sigma: float, per_channe else: rnd_nb = np.repeat(np.random.normal(mu, sigma), size) rnd_nb[np.random.uniform(size=size) > p_per_channel] = 0.0 - data_sample += reverse_broadcast(rnd_nb, get_broadcast_axes(len(data_sample.shape))) + data_sample += reverse_broadcast(rnd_nb, get_broadcast_axes(data_sample.ndim)) return data_sample @@ -149,7 +149,7 @@ def augment_gamma(data_sample, gamma_range=(0.5, 2), invert_image=False, epsilon gamma.append(np.random.uniform(gamma_l, gamma_range[1])) gamma = np.array(gamma) - axes = tuple(range(1, len(data_sample.shape))) + axes = tuple(range(1, data_sample.ndim)) retain_any_stats = np.any(retain_stats_here) if retain_any_stats: @@ -159,7 +159,7 @@ def augment_gamma(data_sample, gamma_range=(0.5, 2), invert_image=False, epsilon minm = data_sample.min(axis=axes, keepdims=True) rnge = data_sample.max(axis=axes, keepdims=True) - minm + epsilon - broadcast_axes = get_broadcast_axes(len(data_sample.shape)) + broadcast_axes = get_broadcast_axes(data_sample.ndim) gamma = reverse_broadcast(gamma, broadcast_axes) # TODO: Remove data_sample = np.power((data_sample - minm) / rnge, gamma) * rnge + minm diff --git a/batchgenerators/augmentations/normalizations.py b/batchgenerators/augmentations/normalizations.py index d6d813e..2f56a99 100644 --- a/batchgenerators/augmentations/normalizations.py +++ b/batchgenerators/augmentations/normalizations.py @@ -18,9 +18,9 @@ def range_normalization(data, rnge=(0, 1), per_channel=True, eps=1e-8): if per_channel: - axes = tuple(range(2, len(data.shape))) + axes = tuple(range(2, data.ndim)) else: - axes = tuple(range(1, len(data.shape))) + axes = tuple(range(1, data.ndim)) data_normalized = min_max_normalization_batched(data, eps, axes) data_normalized *= (rnge[1] - rnge[0]) @@ -47,9 +47,9 @@ def min_max_normalization(data, eps): def zero_mean_unit_variance_normalization(data, per_channel=True, epsilon=1e-8): if per_channel: - axes = tuple(range(2, len(data.shape))) + axes = tuple(range(2, data.ndim)) else: - axes = tuple(range(1, len(data.shape))) + axes = tuple(range(1, data.ndim)) mean = np.mean(data, axis=axes, keepdims=True) std = np.std(data, axis=axes, keepdims=True) + epsilon @@ -67,7 +67,7 @@ def mean_std_normalization(data, mean, std, per_channel=True): assert len(mean) == channel_dimension assert len(std) == channel_dimension - broadcast_axes = tuple(range(2, len(data.shape))) + broadcast_axes = tuple(range(2, data.ndim)) mean = np.expand_dims(np.broadcast_to(mean, (len(data), len(mean))), broadcast_axes) std = np.expand_dims(np.broadcast_to(std, (len(data), len(std))), broadcast_axes) @@ -77,9 +77,9 @@ def mean_std_normalization(data, mean, std, per_channel=True): def cut_off_outliers(data, percentile_lower=0.2, percentile_upper=99.8, per_channel=False): if per_channel: - axes = tuple(range(2, len(data.shape))) + axes = tuple(range(2, data.ndim)) else: - axes = tuple(range(1, len(data.shape))) + axes = tuple(range(1, data.ndim)) cut_off_lower, cut_off_upper = np.percentile(data, (percentile_lower, percentile_upper), axis=axes, keepdims=True) np.clip(data, cut_off_lower, cut_off_upper, out=data) diff --git a/batchgenerators/augmentations/spatial_transformations.py b/batchgenerators/augmentations/spatial_transformations.py index ae8d749..4640a8d 100644 --- a/batchgenerators/augmentations/spatial_transformations.py +++ b/batchgenerators/augmentations/spatial_transformations.py @@ -56,7 +56,7 @@ def augment_resize(sample_data, sample_seg, target_size, order=3, order_seg=1): np.ndarray (just like data). Must also be (c, x, y(, z)) :return: """ - dimensionality = len(sample_data.shape) - 1 + dimensionality = sample_data.ndim - 1 if not isinstance(target_size, (list, tuple)): target_size_here = [target_size] * dimensionality else: @@ -90,7 +90,7 @@ def augment_zoom(sample_data, sample_seg, zoom_factors, order=3, order_seg=1): :return: """ - dimensionality = len(sample_data.shape) - 1 + dimensionality = sample_data.ndim - 1 shape = np.array(sample_data.shape[1:]) if not isinstance(zoom_factors, (list, tuple)): zoom_factors_here = np.array([zoom_factors] * dimensionality) @@ -112,7 +112,7 @@ def augment_zoom(sample_data, sample_seg, zoom_factors, order=3, order_seg=1): def augment_mirroring_batched(sample_data, sample_seg=None, axes=(0, 1, 2)): - assert len(sample_data.shape) == 5 or len(sample_data.shape) == 4, \ + assert sample_data.ndim == 5 or sample_data.ndim == 4, \ "Invalid dimension for sample_data and sample_seg. sample_data and sample_seg should be either " \ "[batch, channels, x, y] or [batch, channels, x, y, z]" size = len(sample_data) @@ -127,7 +127,7 @@ def augment_mirroring_batched(sample_data, sample_seg=None, axes=(0, 1, 2)): sample_data[mask] = np.flip(sample_data[mask], 3) if has_sample_seg: sample_seg[mask] = np.flip(sample_seg[mask], 3) - if 2 in axes and len(sample_data.shape) == 5: + if 2 in axes and sample_data.ndim == 5: mask = np.random.uniform(size=size) < 0.5 sample_data[mask] = np.flip(sample_data[mask], 4) if has_sample_seg: @@ -148,7 +148,7 @@ def augment_mirroring(sample_data, sample_seg=None, axes=(0, 1, 2)): sample_data[:, :, :] = sample_data[:, :, ::-1] if sample_seg is not None: sample_seg[:, :, :] = sample_seg[:, :, ::-1] - if 2 in axes and len(sample_data.shape) == 4: + if 2 in axes and sample_data.ndim == 4: if np.random.uniform() < 0.5: sample_data[:, :, :, :] = sample_data[:, :, :, ::-1] if sample_seg is not None: @@ -287,7 +287,7 @@ def augment_spatial(data, seg, patch_size, patch_center_dist_from_border=30, for d in range(dim): coords[d] += ctr[d] # vectorized version, seems a bit slower - # coords += reverse_broadcast(ctr, get_broadcast_axes(len(coords.shape))) + # coords += reverse_broadcast(ctr, get_broadcast_axes(coords.ndim)) for channel_id in range(data.shape[1]): data_result[sample_id, channel_id] = interpolate_img(data[sample_id, channel_id], coords, order_data, @@ -381,7 +381,7 @@ def augment_spatial_2(data, seg, patch_size, patch_center_dist_from_border=30, # one scale per case, scale is in percent of patch_size def_scale = np.random.uniform(deformation_scale[0], deformation_scale[1]) - for d in range(len(data[sample_id].shape) - 1): + for d in range(data[sample_id].ndim - 1): # transform relative def_scale in pixels sigmas.append(def_scale * patch_size[d]) @@ -446,7 +446,7 @@ def augment_spatial_2(data, seg, patch_size, patch_center_dist_from_border=30, # now find a nice center location if modified_coords: # recenter coordinates - coords_mean = coords.mean(axis=tuple(range(1, len(coords.shape))), keepdims=True) + coords_mean = coords.mean(axis=tuple(range(1, coords.ndim)), keepdims=True) coords -= coords_mean for d in range(dim): @@ -490,8 +490,8 @@ def augment_transpose_axes(data_sample, seg_sample, axes=(0, 1, 2)): """ axes = list(np.array(axes) + 1) # need list to allow shuffle; +1 to accomodate for color channel - assert np.max(axes) <= len(data_sample.shape), "axes must only contain valid axis ids" - static_axes = list(range(len(data_sample.shape))) + assert np.max(axes) <= data_sample.ndim, "axes must only contain valid axis ids" + static_axes = list(range(data_sample.ndim)) for i in axes: static_axes[i] = -1 np.random.shuffle(axes) diff --git a/batchgenerators/augmentations/utils.py b/batchgenerators/augmentations/utils.py index cac82c9..2b501b2 100755 --- a/batchgenerators/augmentations/utils.py +++ b/batchgenerators/augmentations/utils.py @@ -148,7 +148,7 @@ def rotate_coords_2d(coords, angle): def scale_coords(coords: np.ndarray, scale): if isinstance(scale, (tuple, list, np.ndarray)): - scale = reverse_broadcast(scale, get_broadcast_axes(len(coords.shape))) + scale = reverse_broadcast(scale, get_broadcast_axes(coords.ndim)) coords *= scale return coords @@ -189,11 +189,10 @@ def find_entries_in_array(entries, myarray): def center_crop_3D_image(img, crop_size): center = np.array(img.shape) / 2. if type(crop_size) not in (tuple, list): - center_crop = [int(crop_size)] * len(img.shape) + center_crop = [int(crop_size)] * img.ndim else: center_crop = crop_size - assert len(center_crop) == len( - img.shape), "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (3d)" + assert len(center_crop) == img.ndim, "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (3d)" return img[int(center[0] - center_crop[0] / 2.):int(center[0] + center_crop[0] / 2.), int(center[1] - center_crop[1] / 2.):int(center[1] + center_crop[1] / 2.), int(center[2] - center_crop[2] / 2.):int(center[2] + center_crop[2] / 2.)] @@ -203,11 +202,10 @@ def center_crop_3D_image_batched(img, crop_size): # dim 0 is batch, dim 1 is channel, dim 2, 3 and 4 are x y z center = np.array(img.shape[2:]) / 2. if type(crop_size) not in (tuple, list): - center_crop = [int(crop_size)] * (len(img.shape) - 2) + center_crop = [int(crop_size)] * (img.ndim - 2) else: center_crop = crop_size - assert len(center_crop) == (len( - img.shape) - 2), "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (3d)" + assert len(center_crop) == (img.ndim - 2), "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (3d)" return img[:, :, int(center[0] - center_crop[0] / 2.):int(center[0] + center_crop[0] / 2.), int(center[1] - center_crop[1] / 2.):int(center[1] + center_crop[1] / 2.), int(center[2] - center_crop[2] / 2.):int(center[2] + center_crop[2] / 2.)] @@ -216,11 +214,10 @@ def center_crop_3D_image_batched(img, crop_size): def center_crop_2D_image(img, crop_size): center = np.array(img.shape) / 2. if type(crop_size) not in (tuple, list): - center_crop = [int(crop_size)] * len(img.shape) + center_crop = [int(crop_size)] * img.ndim else: center_crop = crop_size - assert len(center_crop) == len( - img.shape), "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (2d)" + assert len(center_crop) == img.ndim, "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (2d)" return img[int(center[0] - center_crop[0] / 2.):int(center[0] + center_crop[0] / 2.), int(center[1] - center_crop[1] / 2.):int(center[1] + center_crop[1] / 2.)] @@ -229,21 +226,19 @@ def center_crop_2D_image_batched(img, crop_size): # dim 0 is batch, dim 1 is channel, dim 2 and 3 are x y center = np.array(img.shape[2:]) / 2. if type(crop_size) not in (tuple, list): - center_crop = [int(crop_size)] * (len(img.shape) - 2) + center_crop = [int(crop_size)] * (img.ndim - 2) else: center_crop = crop_size - assert len(center_crop) == (len( - img.shape) - 2), "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (2d)" + assert len(center_crop) == (img.ndim - 2), "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (2d)" return img[:, :, int(center[0] - center_crop[0] / 2.):int(center[0] + center_crop[0] / 2.), int(center[1] - center_crop[1] / 2.):int(center[1] + center_crop[1] / 2.)] def random_crop_3D_image(img, crop_size): if type(crop_size) not in (tuple, list): - crop_size = [crop_size] * len(img.shape) + crop_size = [crop_size] * img.ndim else: - assert len(crop_size) == len( - img.shape), "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (3d)" + assert len(crop_size) == img.ndim, "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (3d)" if crop_size[0] < img.shape[0]: lb_x = np.random.randint(0, img.shape[0] - crop_size[0]) @@ -271,10 +266,9 @@ def random_crop_3D_image(img, crop_size): def random_crop_3D_image_batched(img, crop_size): if type(crop_size) not in (tuple, list): - crop_size = [crop_size] * (len(img.shape) - 2) + crop_size = [crop_size] * (img.ndim - 2) else: - assert len(crop_size) == (len( - img.shape) - 2), "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (3d)" + assert len(crop_size) == (img.ndim - 2), "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (3d)" if crop_size[0] < img.shape[2]: lb_x = np.random.randint(0, img.shape[2] - crop_size[0]) @@ -302,10 +296,9 @@ def random_crop_3D_image_batched(img, crop_size): def random_crop_2D_image(img, crop_size): if type(crop_size) not in (tuple, list): - crop_size = [crop_size] * len(img.shape) + crop_size = [crop_size] * img.ndim else: - assert len(crop_size) == len( - img.shape), "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (2d)" + assert len(crop_size) == img.ndim, "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (2d)" if crop_size[0] < img.shape[0]: lb_x = np.random.randint(0, img.shape[0] - crop_size[0]) @@ -326,10 +319,9 @@ def random_crop_2D_image(img, crop_size): def random_crop_2D_image_batched(img, crop_size): if type(crop_size) not in (tuple, list): - crop_size = [crop_size] * (len(img.shape) - 2) + crop_size = [crop_size] * (img.ndim - 2) else: - assert len(crop_size) == (len( - img.shape) - 2), "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (2d)" + assert len(crop_size) == (img.ndim - 2), "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (2d)" if crop_size[0] < img.shape[2]: lb_x = np.random.randint(0, img.shape[2] - crop_size[0]) @@ -590,9 +582,9 @@ def convert_seg_to_bounding_box_coordinates(data_dict, dim, get_rois_from_seg_fl def transpose_channels(batch): - if len(batch.shape) == 4: + if batch.ndim == 4: return np.transpose(batch, axes=[0, 2, 3, 1]) - elif len(batch.shape) == 5: + elif batch.ndim == 5: return np.transpose(batch, axes=[0, 4, 2, 3, 1]) else: raise ValueError("wrong dimensions in transpose_channel generator!") @@ -609,7 +601,7 @@ def resize_segmentation(segmentation, new_shape, order=3): :return: ''' tpe = segmentation.dtype - assert len(segmentation.shape) == len(new_shape), "new shape must have same dimensionality as segmentation" + assert segmentation.ndim == len(new_shape), "new shape must have same dimensionality as segmentation" if order == 0: return resize(segmentation.astype(float), new_shape, order, mode="edge", clip=True, anti_aliasing=False).astype( tpe) @@ -711,7 +703,7 @@ def pad_nd_image(image, new_shape=None, mode="constant", kwargs=None, return_sli new_shape = image.shape[-len(shape_must_be_divisible_by):] old_shape = new_shape - num_axes_nopad = len(image.shape) - len(new_shape) + num_axes_nopad = image.ndim - len(new_shape) new_shape = np.maximum(new_shape, old_shape) if shape_must_be_divisible_by is not None: diff --git a/batchgenerators/transforms/crop_and_pad_transforms.py b/batchgenerators/transforms/crop_and_pad_transforms.py index 9735774..ae1836b 100644 --- a/batchgenerators/transforms/crop_and_pad_transforms.py +++ b/batchgenerators/transforms/crop_and_pad_transforms.py @@ -128,7 +128,7 @@ def __call__(self, **data_dict): data = data_dict.get(self.data_key) seg = data_dict.get(self.label_key) - assert len(self.new_size) + 2 == len(data.shape), "new size must be a tuple/list/np.ndarray with shape " \ + assert len(self.new_size) + 2 == data.ndim, "new size must be a tuple/list/np.ndarray with shape " \ "(x, y(, z))" data, seg = pad_nd_image_and_seg(data, seg, self.new_size, None, np_pad_kwargs_data=self.np_pad_kwargs_data, @@ -180,7 +180,7 @@ def __call__(self, **data_dict): for c in range(workon.shape[1]): if np.random.uniform(0, 1) < self.p_per_channel: shift_here = [] - for d in range(len(workon.shape) - 2): + for d in range(workon.ndim - 2): shift_here.append(int(np.round(np.random.normal( self.shift_mu[d] if isinstance(self.shift_mu, (list, tuple)) else self.shift_mu, self.shift_sigma[d] if isinstance(self.shift_sigma, diff --git a/batchgenerators/transforms/local_transforms.py b/batchgenerators/transforms/local_transforms.py index 42363aa..4873b7e 100644 --- a/batchgenerators/transforms/local_transforms.py +++ b/batchgenerators/transforms/local_transforms.py @@ -255,14 +255,15 @@ def __call__(self, **data_dict): def _apply_gamma_gradient(self, img: np.ndarray, kernel: np.ndarray) -> np.ndarray: # store keep original image range mn, mx = img.min(), img.max() + rng = mx - mn # rescale tp [0, 1] - img = (img - mn) / (max(mx - mn, 1e-8)) + img = (img - mn) / (max(rng, 1e-8)) gamma = sample_scalar(self.gamma) img_modified = np.power(img, gamma) - return self.run_interpolation(img, img_modified, kernel) * (mx - mn) + mn + return self.run_interpolation(img, img_modified, kernel) * rng + mn class LocalSmoothingTransform(LocalTransform): diff --git a/batchgenerators/transforms/noise_transforms.py b/batchgenerators/transforms/noise_transforms.py index a214a78..e0f2414 100644 --- a/batchgenerators/transforms/noise_transforms.py +++ b/batchgenerators/transforms/noise_transforms.py @@ -314,7 +314,7 @@ def __call__(self, **data_dict): mn, mx = data[b].min(), data[b].max() strength_here = self.strength if isinstance(self.strength, float) else np.random.uniform( *self.strength) - if len(data.shape) == 4: + if data.ndim == 4: filter_here = self.filter_2d * strength_here filter_here[1, 1] += 1 else: @@ -333,7 +333,7 @@ def __call__(self, **data_dict): mn, mx = data[b, c].min(), data[b, c].max() strength_here = self.strength if isinstance(self.strength, float) else np.random.uniform( *self.strength) - if len(data.shape) == 4: + if data.ndim == 4: filter_here = self.filter_2d * strength_here filter_here[1, 1] += 1 else: diff --git a/batchgenerators/transforms/spatial_transforms.py b/batchgenerators/transforms/spatial_transforms.py index bc5c0e6..7adc19a 100644 --- a/batchgenerators/transforms/spatial_transforms.py +++ b/batchgenerators/transforms/spatial_transforms.py @@ -329,9 +329,9 @@ def __call__(self, **data_dict): seg = data_dict.get(self.label_key) if self.patch_size is None: - if len(data.shape) == 4: + if data.ndim == 4: patch_size = (data.shape[2], data.shape[3]) - elif len(data.shape) == 5: + elif data.ndim == 5: patch_size = (data.shape[2], data.shape[3], data.shape[4]) else: raise ValueError("only support 2D/3D batch data.") @@ -444,9 +444,9 @@ def __call__(self, **data_dict): seg = data_dict.get(self.label_key) if self.patch_size is None: - if len(data.shape) == 4: + if data.ndim == 4: patch_size = (data.shape[2], data.shape[3]) - elif len(data.shape) == 5: + elif data.ndim == 5: patch_size = (data.shape[2], data.shape[3], data.shape[4]) else: raise ValueError("only support 2D/3D batch data.") diff --git a/batchgenerators/transforms/utility_transforms.py b/batchgenerators/transforms/utility_transforms.py index 6bb807a..a4acc1f 100644 --- a/batchgenerators/transforms/utility_transforms.py +++ b/batchgenerators/transforms/utility_transforms.py @@ -455,13 +455,13 @@ def __call__(self, **data_dict): if data is None: print("WARNING in ConvertToChannelLastTransform: data_dict has no key named", k) else: - if len(data.shape) == 4: + if data.ndim == 4: new_ordering = (0, 2, 3, 1) - elif len(data.shape) == 5: + elif data.ndim == 5: new_ordering = (0, 2, 3, 4, 1) else: raise RuntimeError("unsupported dimensionality for ConvertToChannelLastTransform:", - len(data.shape), + data.ndim, ". Only 2d (b, c, x, y) and 3d (b, c, x, y, z) are supported for now.") assert isinstance(data, np.ndarray), "data_dict[k] must be a numpy array" data = data.transpose(new_ordering) From 8ffb23a7ae2ff3acf474eeee38e9f34b082347b5 Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Mon, 11 Sep 2023 11:49:52 +0300 Subject: [PATCH 48/60] Rmoved callable feature from augment contrast and augment gamma. * + optimized augment brightness multiplicative --- .../augmentations/color_augmentations.py | 72 ++++++++----------- .../augmentations/spatial_transformations.py | 2 +- batchgenerators/augmentations/utils.py | 2 +- .../transforms/color_transforms.py | 8 +-- .../transforms/spatial_transforms.py | 4 +- 5 files changed, 36 insertions(+), 52 deletions(-) diff --git a/batchgenerators/augmentations/color_augmentations.py b/batchgenerators/augmentations/color_augmentations.py index 76c7d7d..ddee8b3 100644 --- a/batchgenerators/augmentations/color_augmentations.py +++ b/batchgenerators/augmentations/color_augmentations.py @@ -21,38 +21,31 @@ reverse_broadcast -def get_augment_contrast_factor(contrast_range: Union[Tuple[float, float], Callable[[], float]], +def get_augment_contrast_factor(contrast_range: Tuple[float, float], per_channel: bool, size: int, broadcast_size: int): - # TODO: callable contrast_range is not used. Remove this feature. if per_channel: - if callable(contrast_range): - factor = [contrast_range() for _ in range(size)] - else: - factor = [] - contrast_l = max(contrast_range[0], 1) - for _ in range(size): - if contrast_range[0] < 1 and np.random.random() < 0.5: - factor.append(np.random.uniform(contrast_range[0], 1)) - else: - factor.append(np.random.uniform(contrast_l, contrast_range[1])) + factor = [] + contrast_l = max(contrast_range[0], 1) + for _ in range(size): + if contrast_range[0] < 1 and np.random.random() < 0.5: + factor.append(np.random.uniform(contrast_range[0], 1)) + else: + factor.append(np.random.uniform(contrast_l, contrast_range[1])) factor = reverse_broadcast(np.array(factor), get_broadcast_axes(broadcast_size)) else: - if callable(contrast_range): - factor = contrast_range() + if contrast_range[0] < 1 and np.random.random() < 0.5: + factor = np.random.uniform(contrast_range[0], 1) else: - if contrast_range[0] < 1 and np.random.random() < 0.5: - factor = np.random.uniform(contrast_range[0], 1) - else: - factor = np.random.uniform(max(contrast_range[0], 1), contrast_range[1]) + factor = np.random.uniform(max(contrast_range[0], 1), contrast_range[1]) return factor def augment_contrast(data_sample: np.ndarray, - contrast_range: Union[Tuple[float, float], Callable[[], float]] = (0.75, 1.25), + contrast_range: Tuple[float, float] = (0.75, 1.25), preserve_range: bool = True, per_channel: bool = True, p_per_channel: float = 1, @@ -96,30 +89,29 @@ def augment_brightness_additive(data_sample, mu: float, sigma: float, per_channe return data_sample -def setup_augment_brightness_multiplicative(per_channel: bool, batched: bool, shape: Tuple[int]): +def setup_augment_brightness_multiplicative(per_channel: bool, batched: bool, shape: Tuple[int, ...]): if per_channel: if batched: - return shape[:2], tuple(range(2, len(shape))) - return shape[0], tuple(range(1, len(shape))) + return shape[:2] + (1,) * (len(shape) - 2) + return (shape[0],) + (1,) * (len(shape) - 1) if batched: - return shape[0], tuple(range(1, len(shape))) - return 1, tuple(range(1, len(shape))) + return (shape[0],) + (1,) * (len(shape) - 1) + return (1,) * len(shape) def augment_brightness_multiplicative(data_sample, multiplier_range=(0.5, 2), per_channel=True, batched=False): - size, axes = setup_augment_brightness_multiplicative(per_channel, batched, data_sample.shape) - data_sample *= np.expand_dims(np.random.uniform(multiplier_range[0], multiplier_range[1], size=size), axes) + size = setup_augment_brightness_multiplicative(per_channel, batched, data_sample.shape) + data_sample *= np.random.uniform(multiplier_range[0], multiplier_range[1], size=size) return data_sample def augment_gamma(data_sample, gamma_range=(0.5, 2), invert_image=False, epsilon=1e-7, per_channel=False, - retain_stats: Union[bool, Callable[[], bool]] = False): + retain_stats: bool = False): if invert_image: data_sample = - data_sample if not per_channel: - retain_stats_here = retain_stats() if callable(retain_stats) else retain_stats - if retain_stats_here: + if retain_stats: mn = data_sample.mean() sd = data_sample.std() if gamma_range[0] < 1 and np.random.random() < 0.5: @@ -129,17 +121,12 @@ def augment_gamma(data_sample, gamma_range=(0.5, 2), invert_image=False, epsilon minm = data_sample.min() rnge = data_sample.max() - minm data_sample = np.power(((data_sample - minm) / float(rnge + epsilon)), gamma) * rnge + minm - if retain_stats_here: + if retain_stats: data_sample -= data_sample.mean() data_sample *= sd / (data_sample.std() + 1e-8) data_sample += mn else: shape_0 = data_sample.shape[0] - if callable(retain_stats): - retain_stats_here = [retain_stats() for _ in range(shape_0)] - else: - retain_stats_here = (retain_stats,) * shape_0 - retain_stats_here = np.array(retain_stats_here) gamma = [] gamma_l = max(gamma_range[0], 1) for i in range(shape_0): @@ -151,10 +138,9 @@ def augment_gamma(data_sample, gamma_range=(0.5, 2), invert_image=False, epsilon axes = tuple(range(1, data_sample.ndim)) - retain_any_stats = np.any(retain_stats_here) - if retain_any_stats: - mn = data_sample[retain_stats_here].mean(axis=axes, keepdims=True) - sd = data_sample[retain_stats_here].mean(axis=axes, keepdims=True) + if retain_stats: + mn = data_sample.mean(axis=axes, keepdims=True) + sd = data_sample.mean(axis=axes, keepdims=True) minm = data_sample.min(axis=axes, keepdims=True) rnge = data_sample.max(axis=axes, keepdims=True) - minm + epsilon @@ -163,10 +149,10 @@ def augment_gamma(data_sample, gamma_range=(0.5, 2), invert_image=False, epsilon gamma = reverse_broadcast(gamma, broadcast_axes) # TODO: Remove data_sample = np.power((data_sample - minm) / rnge, gamma) * rnge + minm - if retain_any_stats: - data_sample[retain_stats_here] -= data_sample[retain_stats_here].mean(axis=axes, keepdims=True) - data_sample[retain_stats_here] *= sd / (data_sample[retain_stats_here].std(axis=axes, keepdims=True) + 1e-8) - data_sample[retain_stats_here] += mn + if retain_stats: + data_sample -= data_sample.mean(axis=axes, keepdims=True) + data_sample *= sd / (data_sample.std(axis=axes, keepdims=True) + 1e-8) + data_sample += mn if invert_image: data_sample = - data_sample diff --git a/batchgenerators/augmentations/spatial_transformations.py b/batchgenerators/augmentations/spatial_transformations.py index 4640a8d..62a9297 100644 --- a/batchgenerators/augmentations/spatial_transformations.py +++ b/batchgenerators/augmentations/spatial_transformations.py @@ -303,7 +303,7 @@ def augment_spatial(data, seg, patch_size, patch_center_dist_from_border=30, else: s = seg[sample_id:sample_id + 1] if random_crop: - margin = patch_center_dist_from_border - np.array(patch_size) // 2 + margin = patch_center_dist_from_border - np.asarray(patch_size) // 2 d, s = random_crop_aug(data[sample_id:sample_id + 1], s, patch_size, margin) else: d, s = center_crop_aug(data[sample_id:sample_id + 1], patch_size, s) diff --git a/batchgenerators/augmentations/utils.py b/batchgenerators/augmentations/utils.py index 2b501b2..dae0d83 100755 --- a/batchgenerators/augmentations/utils.py +++ b/batchgenerators/augmentations/utils.py @@ -607,7 +607,7 @@ def resize_segmentation(segmentation, new_shape, order=3): tpe) else: unique_labels = pd.unique(segmentation.reshape(-1)) # does not need sorting - reshaped = np.zeros(new_shape, dtype=segmentation.dtype) + reshaped = np.zeros(new_shape, dtype=tpe) for c in unique_labels: mask = segmentation == c diff --git a/batchgenerators/transforms/color_transforms.py b/batchgenerators/transforms/color_transforms.py index 93c480e..f81152f 100644 --- a/batchgenerators/transforms/color_transforms.py +++ b/batchgenerators/transforms/color_transforms.py @@ -24,7 +24,7 @@ class ContrastAugmentationTransform(AbstractTransform): def __init__(self, - contrast_range: Union[Tuple[float, float], Callable[[], float]] = (0.75, 1.25), + contrast_range: Tuple[float, float] = (0.75, 1.25), preserve_range: bool = True, per_channel: bool = True, data_key: str = "data", @@ -36,7 +36,6 @@ def __init__(self, (float, float): range from which to sample a random contrast that is applied to the data. If one value is smaller and one is larger than 1, half of the contrast modifiers will be >1 and the other half <1 (in the inverval that was specified) - callable : must be contrast_range() -> float :param preserve_range: if True then the intensity values after contrast augmentation will be cropped to min and max values of the data before augmentation. :param per_channel: whether to use the same contrast modifier for all color channels or a separate one for each @@ -134,7 +133,7 @@ def __call__(self, **data_dict): class GammaTransform(AbstractTransform): def __init__(self, gamma_range=(0.5, 2), invert_image=False, per_channel=False, data_key="data", - retain_stats: Union[bool, Callable[[], bool]] = False, p_per_sample=1): + retain_stats: bool = False, p_per_sample=1): """ Augments by changing 'gamma' of the image (same as gamma correction in photos or computer monitors @@ -146,8 +145,7 @@ def __init__(self, gamma_range=(0.5, 2), invert_image=False, per_channel=False, :param per_channel: :param data_key: :param retain_stats: Gamma transformation will alter the mean and std of the data in the patch. If retain_stats=True, - the data will be transformed to match the mean and standard deviation before gamma augmentation. retain_stats - can also be callable (signature retain_stats() -> bool) + the data will be transformed to match the mean and standard deviation before gamma augmentation. :param p_per_sample: """ self.p_per_sample = p_per_sample diff --git a/batchgenerators/transforms/spatial_transforms.py b/batchgenerators/transforms/spatial_transforms.py index 7adc19a..781fe3f 100644 --- a/batchgenerators/transforms/spatial_transforms.py +++ b/batchgenerators/transforms/spatial_transforms.py @@ -330,9 +330,9 @@ def __call__(self, **data_dict): if self.patch_size is None: if data.ndim == 4: - patch_size = (data.shape[2], data.shape[3]) + patch_size = data.shape[2:4] elif data.ndim == 5: - patch_size = (data.shape[2], data.shape[3], data.shape[4]) + patch_size = data.shape[2:5] else: raise ValueError("only support 2D/3D batch data.") else: From ca82730fd9cf3e3f9504b9a3474e0ebad48e75c3 Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Tue, 12 Sep 2023 15:02:16 +0300 Subject: [PATCH 49/60] Using flynt for conversion to fstring --- .../dataloading/multi_threaded_augmenter.py | 4 ++-- .../dataloading/nondet_multi_threaded_augmenter.py | 2 +- batchgenerators/transforms/abstract_transforms.py | 4 ++-- batchgenerators/transforms/local_transforms.py | 10 +++++----- batchgenerators/transforms/utility_transforms.py | 7 +++---- 5 files changed, 13 insertions(+), 14 deletions(-) diff --git a/batchgenerators/dataloading/multi_threaded_augmenter.py b/batchgenerators/dataloading/multi_threaded_augmenter.py index 6006fbe..8fde7b9 100755 --- a/batchgenerators/dataloading/multi_threaded_augmenter.py +++ b/batchgenerators/dataloading/multi_threaded_augmenter.py @@ -61,7 +61,7 @@ def producer(queue, data_loader, transform, thread_id, seed, abort_event, wait_t abort_event.set() return except Exception as e: - print("Exception in background worker %d:\n" % thread_id, e) + print(f"Exception in background worker {thread_id}:\n", e) traceback.print_exc() abort_event.set() return @@ -216,7 +216,7 @@ def __next__(self): return item except KeyboardInterrupt: - logging.error("MultiThreadedGenerator: caught exception: {}".format(sys.exc_info())) + logging.error(f"MultiThreadedGenerator: caught exception: {sys.exc_info()}") self.abort_event.set() self._finish() raise KeyboardInterrupt diff --git a/batchgenerators/dataloading/nondet_multi_threaded_augmenter.py b/batchgenerators/dataloading/nondet_multi_threaded_augmenter.py index c2b0317..25c6e12 100755 --- a/batchgenerators/dataloading/nondet_multi_threaded_augmenter.py +++ b/batchgenerators/dataloading/nondet_multi_threaded_augmenter.py @@ -68,7 +68,7 @@ def producer(queue: Queue, data_loader, transform, thread_id: int, seed, return except Exception as e: - print("Exception in background worker %d:\n" % thread_id, e) + print(f"Exception in background worker {thread_id}:\n", e) traceback.print_exc() abort_event.set() return diff --git a/batchgenerators/transforms/abstract_transforms.py b/batchgenerators/transforms/abstract_transforms.py index 2a535cc..69902f5 100644 --- a/batchgenerators/transforms/abstract_transforms.py +++ b/batchgenerators/transforms/abstract_transforms.py @@ -28,7 +28,7 @@ def __call__(self, **data_dict): def __repr__(self): ret_str = str(type(self).__name__) + "( " + ", ".join( - [key + " = " + repr(val) for key, val in self.__dict__.items()]) + " )" + [f"{key} = {repr(val)}" for key, val in self.__dict__.items()]) + " )" return ret_str @@ -89,4 +89,4 @@ def __call__(self, **data_dict): return data_dict def __repr__(self): - return str(type(self).__name__) + " ( " + repr(self.transforms) + " )" + return f"{str(type(self).__name__)} ( {repr(self.transforms)} )" diff --git a/batchgenerators/transforms/local_transforms.py b/batchgenerators/transforms/local_transforms.py index 4873b7e..9c69dff 100644 --- a/batchgenerators/transforms/local_transforms.py +++ b/batchgenerators/transforms/local_transforms.py @@ -150,7 +150,7 @@ def __init__(self, def __call__(self, **data_dict): data = data_dict.get(self.data_key) - assert data is not None, "Could not find data key '%s'" % self.data_key + assert data is not None, f"Could not find data key '{self.data_key}'" b, c, *img_shape = data.shape for bi in range(b): if np.random.uniform() < self.p_per_sample: @@ -235,7 +235,7 @@ def __init__(self, def __call__(self, **data_dict): data = data_dict.get(self.data_key) - assert data is not None, "Could not find data key '%s'" % self.data_key + assert data is not None, f"Could not find data key '{self.data_key}'" b, c, *img_shape = data.shape for bi in range(b): if np.random.uniform() < self.p_per_sample: @@ -302,7 +302,7 @@ def __init__(self, def __call__(self, **data_dict): data = data_dict.get(self.data_key) - assert data is not None, "Could not find data key '%s'" % self.data_key + assert data is not None, f"Could not find data key '{self.data_key}'" b, c, *img_shape = data.shape for bi in range(b): if np.random.uniform() < self.p_per_sample: @@ -324,7 +324,7 @@ def _apply_local_smoothing(self, img: np.ndarray, kernel: np.ndarray) -> np.ndar kernel = np.copy(kernel) smoothing = sample_scalar(self.smoothing_strength) - assert 0 <= smoothing <= 1, 'smoothing_strength must be between 0 and 1, is %f' % smoothing + assert 0 <= smoothing <= 1, f'smoothing_strength must be between 0 and 1, is {smoothing}' # prepare kernel by rescaling it to gamma_range # kernel is already [0, 1] @@ -354,7 +354,7 @@ def __init__(self, def __call__(self, **data_dict): data = data_dict.get(self.data_key) - assert data is not None, "Could not find data key '%s'" % self.data_key + assert data is not None, f"Could not find data key '{self.data_key}'" b, c, *img_shape = data.shape for bi in range(b): if np.random.uniform() < self.p_per_sample: diff --git a/batchgenerators/transforms/utility_transforms.py b/batchgenerators/transforms/utility_transforms.py index a4acc1f..59db137 100644 --- a/batchgenerators/transforms/utility_transforms.py +++ b/batchgenerators/transforms/utility_transforms.py @@ -219,8 +219,7 @@ def __call__(self, **data_dict): if not seg.shape[1] % self.output_channels == 0: from warnings import warn warn( - "Calling ConvertMultiSegToArgmaxTransform but number of input channels {} cannot be divided into {} output channels.".format( - seg.shape[1], self.output_channels)) + f"Calling ConvertMultiSegToArgmaxTransform but number of input channels {seg.shape[1]} cannot be divided into {self.output_channels} output channels.") n_labels = seg.shape[1] // self.output_channels target_size = list(seg.shape) target_size[1] = self.output_channels @@ -355,7 +354,7 @@ def __call__(self, **data_dict): return new_dict def __repr__(self): - return str(type(self).__name__) + " ( " + repr(self.transforms) + " )" + return f"{str(type(self).__name__)} ( {repr(self.transforms)} )" class ReshapeTransform(AbstractTransform): @@ -422,7 +421,7 @@ def __call__(self, **data_dict): inp = data_dict.get(self.input_key) outp = data_dict.get(self.output_key) - assert inp is not None, "input_key %s is not present in data_dict" % self.input_key + assert inp is not None, f"input_key {self.input_key} is not present in data_dict" selected_channels = inp[:, self.channel_indexes] From 3733e75623485e863881c280772764b5fe26e6c8 Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Tue, 12 Sep 2023 15:08:09 +0300 Subject: [PATCH 50/60] Optimize imports --- batchgenerators/augmentations/color_augmentations.py | 4 ++-- batchgenerators/augmentations/crop_and_pad_augmentations.py | 1 - batchgenerators/augmentations/noise_augmentations.py | 1 - batchgenerators/augmentations/resample_augmentations.py | 1 - batchgenerators/augmentations/spatial_transformations.py | 2 -- batchgenerators/dataloading/data_loader.py | 1 - .../dataloading/nondet_multi_threaded_augmenter.py | 1 - batchgenerators/transforms/color_transforms.py | 2 +- batchgenerators/utilities/custom_types.py | 4 ++-- 9 files changed, 5 insertions(+), 12 deletions(-) diff --git a/batchgenerators/augmentations/color_augmentations.py b/batchgenerators/augmentations/color_augmentations.py index ddee8b3..0db8c87 100644 --- a/batchgenerators/augmentations/color_augmentations.py +++ b/batchgenerators/augmentations/color_augmentations.py @@ -13,10 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from builtins import range -from typing import Tuple, Union, Callable +from typing import Tuple import numpy as np + from batchgenerators.augmentations.utils import general_cc_var_num_channels, illumination_jitter, get_broadcast_axes, \ reverse_broadcast diff --git a/batchgenerators/augmentations/crop_and_pad_augmentations.py b/batchgenerators/augmentations/crop_and_pad_augmentations.py index 9bf8e82..1ba4f0a 100644 --- a/batchgenerators/augmentations/crop_and_pad_augmentations.py +++ b/batchgenerators/augmentations/crop_and_pad_augmentations.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from builtins import range import numpy as np from batchgenerators.augmentations.utils import pad_nd_image from typing import Union, Sequence diff --git a/batchgenerators/augmentations/noise_augmentations.py b/batchgenerators/augmentations/noise_augmentations.py index 157aaac..add8b1f 100644 --- a/batchgenerators/augmentations/noise_augmentations.py +++ b/batchgenerators/augmentations/noise_augmentations.py @@ -18,7 +18,6 @@ import numpy as np from batchgenerators.augmentations.utils import mask_random_squares, uniform -from builtins import range from scipy.ndimage import gaussian_filter diff --git a/batchgenerators/augmentations/resample_augmentations.py b/batchgenerators/augmentations/resample_augmentations.py index ad7e530..fe28a46 100644 --- a/batchgenerators/augmentations/resample_augmentations.py +++ b/batchgenerators/augmentations/resample_augmentations.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from builtins import range import numpy as np from skimage.transform import resize from batchgenerators.augmentations.utils import uniform diff --git a/batchgenerators/augmentations/spatial_transformations.py b/batchgenerators/augmentations/spatial_transformations.py index 62a9297..37dad0f 100644 --- a/batchgenerators/augmentations/spatial_transformations.py +++ b/batchgenerators/augmentations/spatial_transformations.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from builtins import range - import numpy as np from batchgenerators.augmentations.utils import create_zero_centered_coordinate_mesh, elastic_deform_coordinates, \ interpolate_img, \ diff --git a/batchgenerators/dataloading/data_loader.py b/batchgenerators/dataloading/data_loader.py index fc9f84f..71dcbf0 100644 --- a/batchgenerators/dataloading/data_loader.py +++ b/batchgenerators/dataloading/data_loader.py @@ -14,7 +14,6 @@ # limitations under the License. from abc import ABCMeta, abstractmethod -from builtins import object import warnings from collections import OrderedDict from warnings import warn diff --git a/batchgenerators/dataloading/nondet_multi_threaded_augmenter.py b/batchgenerators/dataloading/nondet_multi_threaded_augmenter.py index 25c6e12..b85d2d6 100755 --- a/batchgenerators/dataloading/nondet_multi_threaded_augmenter.py +++ b/batchgenerators/dataloading/nondet_multi_threaded_augmenter.py @@ -17,7 +17,6 @@ from copy import deepcopy from typing import List, Union import threading -from builtins import range from multiprocessing import Process from multiprocessing import Queue from queue import Queue as thrQueue diff --git a/batchgenerators/transforms/color_transforms.py b/batchgenerators/transforms/color_transforms.py index f81152f..6b75b66 100644 --- a/batchgenerators/transforms/color_transforms.py +++ b/batchgenerators/transforms/color_transforms.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union, Tuple, Callable +from typing import Tuple import numpy as np diff --git a/batchgenerators/utilities/custom_types.py b/batchgenerators/utilities/custom_types.py index 2a4b9b1..0098dfe 100644 --- a/batchgenerators/utilities/custom_types.py +++ b/batchgenerators/utilities/custom_types.py @@ -13,9 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union, Tuple, Any, Callable -import numpy as np +from typing import Union, Tuple, Callable +import numpy as np ScalarType = Union[int, float, Tuple[float, float], Callable[..., Union[int, float]]] From b81b4369d7d2a350300722f2965443ead9c41054 Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Mon, 18 Sep 2023 16:10:13 +0300 Subject: [PATCH 51/60] Prefering tuple to list and dtype to astype --- .../augmentations/spatial_transformations.py | 8 ++++---- batchgenerators/augmentations/utils.py | 11 +++++------ batchgenerators/dataloading/data_loader.py | 6 +++--- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/batchgenerators/augmentations/spatial_transformations.py b/batchgenerators/augmentations/spatial_transformations.py index 37dad0f..946487b 100644 --- a/batchgenerators/augmentations/spatial_transformations.py +++ b/batchgenerators/augmentations/spatial_transformations.py @@ -56,16 +56,16 @@ def augment_resize(sample_data, sample_seg, target_size, order=3, order_seg=1): """ dimensionality = sample_data.ndim - 1 if not isinstance(target_size, (list, tuple)): - target_size_here = [target_size] * dimensionality + target_size_here = (target_size,) * dimensionality else: assert len(target_size) == dimensionality, "If you give a tuple/list as target size, make sure it has " \ "the same dimensionality as data!" - target_size_here = list(target_size) + target_size_here = tuple(target_size) sample_data = resize_multichannel_image(sample_data, target_size_here, order) if sample_seg is not None: - target_seg = np.ones([sample_seg.shape[0]] + target_size_here) + target_seg = np.ones((sample_seg.shape[0],) + target_size_here) for c in range(sample_seg.shape[0]): target_seg[c] = resize_segmentation(sample_seg[c], target_size_here, order_seg) else: @@ -96,7 +96,7 @@ def augment_zoom(sample_data, sample_seg, zoom_factors, order=3, order_seg=1): assert len(zoom_factors) == dimensionality, "If you give a tuple/list as target size, make sure it has " \ "the same dimensionality as data!" zoom_factors_here = np.array(zoom_factors) - target_shape_here = list(np.round(shape * zoom_factors_here).astype(int)) + target_shape_here = tuple(np.round(shape * zoom_factors_here).astype(int)) sample_data = resize_multichannel_image(sample_data, target_shape_here, order) diff --git a/batchgenerators/augmentations/utils.py b/batchgenerators/augmentations/utils.py index dae0d83..5cd9e44 100755 --- a/batchgenerators/augmentations/utils.py +++ b/batchgenerators/augmentations/utils.py @@ -561,7 +561,7 @@ def convert_seg_to_bounding_box_coordinates(data_dict, dim, get_rois_from_seg_fl out_seg[b][data_dict['seg'][b] > 0] = 1 bb_target.append(np.array(p_coords_list)) - roi_masks.append(np.array(p_roi_masks_list).astype('uint8')) + roi_masks.append(np.array(p_roi_masks_list, dtype=np.uint8)) roi_labels.append(np.array(p_roi_labels_list)) @@ -590,7 +590,7 @@ def transpose_channels(batch): raise ValueError("wrong dimensions in transpose_channel generator!") -def resize_segmentation(segmentation, new_shape, order=3): +def resize_segmentation(segmentation, new_shape: tuple, order=3): ''' Resizes a segmentation map. Supports all orders (see skimage documentation). Will transform segmentation map to one hot encoding which is resized and transformed back to a segmentation map. @@ -617,7 +617,7 @@ def resize_segmentation(segmentation, new_shape, order=3): return reshaped -def resize_multichannel_image(multichannel_image, new_shape, order=3): +def resize_multichannel_image(multichannel_image, new_shape: tuple, order=3): ''' Resizes multichannel_image. Resizes each channel in c separately and fuses results back together @@ -626,12 +626,11 @@ def resize_multichannel_image(multichannel_image, new_shape, order=3): :param order: :return: ''' - tpe = multichannel_image.dtype - new_shp = [multichannel_image.shape[0]] + list(new_shape) + new_shp = (multichannel_image.shape[0], ) + new_shape result = np.zeros(new_shp, dtype=multichannel_image.dtype) for i in range(multichannel_image.shape[0]): result[i] = resize(multichannel_image[i].astype(float), new_shape, order, clip=True, anti_aliasing=False) - return result.astype(tpe) + return result def get_range_val(value, rnd_type="uniform"): diff --git a/batchgenerators/dataloading/data_loader.py b/batchgenerators/dataloading/data_loader.py index 71dcbf0..581ecd4 100644 --- a/batchgenerators/dataloading/data_loader.py +++ b/batchgenerators/dataloading/data_loader.py @@ -231,11 +231,11 @@ def default_collate(batch): if isinstance(batch[0], np.ndarray): return np.vstack(batch) elif isinstance(batch[0], (int, np.int64)): - return np.array(batch).astype(np.int32) + return np.array(batch, dtype=np.int32) elif isinstance(batch[0], (float, np.float32)): - return np.array(batch).astype(np.float32) + return np.array(batch, dtype=np.float32) elif isinstance(batch[0], (np.float64,)): - return np.array(batch).astype(np.float64) + return np.array(batch, dtype=np.float64) elif isinstance(batch[0], (dict, OrderedDict)): return {key: default_collate([d[key] for d in batch]) for key in batch[0]} elif isinstance(batch[0], (tuple, list)): From 1f86cc0ec9d54d8f7cf198f6862aeef62312df78 Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Mon, 25 Sep 2023 16:11:25 +0300 Subject: [PATCH 52/60] Using rint instead of round --- batchgenerators/augmentations/resample_augmentations.py | 4 ++-- batchgenerators/augmentations/spatial_transformations.py | 2 +- batchgenerators/transforms/crop_and_pad_transforms.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/batchgenerators/augmentations/resample_augmentations.py b/batchgenerators/augmentations/resample_augmentations.py index fe28a46..b749b9c 100644 --- a/batchgenerators/augmentations/resample_augmentations.py +++ b/batchgenerators/augmentations/resample_augmentations.py @@ -58,7 +58,7 @@ def augment_linear_downsampling_scipy(data_sample, zoom_range=(0.5, 1), per_chan else: zoom = uniform(zoom_range[0], zoom_range[1]) - target_shape = np.round(shp * zoom).astype(int) + target_shape = np.rint(shp * zoom).astype(int) if ignore_axes is not None: for i in ignore_axes: @@ -76,7 +76,7 @@ def augment_linear_downsampling_scipy(data_sample, zoom_range=(0.5, 1), per_chan else: zoom = uniform(zoom_range[0], zoom_range[1]) - target_shape = np.round(shp * zoom).astype(int) + target_shape = np.rint(shp * zoom).astype(int) if ignore_axes is not None: for i in ignore_axes: target_shape[i] = shp[i] diff --git a/batchgenerators/augmentations/spatial_transformations.py b/batchgenerators/augmentations/spatial_transformations.py index 946487b..6c20647 100644 --- a/batchgenerators/augmentations/spatial_transformations.py +++ b/batchgenerators/augmentations/spatial_transformations.py @@ -96,7 +96,7 @@ def augment_zoom(sample_data, sample_seg, zoom_factors, order=3, order_seg=1): assert len(zoom_factors) == dimensionality, "If you give a tuple/list as target size, make sure it has " \ "the same dimensionality as data!" zoom_factors_here = np.array(zoom_factors) - target_shape_here = tuple(np.round(shape * zoom_factors_here).astype(int)) + target_shape_here = tuple(np.rint(shape * zoom_factors_here).astype(int)) sample_data = resize_multichannel_image(sample_data, target_shape_here, order) diff --git a/batchgenerators/transforms/crop_and_pad_transforms.py b/batchgenerators/transforms/crop_and_pad_transforms.py index ae1836b..87b20cc 100644 --- a/batchgenerators/transforms/crop_and_pad_transforms.py +++ b/batchgenerators/transforms/crop_and_pad_transforms.py @@ -181,7 +181,7 @@ def __call__(self, **data_dict): if np.random.uniform(0, 1) < self.p_per_channel: shift_here = [] for d in range(workon.ndim - 2): - shift_here.append(int(np.round(np.random.normal( + shift_here.append(int(np.rint(np.random.normal( self.shift_mu[d] if isinstance(self.shift_mu, (list, tuple)) else self.shift_mu, self.shift_sigma[d] if isinstance(self.shift_sigma, (list, tuple)) else self.shift_sigma, From c98f477d528aab95b2fe4344919fe5e6787a514f Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Wed, 6 Dec 2023 11:19:28 +0200 Subject: [PATCH 53/60] Added new TODOS --- batchgenerators/augmentations/spatial_transformations.py | 2 +- batchgenerators/transforms/noise_transforms.py | 3 +++ batchgenerators/transforms/spatial_transforms.py | 4 +++- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/batchgenerators/augmentations/spatial_transformations.py b/batchgenerators/augmentations/spatial_transformations.py index dc081ff..4f8f5d9 100644 --- a/batchgenerators/augmentations/spatial_transformations.py +++ b/batchgenerators/augmentations/spatial_transformations.py @@ -521,7 +521,7 @@ def augment_anatomy_informed(data, seg, t, u, v = get_organ_gradient_field(seg == organ_idx + 2, spacing_ratio=spacing_ratio, blur=blur) - + # TODO: if directions_of_trans[organ_idx][0]: coords[0, :, :, :] = coords[0, :, :, :] + t * dil_magnitude * spacing_ratio if directions_of_trans[organ_idx][1]: diff --git a/batchgenerators/transforms/noise_transforms.py b/batchgenerators/transforms/noise_transforms.py index 1e1f0d7..1a09d82 100644 --- a/batchgenerators/transforms/noise_transforms.py +++ b/batchgenerators/transforms/noise_transforms.py @@ -138,6 +138,7 @@ def __init__(self, rectangle_value): self.rectangle_value = rectangle_value def __call__(self, x): + # TODO: Change this if np.isscalar(self.rectangle_value): return self.rectangle_value elif callable(self.rectangle_value): @@ -148,6 +149,8 @@ def __call__(self, x): raise RuntimeError("unrecognized format for rectangle_value") + + class BlankRectangleTransform(AbstractTransform): def __init__(self, rectangle_size, rectangle_value, num_rectangles, force_square=False, p_per_sample=0.5, p_per_channel=0.5, apply_to_keys=('data',)): diff --git a/batchgenerators/transforms/spatial_transforms.py b/batchgenerators/transforms/spatial_transforms.py index 5fdfbda..86c95a8 100644 --- a/batchgenerators/transforms/spatial_transforms.py +++ b/batchgenerators/transforms/spatial_transforms.py @@ -567,6 +567,7 @@ def __call__(self, **data_dict): self.dim = 3 active_organs = [] + # TODO: Optimize this for prob in self.p_per_sample: if np.random.uniform() < prob: active_organs.append(1) @@ -669,9 +670,10 @@ def __init__(self, data_key="data", label_key="seg", self.border_cval_seg = border_cval_seg def __call__(self, **data_dict): - data = data_dict.get(self.data_key) + data = data_dict[self.data_key] seg = data_dict.get(self.label_key) + # TODO if data.shape[1] < 2: raise ValueError("only support multi-modal images") else: From 57ec2fbb084f1e348bfbd9efe05ae228e106c580 Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Wed, 6 Dec 2023 11:30:09 +0200 Subject: [PATCH 54/60] Disabling test --- tests/test_DataLoader.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_DataLoader.py b/tests/test_DataLoader.py index 1fa59e7..e632248 100644 --- a/tests/test_DataLoader.py +++ b/tests/test_DataLoader.py @@ -200,6 +200,9 @@ def test_return_incomplete_multi_threaded(self): self.assertTrue(len(np.unique(all_return)) == len(data)) def test_thoroughly(self): + really_test_this = False + if not really_test_this: + print("This test takes too much time. Run me if you really want to test me.") data_list = [list(range(123)), list(range(1243)), list(range(1)), From 01e0d106b93086765ab16c5311bf87ac7fa52ce9 Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Wed, 6 Dec 2023 15:00:47 +0200 Subject: [PATCH 55/60] Removing redundant dependencies --- setup.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/setup.py b/setup.py index e13e845..bb9abc4 100755 --- a/setup.py +++ b/setup.py @@ -10,12 +10,10 @@ license='Apache License Version 2.0, January 2004', packages=find_packages(exclude=["tests"]), install_requires=[ - "pillow>=7.1.2", "numpy>=1.10.2", "scipy", "scikit-image", "scikit-learn", - "future", "threadpoolctl", "pandas" ], From 6388e43730bafc8eab4a77529dd3169144f19a6b Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Wed, 10 Jan 2024 15:51:50 +0200 Subject: [PATCH 56/60] More inplace np.clip --- batchgenerators/transforms/color_transforms.py | 2 +- batchgenerators/transforms/noise_transforms.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/batchgenerators/transforms/color_transforms.py b/batchgenerators/transforms/color_transforms.py index 6b75b66..3d294ab 100644 --- a/batchgenerators/transforms/color_transforms.py +++ b/batchgenerators/transforms/color_transforms.py @@ -205,5 +205,5 @@ def __init__(self, min=None, max=None, data_key="data"): self.max = max def __call__(self, **data_dict): - data_dict[self.data_key] = np.clip(data_dict[self.data_key], self.min, self.max) + np.clip(data_dict[self.data_key], self.min, self.max, out=data_dict[self.data_key]) return data_dict diff --git a/batchgenerators/transforms/noise_transforms.py b/batchgenerators/transforms/noise_transforms.py index 1a09d82..e2e40c3 100644 --- a/batchgenerators/transforms/noise_transforms.py +++ b/batchgenerators/transforms/noise_transforms.py @@ -334,7 +334,7 @@ def __call__(self, **data_dict): filter_here, mode='same' ) - data[b, c] = np.clip(data[b, c], mn, mx) + np.clip(data[b, c], mn, mx, out=data[b, c]) else: for c in range(data.shape[1]): if np.random.uniform() < self.p_per_channel: @@ -351,7 +351,7 @@ def __call__(self, **data_dict): filter_here, mode='same' ) - data[b, c] = np.clip(data[b, c], mn, mx) + np.clip(data[b, c], mn, mx, out=data[b, c]) return data_dict From f3b111738328f8653ed1c2af9cecc454c3e17eaf Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Thu, 11 Jan 2024 13:25:39 +0200 Subject: [PATCH 57/60] Using and reducing memory allocations when casting to new type --- batchgenerators/augmentations/resample_augmentations.py | 4 ++-- batchgenerators/augmentations/spatial_transformations.py | 2 +- batchgenerators/augmentations/utils.py | 9 +++++---- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/batchgenerators/augmentations/resample_augmentations.py b/batchgenerators/augmentations/resample_augmentations.py index b749b9c..955e43d 100644 --- a/batchgenerators/augmentations/resample_augmentations.py +++ b/batchgenerators/augmentations/resample_augmentations.py @@ -81,8 +81,8 @@ def augment_linear_downsampling_scipy(data_sample, zoom_range=(0.5, 1), per_chan for i in ignore_axes: target_shape[i] = shp[i] - downsampled = resize(data_sample[c].astype(float), target_shape, order=order_downsample, mode='edge', - anti_aliasing=False) + downsampled = resize(data_sample[c].astype(float, copy=False), target_shape, order=order_downsample, + mode='edge', anti_aliasing=False) data_sample[c] = resize(downsampled, shp, order=order_upsample, mode='edge', anti_aliasing=False) return data_sample diff --git a/batchgenerators/augmentations/spatial_transformations.py b/batchgenerators/augmentations/spatial_transformations.py index 4f8f5d9..be97738 100644 --- a/batchgenerators/augmentations/spatial_transformations.py +++ b/batchgenerators/augmentations/spatial_transformations.py @@ -98,7 +98,7 @@ def augment_zoom(sample_data, sample_seg, zoom_factors, order=3, order_seg=1): assert len(zoom_factors) == dimensionality, "If you give a tuple/list as target size, make sure it has " \ "the same dimensionality as data!" zoom_factors_here = np.array(zoom_factors) - target_shape_here = tuple(np.rint(shape * zoom_factors_here).astype(int)) + target_shape_here = tuple(np.rint(shape * zoom_factors_here).astype(int, copy=False)) sample_data = resize_multichannel_image(sample_data, target_shape_here, order) diff --git a/batchgenerators/augmentations/utils.py b/batchgenerators/augmentations/utils.py index 84485ed..335f64d 100755 --- a/batchgenerators/augmentations/utils.py +++ b/batchgenerators/augmentations/utils.py @@ -170,7 +170,8 @@ def interpolate_img(img, coords, order=3, mode='nearest', cval=0.0, is_seg=False result[res_new >= 0.5] = c return result else: - return map_coordinates(img.astype(float), coords, order=order, mode=mode, cval=cval).astype(img.dtype) + return map_coordinates( + img.astype(float, copy=False), coords, order=order, mode=mode, cval=cval).astype(img.dtype, copy=False) def generate_noise(shape, alpha, sigma): @@ -183,7 +184,7 @@ def find_entries_in_array(entries, myarray): entries = np.array(entries, dtype=int) lut = np.zeros(np.max(myarray) + 1, 'bool') lut[entries] = True - return np.take(lut, myarray.astype(int)) + return np.take(lut, myarray.astype(int, copy=False)) def center_crop_3D_image(img, crop_size): @@ -629,7 +630,7 @@ def resize_multichannel_image(multichannel_image, new_shape: tuple, order=3): new_shp = (multichannel_image.shape[0], ) + new_shape result = np.zeros(new_shp, dtype=multichannel_image.dtype) for i in range(multichannel_image.shape[0]): - result[i] = resize(multichannel_image[i].astype(float), new_shape, order, clip=True, anti_aliasing=False) + result[i] = resize(multichannel_image[i].astype(float, copy=False), new_shape, order, clip=True, anti_aliasing=False) return result @@ -789,7 +790,7 @@ def get_organ_gradient_field(organ, spacing_ratio=0.3125/3.0, blur=32): :param spacing_ratio: ratio of the axial spacing and the slice thickness, needed for the right vector field calculation :param blur: kernel constant """ - organ_blurred = gaussian_filter(organ.astype(float), + organ_blurred = gaussian_filter(organ.astype(float, copy=False), sigma=(blur * spacing_ratio, blur, blur), order=0, mode='nearest') From c4eda10e52cf34668118e5816bcd312fe1004468 Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Wed, 31 Jan 2024 16:42:03 +0200 Subject: [PATCH 58/60] Making resize segmentation faster without additional casting --- batchgenerators/augmentations/utils.py | 39 +++++++++++++++++--------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/batchgenerators/augmentations/utils.py b/batchgenerators/augmentations/utils.py index 335f64d..f28ce5d 100755 --- a/batchgenerators/augmentations/utils.py +++ b/batchgenerators/augmentations/utils.py @@ -193,7 +193,8 @@ def center_crop_3D_image(img, crop_size): center_crop = [int(crop_size)] * img.ndim else: center_crop = crop_size - assert len(center_crop) == img.ndim, "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (3d)" + assert len( + center_crop) == img.ndim, "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (3d)" return img[int(center[0] - center_crop[0] / 2.):int(center[0] + center_crop[0] / 2.), int(center[1] - center_crop[1] / 2.):int(center[1] + center_crop[1] / 2.), int(center[2] - center_crop[2] / 2.):int(center[2] + center_crop[2] / 2.)] @@ -206,7 +207,8 @@ def center_crop_3D_image_batched(img, crop_size): center_crop = [int(crop_size)] * (img.ndim - 2) else: center_crop = crop_size - assert len(center_crop) == (img.ndim - 2), "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (3d)" + assert len(center_crop) == ( + img.ndim - 2), "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (3d)" return img[:, :, int(center[0] - center_crop[0] / 2.):int(center[0] + center_crop[0] / 2.), int(center[1] - center_crop[1] / 2.):int(center[1] + center_crop[1] / 2.), int(center[2] - center_crop[2] / 2.):int(center[2] + center_crop[2] / 2.)] @@ -218,7 +220,8 @@ def center_crop_2D_image(img, crop_size): center_crop = [int(crop_size)] * img.ndim else: center_crop = crop_size - assert len(center_crop) == img.ndim, "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (2d)" + assert len( + center_crop) == img.ndim, "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (2d)" return img[int(center[0] - center_crop[0] / 2.):int(center[0] + center_crop[0] / 2.), int(center[1] - center_crop[1] / 2.):int(center[1] + center_crop[1] / 2.)] @@ -230,7 +233,8 @@ def center_crop_2D_image_batched(img, crop_size): center_crop = [int(crop_size)] * (img.ndim - 2) else: center_crop = crop_size - assert len(center_crop) == (img.ndim - 2), "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (2d)" + assert len(center_crop) == ( + img.ndim - 2), "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (2d)" return img[:, :, int(center[0] - center_crop[0] / 2.):int(center[0] + center_crop[0] / 2.), int(center[1] - center_crop[1] / 2.):int(center[1] + center_crop[1] / 2.)] @@ -239,7 +243,8 @@ def random_crop_3D_image(img, crop_size): if type(crop_size) not in (tuple, list): crop_size = [crop_size] * img.ndim else: - assert len(crop_size) == img.ndim, "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (3d)" + assert len( + crop_size) == img.ndim, "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (3d)" if crop_size[0] < img.shape[0]: lb_x = np.random.randint(0, img.shape[0] - crop_size[0]) @@ -269,7 +274,8 @@ def random_crop_3D_image_batched(img, crop_size): if type(crop_size) not in (tuple, list): crop_size = [crop_size] * (img.ndim - 2) else: - assert len(crop_size) == (img.ndim - 2), "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (3d)" + assert len(crop_size) == ( + img.ndim - 2), "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (3d)" if crop_size[0] < img.shape[2]: lb_x = np.random.randint(0, img.shape[2] - crop_size[0]) @@ -299,7 +305,8 @@ def random_crop_2D_image(img, crop_size): if type(crop_size) not in (tuple, list): crop_size = [crop_size] * img.ndim else: - assert len(crop_size) == img.ndim, "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (2d)" + assert len( + crop_size) == img.ndim, "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (2d)" if crop_size[0] < img.shape[0]: lb_x = np.random.randint(0, img.shape[0] - crop_size[0]) @@ -322,7 +329,8 @@ def random_crop_2D_image_batched(img, crop_size): if type(crop_size) not in (tuple, list): crop_size = [crop_size] * (img.ndim - 2) else: - assert len(crop_size) == (img.ndim - 2), "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (2d)" + assert len(crop_size) == ( + img.ndim - 2), "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (2d)" if crop_size[0] < img.shape[2]: lb_x = np.random.randint(0, img.shape[2] - crop_size[0]) @@ -604,15 +612,15 @@ def resize_segmentation(segmentation, new_shape: tuple, order=3): tpe = segmentation.dtype assert segmentation.ndim == len(new_shape), "new shape must have same dimensionality as segmentation" if order == 0: - return resize(segmentation.astype(float), new_shape, order, mode="edge", clip=True, anti_aliasing=False).astype( - tpe) + return resize(segmentation.astype(np.float64, copy=False), new_shape, order, mode="edge", clip=True, + anti_aliasing=False).astype(tpe) else: unique_labels = pd.unique(segmentation.reshape(-1)) # does not need sorting reshaped = np.zeros(new_shape, dtype=tpe) for c in unique_labels: mask = segmentation == c - reshaped_multihot = resize(mask.astype(float), new_shape, order, mode="edge", clip=True, + reshaped_multihot = resize(mask.astype(np.float64, copy=False), new_shape, order, mode="edge", clip=True, anti_aliasing=False) reshaped[reshaped_multihot >= 0.5] = c return reshaped @@ -627,10 +635,11 @@ def resize_multichannel_image(multichannel_image, new_shape: tuple, order=3): :param order: :return: ''' - new_shp = (multichannel_image.shape[0], ) + new_shape + new_shp = (multichannel_image.shape[0],) + new_shape result = np.zeros(new_shp, dtype=multichannel_image.dtype) for i in range(multichannel_image.shape[0]): - result[i] = resize(multichannel_image[i].astype(float, copy=False), new_shape, order, clip=True, anti_aliasing=False) + result[i] = resize(multichannel_image[i].astype(float, copy=False), new_shape, order, clip=True, + anti_aliasing=False) return result @@ -782,7 +791,8 @@ def mask_random_squares(img, square_size, n_squares, n_val, channel_wise_n_val=F square_pos=square_pos) return img -def get_organ_gradient_field(organ, spacing_ratio=0.3125/3.0, blur=32): + +def get_organ_gradient_field(organ, spacing_ratio=0.3125 / 3.0, blur=32): """ Calculates the gradient field around the organ segmentations for the anatomy-informed augmentation @@ -800,6 +810,7 @@ def get_organ_gradient_field(organ, spacing_ratio=0.3125/3.0, blur=32): return t, u, v + def ignore_anatomy(segm, max_annotation_value=1, replace_value=0): segm[segm > max_annotation_value] = replace_value return segm From d06ed91cd8d5103a1bbc4dce3ff22f6649292121 Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Wed, 31 Jan 2024 16:44:47 +0200 Subject: [PATCH 59/60] Making resize segmentation faster without additional casting --- batchgenerators/augmentations/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/batchgenerators/augmentations/utils.py b/batchgenerators/augmentations/utils.py index f28ce5d..b50f3e4 100755 --- a/batchgenerators/augmentations/utils.py +++ b/batchgenerators/augmentations/utils.py @@ -613,7 +613,7 @@ def resize_segmentation(segmentation, new_shape: tuple, order=3): assert segmentation.ndim == len(new_shape), "new shape must have same dimensionality as segmentation" if order == 0: return resize(segmentation.astype(np.float64, copy=False), new_shape, order, mode="edge", clip=True, - anti_aliasing=False).astype(tpe) + anti_aliasing=False).astype(tpe, copy=False) else: unique_labels = pd.unique(segmentation.reshape(-1)) # does not need sorting reshaped = np.zeros(new_shape, dtype=tpe) From 6d8058cb429ae2ab1a64276dce552fe08a1b2262 Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Thu, 1 Feb 2024 13:24:09 +0200 Subject: [PATCH 60/60] Added callable retain_stats and contrast_range arguments back to color augmentation functions --- .../augmentations/color_augmentations.py | 41 +++++++++++-------- tests/test_DataLoader.py | 1 + 2 files changed, 26 insertions(+), 16 deletions(-) diff --git a/batchgenerators/augmentations/color_augmentations.py b/batchgenerators/augmentations/color_augmentations.py index 0db8c87..3e940dc 100644 --- a/batchgenerators/augmentations/color_augmentations.py +++ b/batchgenerators/augmentations/color_augmentations.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple +from typing import Tuple, Callable, Union import numpy as np @@ -21,22 +21,25 @@ reverse_broadcast -def get_augment_contrast_factor(contrast_range: Tuple[float, float], +def get_augment_contrast_factor(contrast_range: Union[Tuple[float, float], Callable[[], float]], per_channel: bool, size: int, broadcast_size: int): if per_channel: factor = [] - contrast_l = max(contrast_range[0], 1) for _ in range(size): - if contrast_range[0] < 1 and np.random.random() < 0.5: + if callable(contrast_range): + factor.append(contrast_range()) + elif contrast_range[0] < 1 and np.random.random() < 0.5: factor.append(np.random.uniform(contrast_range[0], 1)) else: - factor.append(np.random.uniform(contrast_l, contrast_range[1])) + factor.append(np.random.uniform(max(contrast_range[0], 1), contrast_range[1])) factor = reverse_broadcast(np.array(factor), get_broadcast_axes(broadcast_size)) else: - if contrast_range[0] < 1 and np.random.random() < 0.5: + if callable(contrast_range): + factor = contrast_range() + elif contrast_range[0] < 1 and np.random.random() < 0.5: factor = np.random.uniform(contrast_range[0], 1) else: factor = np.random.uniform(max(contrast_range[0], 1), contrast_range[1]) @@ -45,7 +48,7 @@ def get_augment_contrast_factor(contrast_range: Tuple[float, float], def augment_contrast(data_sample: np.ndarray, - contrast_range: Tuple[float, float] = (0.75, 1.25), + contrast_range: Union[Tuple[float, float], Callable[[], float]] = (0.75, 1.25), preserve_range: bool = True, per_channel: bool = True, p_per_channel: float = 1, @@ -106,11 +109,12 @@ def augment_brightness_multiplicative(data_sample, multiplier_range=(0.5, 2), pe def augment_gamma(data_sample, gamma_range=(0.5, 2), invert_image=False, epsilon=1e-7, per_channel=False, - retain_stats: bool = False): + retain_stats: Union[bool, Callable[[], bool]] = False): if invert_image: data_sample = - data_sample if not per_channel: + retain_stats = retain_stats() if callable(retain_stats) else retain_stats if retain_stats: mn = data_sample.mean() sd = data_sample.std() @@ -138,21 +142,26 @@ def augment_gamma(data_sample, gamma_range=(0.5, 2), invert_image=False, epsilon axes = tuple(range(1, data_sample.ndim)) - if retain_stats: - mn = data_sample.mean(axis=axes, keepdims=True) - sd = data_sample.mean(axis=axes, keepdims=True) + if callable(retain_stats): + retain_stats = [retain_stats() for _ in range(shape_0)] + else: + retain_stats = [retain_stats] * shape_0 + retain_stats_here = any(retain_stats) + if retain_stats_here: + mn = data_sample[retain_stats].mean(axis=axes, keepdims=True) + sd = data_sample[retain_stats].mean(axis=axes, keepdims=True) minm = data_sample.min(axis=axes, keepdims=True) rnge = data_sample.max(axis=axes, keepdims=True) - minm + epsilon broadcast_axes = get_broadcast_axes(data_sample.ndim) - gamma = reverse_broadcast(gamma, broadcast_axes) # TODO: Remove + gamma = reverse_broadcast(gamma, broadcast_axes) data_sample = np.power((data_sample - minm) / rnge, gamma) * rnge + minm - if retain_stats: - data_sample -= data_sample.mean(axis=axes, keepdims=True) - data_sample *= sd / (data_sample.std(axis=axes, keepdims=True) + 1e-8) - data_sample += mn + if retain_stats_here: + data_sample[retain_stats] -= data_sample[retain_stats].mean(axis=axes, keepdims=True) + data_sample[retain_stats] *= sd / (data_sample[retain_stats].std(axis=axes, keepdims=True) + 1e-8) + data_sample[retain_stats] += mn if invert_image: data_sample = - data_sample diff --git a/tests/test_DataLoader.py b/tests/test_DataLoader.py index e632248..21efebe 100644 --- a/tests/test_DataLoader.py +++ b/tests/test_DataLoader.py @@ -203,6 +203,7 @@ def test_thoroughly(self): really_test_this = False if not really_test_this: print("This test takes too much time. Run me if you really want to test me.") + return data_list = [list(range(123)), list(range(1243)), list(range(1)),