diff --git a/nnunetv2/evaluation/find_best_configuration.py b/nnunetv2/evaluation/find_best_configuration.py index 7e9f77420..f585b80d9 100644 --- a/nnunetv2/evaluation/find_best_configuration.py +++ b/nnunetv2/evaluation/find_best_configuration.py @@ -3,8 +3,9 @@ from copy import deepcopy from typing import Union, List, Tuple -from batchgenerators.utilities.file_and_folder_operations import load_json, join, isdir, save_json - +from batchgenerators.utilities.file_and_folder_operations import ( + load_json, join, isdir, listdir, save_json +) from nnunetv2.configuration import default_num_processes from nnunetv2.ensembling.ensemble import ensemble_crossvalidations from nnunetv2.evaluation.accumulate_cv_results import accumulate_cv_results @@ -320,6 +321,11 @@ def accumulate_crossval_results_entry_point(): merged_output_folder = join(trained_model_folder, f'crossval_results_folds_{folds_tuple_to_string(args.f)}') else: merged_output_folder = args.o + if isdir(merged_output_folder) and len(listdir(merged_output_folder)) > 0: + raise FileExistsError( + f"Output folder {merged_output_folder} exists and is not empty. " + f"To avoid data loss, nnUNet requires an empty output folder." + ) accumulate_cv_results(trained_model_folder, merged_output_folder, args.f) diff --git a/nnunetv2/experiment_planning/experiment_planners/resampling/__init__.py b/nnunetv2/experiment_planning/experiment_planners/resampling/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/nnunetv2/experiment_planning/experiment_planners/resampling/resample_with_torch.py b/nnunetv2/experiment_planning/experiment_planners/resampling/resample_with_torch.py new file mode 100644 index 000000000..ee5adc7b9 --- /dev/null +++ b/nnunetv2/experiment_planning/experiment_planners/resampling/resample_with_torch.py @@ -0,0 +1,181 @@ +from typing import Union, List, Tuple + +from nnunetv2.configuration import ANISO_THRESHOLD +from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner +from nnunetv2.experiment_planning.experiment_planners.residual_unets.residual_encoder_unet_planners import \ + nnUNetPlannerResEncL +from nnunetv2.preprocessing.resampling.resample_torch import resample_torch_fornnunet + + +class nnUNetPlannerResEncL_torchres(nnUNetPlannerResEncL): + def __init__(self, dataset_name_or_id: Union[str, int], + gpu_memory_target_in_gb: float = 24, + preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetLPlans_torchres', + overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, + suppress_transpose: bool = False): + super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, + overwrite_target_spacing, suppress_transpose) + + def generate_data_identifier(self, configuration_name: str) -> str: + """ + configurations are unique within each plans file but different plans file can have configurations with the + same name. In order to distinguish the associated data we need a data identifier that reflects not just the + config but also the plans it originates from + """ + return self.plans_identifier + '_' + configuration_name + + def determine_resampling(self, *args, **kwargs): + """ + returns what functions to use for resampling data and seg, respectively. Also returns kwargs + resampling function must be callable(data, current_spacing, new_spacing, **kwargs) + + determine_resampling is called within get_plans_for_configuration to allow for different functions for each + configuration + """ + resampling_data = resample_torch_fornnunet + resampling_data_kwargs = { + "is_seg": False, + 'force_separate_z': False, + 'memefficient_seg_resampling': False + } + resampling_seg = resample_torch_fornnunet + resampling_seg_kwargs = { + "is_seg": True, + 'force_separate_z': False, + 'memefficient_seg_resampling': False + } + return resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs + + def determine_segmentation_softmax_export_fn(self, *args, **kwargs): + """ + function must be callable(data, new_shape, current_spacing, new_spacing, **kwargs). The new_shape should be + used as target. current_spacing and new_spacing are merely there in case we want to use it somehow + + determine_segmentation_softmax_export_fn is called within get_plans_for_configuration to allow for different + functions for each configuration + + """ + resampling_fn = resample_torch_fornnunet + resampling_fn_kwargs = { + "is_seg": False, + 'force_separate_z': False, + 'memefficient_seg_resampling': False + } + return resampling_fn, resampling_fn_kwargs + + +class nnUNetPlannerResEncL_torchres_sepz(nnUNetPlannerResEncL): + def __init__(self, dataset_name_or_id: Union[str, int], + gpu_memory_target_in_gb: float = 24, + preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetLPlans_torchres_sepz', + overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, + suppress_transpose: bool = False): + super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, + overwrite_target_spacing, suppress_transpose) + + def generate_data_identifier(self, configuration_name: str) -> str: + """ + configurations are unique within each plans file but different plans file can have configurations with the + same name. In order to distinguish the associated data we need a data identifier that reflects not just the + config but also the plans it originates from + """ + return self.plans_identifier + '_' + configuration_name + + def determine_resampling(self, *args, **kwargs): + """ + returns what functions to use for resampling data and seg, respectively. Also returns kwargs + resampling function must be callable(data, current_spacing, new_spacing, **kwargs) + + determine_resampling is called within get_plans_for_configuration to allow for different functions for each + configuration + """ + resampling_data = resample_torch_fornnunet + resampling_data_kwargs = { + "is_seg": False, + 'force_separate_z': None, + 'memefficient_seg_resampling': False, + 'separate_z_anisotropy_threshold': ANISO_THRESHOLD + } + resampling_seg = resample_torch_fornnunet + resampling_seg_kwargs = { + "is_seg": True, + 'force_separate_z': None, + 'memefficient_seg_resampling': False, + 'separate_z_anisotropy_threshold': ANISO_THRESHOLD + } + return resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs + + def determine_segmentation_softmax_export_fn(self, *args, **kwargs): + """ + function must be callable(data, new_shape, current_spacing, new_spacing, **kwargs). The new_shape should be + used as target. current_spacing and new_spacing are merely there in case we want to use it somehow + + determine_segmentation_softmax_export_fn is called within get_plans_for_configuration to allow for different + functions for each configuration + + """ + resampling_fn = resample_torch_fornnunet + resampling_fn_kwargs = { + "is_seg": False, + 'force_separate_z': None, + 'memefficient_seg_resampling': False, + 'separate_z_anisotropy_threshold': ANISO_THRESHOLD + } + return resampling_fn, resampling_fn_kwargs + + +class nnUNetPlanner_torchres(ExperimentPlanner): + def __init__(self, dataset_name_or_id: Union[str, int], + gpu_memory_target_in_gb: float = 8, + preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetPlans_torchres', + overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, + suppress_transpose: bool = False): + super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, + overwrite_target_spacing, suppress_transpose) + + def generate_data_identifier(self, configuration_name: str) -> str: + """ + configurations are unique within each plans file but different plans file can have configurations with the + same name. In order to distinguish the associated data we need a data identifier that reflects not just the + config but also the plans it originates from + """ + return self.plans_identifier + '_' + configuration_name + + def determine_resampling(self, *args, **kwargs): + """ + returns what functions to use for resampling data and seg, respectively. Also returns kwargs + resampling function must be callable(data, current_spacing, new_spacing, **kwargs) + + determine_resampling is called within get_plans_for_configuration to allow for different functions for each + configuration + """ + resampling_data = resample_torch_fornnunet + resampling_data_kwargs = { + "is_seg": False, + 'force_separate_z': False, + 'memefficient_seg_resampling': False + } + resampling_seg = resample_torch_fornnunet + resampling_seg_kwargs = { + "is_seg": True, + 'force_separate_z': False, + 'memefficient_seg_resampling': False + } + return resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs + + def determine_segmentation_softmax_export_fn(self, *args, **kwargs): + """ + function must be callable(data, new_shape, current_spacing, new_spacing, **kwargs). The new_shape should be + used as target. current_spacing and new_spacing are merely there in case we want to use it somehow + + determine_segmentation_softmax_export_fn is called within get_plans_for_configuration to allow for different + functions for each configuration + + """ + resampling_fn = resample_torch_fornnunet + resampling_fn_kwargs = { + "is_seg": False, + 'force_separate_z': False, + 'memefficient_seg_resampling': False + } + return resampling_fn, resampling_fn_kwargs diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets/residual_encoder_unet_planners.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets/residual_encoder_unet_planners.py index f7026e311..012950b82 100644 --- a/nnunetv2/experiment_planning/experiment_planners/residual_unets/residual_encoder_unet_planners.py +++ b/nnunetv2/experiment_planning/experiment_planners/residual_unets/residual_encoder_unet_planners.py @@ -294,63 +294,6 @@ def __init__(self, dataset_name_or_id: Union[str, int], self.max_dataset_covered = 1 -class nnUNetPlannerResEncL_torchres(nnUNetPlannerResEncL): - def __init__(self, dataset_name_or_id: Union[str, int], - gpu_memory_target_in_gb: float = 24, - preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetLPlans_torchres', - overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, - suppress_transpose: bool = False): - super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, - overwrite_target_spacing, suppress_transpose) - - def generate_data_identifier(self, configuration_name: str) -> str: - """ - configurations are unique within each plans file but different plans file can have configurations with the - same name. In order to distinguish the associated data we need a data identifier that reflects not just the - config but also the plans it originates from - """ - return self.plans_identifier + '_' + configuration_name - - def determine_resampling(self, *args, **kwargs): - """ - returns what functions to use for resampling data and seg, respectively. Also returns kwargs - resampling function must be callable(data, current_spacing, new_spacing, **kwargs) - - determine_resampling is called within get_plans_for_configuration to allow for different functions for each - configuration - """ - resampling_data = resample_torch_fornnunet - resampling_data_kwargs = { - "is_seg": False, - 'force_separate_z': False, - 'memefficient_seg_resampling': False - } - resampling_seg = resample_torch_fornnunet - resampling_seg_kwargs = { - "is_seg": True, - 'force_separate_z': False, - 'memefficient_seg_resampling': False - } - return resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs - - def determine_segmentation_softmax_export_fn(self, *args, **kwargs): - """ - function must be callable(data, new_shape, current_spacing, new_spacing, **kwargs). The new_shape should be - used as target. current_spacing and new_spacing are merely there in case we want to use it somehow - - determine_segmentation_softmax_export_fn is called within get_plans_for_configuration to allow for different - functions for each configuration - - """ - resampling_fn = resample_torch_fornnunet - resampling_fn_kwargs = { - "is_seg": False, - 'force_separate_z': False, - 'memefficient_seg_resampling': False - } - return resampling_fn, resampling_fn_kwargs - - if __name__ == '__main__': # we know both of these networks run with batch size 2 and 12 on ~8-10GB, respectively net = ResidualEncoderUNet(input_channels=1, n_stages=6, features_per_stage=(32, 64, 128, 256, 320, 320), diff --git a/nnunetv2/imageio/nibabel_reader_writer.py b/nnunetv2/imageio/nibabel_reader_writer.py index 78fb17ac1..2854da4b5 100644 --- a/nnunetv2/imageio/nibabel_reader_writer.py +++ b/nnunetv2/imageio/nibabel_reader_writer.py @@ -31,8 +31,6 @@ class NibabelIO(BaseReaderWriter): supported_file_endings = [ '.nii', '.nii.gz', - '.nrrd', - '.mha' ] def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]: @@ -110,8 +108,6 @@ class NibabelIOWithReorient(BaseReaderWriter): supported_file_endings = [ '.nii', '.nii.gz', - '.nrrd', - '.mha' ] def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]: diff --git a/nnunetv2/preprocessing/preprocessors/default_preprocessor.py b/nnunetv2/preprocessing/preprocessors/default_preprocessor.py index 7e0068b9d..8b1abf7b2 100644 --- a/nnunetv2/preprocessing/preprocessors/default_preprocessor.py +++ b/nnunetv2/preprocessing/preprocessors/default_preprocessor.py @@ -230,15 +230,17 @@ def run(self, dataset_name_or_id: Union[int, str], configuration_name: str, plan # multiprocessing magic. r = [] with multiprocessing.get_context("spawn").Pool(num_processes) as p: + remaining = list(range(len(dataset))) + # p is pretty nifti. If we kill workers they just respawn but don't do any work. + # So we need to store the original pool of workers. + workers = [j for j in p._pool] + for k in dataset.keys(): r.append(p.starmap_async(self.run_case_save, ((join(output_directory, k), dataset[k]['images'], dataset[k]['label'], plans_manager, configuration_manager, dataset_json),))) - remaining = list(range(len(dataset))) - # p is pretty nifti. If we kill workers they just respawn but don't do any work. - # So we need to store the original pool of workers. - workers = [j for j in p._pool] + with tqdm(desc=None, total=len(dataset), disable=self.verbose) as pbar: while len(remaining) > 0: all_alive = all([j.is_alive() for j in workers]) @@ -251,6 +253,8 @@ def run(self, dataset_name_or_id: Union[int, str], configuration_name: str, plan 'an error message, out of RAM is likely the problem. In that case ' 'reducing the number of workers might help') done = [i for i in remaining if r[i].ready()] + # get done so that errors can be raised + _ = [r[i].get() for i in done] for _ in done: r[_].get() # allows triggering errors pbar.update() diff --git a/nnunetv2/preprocessing/resampling/default_resampling.py b/nnunetv2/preprocessing/resampling/default_resampling.py index 299aa939c..c65dc142c 100644 --- a/nnunetv2/preprocessing/resampling/default_resampling.py +++ b/nnunetv2/preprocessing/resampling/default_resampling.py @@ -31,8 +31,12 @@ def compute_new_shape(old_shape: Union[Tuple[int, ...], List[int], np.ndarray], return new_shape -def determine_do_sep_z_and_axis(force_separate_z, current_spacing, new_spacing, - separate_z_anisotropy_threshold: float = ANISO_THRESHOLD) -> Tuple[bool, int]: + +def determine_do_sep_z_and_axis( + force_separate_z: bool, + current_spacing, + new_spacing, + separate_z_anisotropy_threshold: float = ANISO_THRESHOLD) -> Tuple[bool, Union[int, None]]: if force_separate_z is not None: do_separate_z = force_separate_z if force_separate_z: diff --git a/nnunetv2/utilities/plans_handling/plans_handler.py b/nnunetv2/utilities/plans_handling/plans_handler.py index 11b76dfcb..eab65885a 100644 --- a/nnunetv2/utilities/plans_handling/plans_handler.py +++ b/nnunetv2/utilities/plans_handling/plans_handler.py @@ -54,6 +54,8 @@ def __init__(self, configuration_dict: dict): conv_op = convert_dim_to_conv_op(dim) instnorm = get_matching_instancenorm(dimension=dim) + convs_or_blocks = "n_conv_per_stage" if unet_class_name == "PlainConvUNet" else "n_blocks_per_stage" + arch_dict = { 'network_class_name': network_class_name, 'arch_kwargs': { @@ -64,7 +66,7 @@ def __init__(self, configuration_dict: dict): "conv_op": conv_op.__module__ + '.' + conv_op.__name__, "kernel_sizes": deepcopy(self.configuration["conv_kernel_sizes"]), "strides": deepcopy(self.configuration["pool_op_kernel_sizes"]), - "n_conv_per_stage": deepcopy(self.configuration["n_conv_per_stage_encoder"]), + convs_or_blocks: deepcopy(self.configuration["n_conv_per_stage_encoder"]), "n_conv_per_stage_decoder": deepcopy(self.configuration["n_conv_per_stage_decoder"]), "conv_bias": True, "norm_op": instnorm.__module__ + '.' + instnorm.__name__,