Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
FabianIsensee committed Jun 17, 2024
2 parents 24b5e48 + 8b6adc2 commit 23afd25
Show file tree
Hide file tree
Showing 8 changed files with 206 additions and 70 deletions.
10 changes: 8 additions & 2 deletions nnunetv2/evaluation/find_best_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from copy import deepcopy
from typing import Union, List, Tuple

from batchgenerators.utilities.file_and_folder_operations import load_json, join, isdir, save_json

from batchgenerators.utilities.file_and_folder_operations import (
load_json, join, isdir, listdir, save_json
)
from nnunetv2.configuration import default_num_processes
from nnunetv2.ensembling.ensemble import ensemble_crossvalidations
from nnunetv2.evaluation.accumulate_cv_results import accumulate_cv_results
Expand Down Expand Up @@ -320,6 +321,11 @@ def accumulate_crossval_results_entry_point():
merged_output_folder = join(trained_model_folder, f'crossval_results_folds_{folds_tuple_to_string(args.f)}')
else:
merged_output_folder = args.o
if isdir(merged_output_folder) and len(listdir(merged_output_folder)) > 0:
raise FileExistsError(
f"Output folder {merged_output_folder} exists and is not empty. "
f"To avoid data loss, nnUNet requires an empty output folder."
)

accumulate_cv_results(trained_model_folder, merged_output_folder, args.f)

Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
from typing import Union, List, Tuple

from nnunetv2.configuration import ANISO_THRESHOLD
from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner
from nnunetv2.experiment_planning.experiment_planners.residual_unets.residual_encoder_unet_planners import \
nnUNetPlannerResEncL
from nnunetv2.preprocessing.resampling.resample_torch import resample_torch_fornnunet


class nnUNetPlannerResEncL_torchres(nnUNetPlannerResEncL):
def __init__(self, dataset_name_or_id: Union[str, int],
gpu_memory_target_in_gb: float = 24,
preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetLPlans_torchres',
overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None,
suppress_transpose: bool = False):
super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name,
overwrite_target_spacing, suppress_transpose)

def generate_data_identifier(self, configuration_name: str) -> str:
"""
configurations are unique within each plans file but different plans file can have configurations with the
same name. In order to distinguish the associated data we need a data identifier that reflects not just the
config but also the plans it originates from
"""
return self.plans_identifier + '_' + configuration_name

def determine_resampling(self, *args, **kwargs):
"""
returns what functions to use for resampling data and seg, respectively. Also returns kwargs
resampling function must be callable(data, current_spacing, new_spacing, **kwargs)
determine_resampling is called within get_plans_for_configuration to allow for different functions for each
configuration
"""
resampling_data = resample_torch_fornnunet
resampling_data_kwargs = {
"is_seg": False,
'force_separate_z': False,
'memefficient_seg_resampling': False
}
resampling_seg = resample_torch_fornnunet
resampling_seg_kwargs = {
"is_seg": True,
'force_separate_z': False,
'memefficient_seg_resampling': False
}
return resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs

def determine_segmentation_softmax_export_fn(self, *args, **kwargs):
"""
function must be callable(data, new_shape, current_spacing, new_spacing, **kwargs). The new_shape should be
used as target. current_spacing and new_spacing are merely there in case we want to use it somehow
determine_segmentation_softmax_export_fn is called within get_plans_for_configuration to allow for different
functions for each configuration
"""
resampling_fn = resample_torch_fornnunet
resampling_fn_kwargs = {
"is_seg": False,
'force_separate_z': False,
'memefficient_seg_resampling': False
}
return resampling_fn, resampling_fn_kwargs


class nnUNetPlannerResEncL_torchres_sepz(nnUNetPlannerResEncL):
def __init__(self, dataset_name_or_id: Union[str, int],
gpu_memory_target_in_gb: float = 24,
preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetLPlans_torchres_sepz',
overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None,
suppress_transpose: bool = False):
super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name,
overwrite_target_spacing, suppress_transpose)

def generate_data_identifier(self, configuration_name: str) -> str:
"""
configurations are unique within each plans file but different plans file can have configurations with the
same name. In order to distinguish the associated data we need a data identifier that reflects not just the
config but also the plans it originates from
"""
return self.plans_identifier + '_' + configuration_name

def determine_resampling(self, *args, **kwargs):
"""
returns what functions to use for resampling data and seg, respectively. Also returns kwargs
resampling function must be callable(data, current_spacing, new_spacing, **kwargs)
determine_resampling is called within get_plans_for_configuration to allow for different functions for each
configuration
"""
resampling_data = resample_torch_fornnunet
resampling_data_kwargs = {
"is_seg": False,
'force_separate_z': None,
'memefficient_seg_resampling': False,
'separate_z_anisotropy_threshold': ANISO_THRESHOLD
}
resampling_seg = resample_torch_fornnunet
resampling_seg_kwargs = {
"is_seg": True,
'force_separate_z': None,
'memefficient_seg_resampling': False,
'separate_z_anisotropy_threshold': ANISO_THRESHOLD
}
return resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs

def determine_segmentation_softmax_export_fn(self, *args, **kwargs):
"""
function must be callable(data, new_shape, current_spacing, new_spacing, **kwargs). The new_shape should be
used as target. current_spacing and new_spacing are merely there in case we want to use it somehow
determine_segmentation_softmax_export_fn is called within get_plans_for_configuration to allow for different
functions for each configuration
"""
resampling_fn = resample_torch_fornnunet
resampling_fn_kwargs = {
"is_seg": False,
'force_separate_z': None,
'memefficient_seg_resampling': False,
'separate_z_anisotropy_threshold': ANISO_THRESHOLD
}
return resampling_fn, resampling_fn_kwargs


class nnUNetPlanner_torchres(ExperimentPlanner):
def __init__(self, dataset_name_or_id: Union[str, int],
gpu_memory_target_in_gb: float = 8,
preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetPlans_torchres',
overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None,
suppress_transpose: bool = False):
super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name,
overwrite_target_spacing, suppress_transpose)

def generate_data_identifier(self, configuration_name: str) -> str:
"""
configurations are unique within each plans file but different plans file can have configurations with the
same name. In order to distinguish the associated data we need a data identifier that reflects not just the
config but also the plans it originates from
"""
return self.plans_identifier + '_' + configuration_name

def determine_resampling(self, *args, **kwargs):
"""
returns what functions to use for resampling data and seg, respectively. Also returns kwargs
resampling function must be callable(data, current_spacing, new_spacing, **kwargs)
determine_resampling is called within get_plans_for_configuration to allow for different functions for each
configuration
"""
resampling_data = resample_torch_fornnunet
resampling_data_kwargs = {
"is_seg": False,
'force_separate_z': False,
'memefficient_seg_resampling': False
}
resampling_seg = resample_torch_fornnunet
resampling_seg_kwargs = {
"is_seg": True,
'force_separate_z': False,
'memefficient_seg_resampling': False
}
return resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs

def determine_segmentation_softmax_export_fn(self, *args, **kwargs):
"""
function must be callable(data, new_shape, current_spacing, new_spacing, **kwargs). The new_shape should be
used as target. current_spacing and new_spacing are merely there in case we want to use it somehow
determine_segmentation_softmax_export_fn is called within get_plans_for_configuration to allow for different
functions for each configuration
"""
resampling_fn = resample_torch_fornnunet
resampling_fn_kwargs = {
"is_seg": False,
'force_separate_z': False,
'memefficient_seg_resampling': False
}
return resampling_fn, resampling_fn_kwargs
Original file line number Diff line number Diff line change
Expand Up @@ -294,63 +294,6 @@ def __init__(self, dataset_name_or_id: Union[str, int],
self.max_dataset_covered = 1


class nnUNetPlannerResEncL_torchres(nnUNetPlannerResEncL):
def __init__(self, dataset_name_or_id: Union[str, int],
gpu_memory_target_in_gb: float = 24,
preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetLPlans_torchres',
overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None,
suppress_transpose: bool = False):
super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name,
overwrite_target_spacing, suppress_transpose)

def generate_data_identifier(self, configuration_name: str) -> str:
"""
configurations are unique within each plans file but different plans file can have configurations with the
same name. In order to distinguish the associated data we need a data identifier that reflects not just the
config but also the plans it originates from
"""
return self.plans_identifier + '_' + configuration_name

def determine_resampling(self, *args, **kwargs):
"""
returns what functions to use for resampling data and seg, respectively. Also returns kwargs
resampling function must be callable(data, current_spacing, new_spacing, **kwargs)
determine_resampling is called within get_plans_for_configuration to allow for different functions for each
configuration
"""
resampling_data = resample_torch_fornnunet
resampling_data_kwargs = {
"is_seg": False,
'force_separate_z': False,
'memefficient_seg_resampling': False
}
resampling_seg = resample_torch_fornnunet
resampling_seg_kwargs = {
"is_seg": True,
'force_separate_z': False,
'memefficient_seg_resampling': False
}
return resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs

def determine_segmentation_softmax_export_fn(self, *args, **kwargs):
"""
function must be callable(data, new_shape, current_spacing, new_spacing, **kwargs). The new_shape should be
used as target. current_spacing and new_spacing are merely there in case we want to use it somehow
determine_segmentation_softmax_export_fn is called within get_plans_for_configuration to allow for different
functions for each configuration
"""
resampling_fn = resample_torch_fornnunet
resampling_fn_kwargs = {
"is_seg": False,
'force_separate_z': False,
'memefficient_seg_resampling': False
}
return resampling_fn, resampling_fn_kwargs


if __name__ == '__main__':
# we know both of these networks run with batch size 2 and 12 on ~8-10GB, respectively
net = ResidualEncoderUNet(input_channels=1, n_stages=6, features_per_stage=(32, 64, 128, 256, 320, 320),
Expand Down
4 changes: 0 additions & 4 deletions nnunetv2/imageio/nibabel_reader_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ class NibabelIO(BaseReaderWriter):
supported_file_endings = [
'.nii',
'.nii.gz',
'.nrrd',
'.mha'
]

def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]:
Expand Down Expand Up @@ -110,8 +108,6 @@ class NibabelIOWithReorient(BaseReaderWriter):
supported_file_endings = [
'.nii',
'.nii.gz',
'.nrrd',
'.mha'
]

def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]:
Expand Down
12 changes: 8 additions & 4 deletions nnunetv2/preprocessing/preprocessors/default_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,15 +230,17 @@ def run(self, dataset_name_or_id: Union[int, str], configuration_name: str, plan
# multiprocessing magic.
r = []
with multiprocessing.get_context("spawn").Pool(num_processes) as p:
remaining = list(range(len(dataset)))
# p is pretty nifti. If we kill workers they just respawn but don't do any work.
# So we need to store the original pool of workers.
workers = [j for j in p._pool]

for k in dataset.keys():
r.append(p.starmap_async(self.run_case_save,
((join(output_directory, k), dataset[k]['images'], dataset[k]['label'],
plans_manager, configuration_manager,
dataset_json),)))
remaining = list(range(len(dataset)))
# p is pretty nifti. If we kill workers they just respawn but don't do any work.
# So we need to store the original pool of workers.
workers = [j for j in p._pool]

with tqdm(desc=None, total=len(dataset), disable=self.verbose) as pbar:
while len(remaining) > 0:
all_alive = all([j.is_alive() for j in workers])
Expand All @@ -251,6 +253,8 @@ def run(self, dataset_name_or_id: Union[int, str], configuration_name: str, plan
'an error message, out of RAM is likely the problem. In that case '
'reducing the number of workers might help')
done = [i for i in remaining if r[i].ready()]
# get done so that errors can be raised
_ = [r[i].get() for i in done]
for _ in done:
r[_].get() # allows triggering errors
pbar.update()
Expand Down
8 changes: 6 additions & 2 deletions nnunetv2/preprocessing/resampling/default_resampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,12 @@ def compute_new_shape(old_shape: Union[Tuple[int, ...], List[int], np.ndarray],
return new_shape


def determine_do_sep_z_and_axis(force_separate_z, current_spacing, new_spacing,
separate_z_anisotropy_threshold: float = ANISO_THRESHOLD) -> Tuple[bool, int]:

def determine_do_sep_z_and_axis(
force_separate_z: bool,
current_spacing,
new_spacing,
separate_z_anisotropy_threshold: float = ANISO_THRESHOLD) -> Tuple[bool, Union[int, None]]:
if force_separate_z is not None:
do_separate_z = force_separate_z
if force_separate_z:
Expand Down
4 changes: 3 additions & 1 deletion nnunetv2/utilities/plans_handling/plans_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def __init__(self, configuration_dict: dict):
conv_op = convert_dim_to_conv_op(dim)
instnorm = get_matching_instancenorm(dimension=dim)

convs_or_blocks = "n_conv_per_stage" if unet_class_name == "PlainConvUNet" else "n_blocks_per_stage"

arch_dict = {
'network_class_name': network_class_name,
'arch_kwargs': {
Expand All @@ -64,7 +66,7 @@ def __init__(self, configuration_dict: dict):
"conv_op": conv_op.__module__ + '.' + conv_op.__name__,
"kernel_sizes": deepcopy(self.configuration["conv_kernel_sizes"]),
"strides": deepcopy(self.configuration["pool_op_kernel_sizes"]),
"n_conv_per_stage": deepcopy(self.configuration["n_conv_per_stage_encoder"]),
convs_or_blocks: deepcopy(self.configuration["n_conv_per_stage_encoder"]),
"n_conv_per_stage_decoder": deepcopy(self.configuration["n_conv_per_stage_decoder"]),
"conv_bias": True,
"norm_op": instnorm.__module__ + '.' + instnorm.__name__,
Expand Down

0 comments on commit 23afd25

Please sign in to comment.