From 947eafbb9adb5eb06b9171330b4688e006e6f301 Mon Sep 17 00:00:00 2001 From: Fabian Isensee Date: Tue, 9 Jan 2024 15:44:35 +0100 Subject: [PATCH] if a splits_final.json exists in the raw dataset folder it will be copied to the preprocessed folder as part of the experiment planning --- .../default_experiment_planner.py | 23 ++++++++++++++++++- .../plan_and_preprocess_api.py | 9 ++++---- .../training/nnUNetTrainer/nnUNetTrainer.py | 4 ++-- 3 files changed, 28 insertions(+), 8 deletions(-) diff --git a/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py b/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py index 2b1c41247..ccb4a251e 100644 --- a/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py +++ b/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py @@ -1,4 +1,3 @@ -import os.path import shutil from copy import deepcopy from functools import lru_cache @@ -79,6 +78,10 @@ def __init__(self, dataset_name_or_id: Union[str, int], self.plans = None + if isfile(join(self.raw_dataset_folder, 'splits_final.json')): + _maybe_copy_splits_file(join(self.raw_dataset_folder, 'splits_final.json'), + join(preprocessed_folder, 'splits_final.json')) + def determine_reader_writer(self): example_image = self.dataset[self.dataset.keys().__iter__().__next__()]['images'][0] return determine_reader_writer_from_dataset_json(self.dataset_json, example_image) @@ -530,5 +533,23 @@ def load_plans(self, fname: str): self.plans = load_json(fname) +def _maybe_copy_splits_file(splits_file: str, target_fname: str): + if not isfile(target_fname): + shutil.copy(splits_file, target_fname) + else: + # split already exists, do not copy, but check that the splits match. + # This code allows target_fname to contain more splits than splits_file. This is OK. + splits_source = load_json(splits_file) + splits_target = load_json(target_fname) + # all folds in the source file must match the target file + for i in range(len(splits_source)): + train_source = set(splits_source[i]['train']) + train_target = set(splits_target[i]['train']) + assert train_target == train_source + val_source = set(splits_source[i]['val']) + val_target = set(splits_target[i]['val']) + assert val_source == val_target + + if __name__ == '__main__': ExperimentPlanner(2, 8).plan_experiment() diff --git a/nnunetv2/experiment_planning/plan_and_preprocess_api.py b/nnunetv2/experiment_planning/plan_and_preprocess_api.py index eb94840d7..8c74f7c61 100644 --- a/nnunetv2/experiment_planning/plan_and_preprocess_api.py +++ b/nnunetv2/experiment_planning/plan_and_preprocess_api.py @@ -1,17 +1,16 @@ -import shutil from typing import List, Type, Optional, Tuple, Union -import nnunetv2 -from batchgenerators.utilities.file_and_folder_operations import join, maybe_mkdir_p, subfiles, load_json +from batchgenerators.utilities.file_and_folder_operations import join, maybe_mkdir_p, load_json +import nnunetv2 +from nnunetv2.configuration import default_num_processes from nnunetv2.experiment_planning.dataset_fingerprint.fingerprint_extractor import DatasetFingerprintExtractor from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner from nnunetv2.experiment_planning.verify_dataset_integrity import verify_dataset_integrity from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed -from nnunetv2.utilities.dataset_name_id_conversion import convert_id_to_dataset_name, maybe_convert_to_dataset_name +from nnunetv2.utilities.dataset_name_id_conversion import convert_id_to_dataset_name from nnunetv2.utilities.find_class_by_name import recursive_find_python_class from nnunetv2.utilities.plans_handling.plans_handler import PlansManager -from nnunetv2.configuration import default_num_processes from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets diff --git a/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py b/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py index d355fd5ea..318be58e9 100644 --- a/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py +++ b/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py @@ -520,9 +520,9 @@ def plot_network_architecture(self): def do_split(self): """ The default split is a 5 fold CV on all available training cases. nnU-Net will create a split (it is seeded, - so always the same) and save it as splits_final.pkl file in the preprocessed data directory. + so always the same) and save it as splits_final.json file in the preprocessed data directory. Sometimes you may want to create your own split for various reasons. For this you will need to create your own - splits_final.pkl file. If this file is present, nnU-Net is going to use it and whatever splits are defined in + splits_final.json file. If this file is present, nnU-Net is going to use it and whatever splits are defined in it. You can create as many splits in this file as you want. Note that if you define only 4 splits (fold 0-3) and then set fold=4 when training (that would be the fifth split), nnU-Net will print a warning and proceed to use a random 80:20 data split.