diff --git a/nnunetv2/inference/predict_from_raw_data.py b/nnunetv2/inference/predict_from_raw_data.py index 1fe69ea1b..d211b07ff 100644 --- a/nnunetv2/inference/predict_from_raw_data.py +++ b/nnunetv2/inference/predict_from_raw_data.py @@ -64,9 +64,13 @@ def __init__(self, self.device = device self.perform_everything_on_gpu = perform_everything_on_gpu - def initialize_from_trained_model_folder(self, model_training_output_dir: str, - use_folds: Union[Tuple[Union[int, str]], None], - checkpoint_name: str = 'checkpoint_final.pth'): + def initialize_from_trained_model_folder( + self, + model_training_output_dir: str, + use_folds: Union[Tuple[Union[int, str]], None], + checkpoint_name: str = 'checkpoint_final.pth', + disable_compilation: bool = False, + ): """ This is used when making predictions with a trained model """ @@ -109,7 +113,7 @@ def initialize_from_trained_model_folder(self, model_training_output_dir: str, self.allowed_mirroring_axes = inference_allowed_mirroring_axes self.label_manager = plans_manager.get_label_manager(dataset_json) if ('nnUNet_compile' in os.environ.keys()) and (os.environ['nnUNet_compile'].lower() in ('true', '1', 't')) \ - and not isinstance(self.network, OptimizedModule): + and not isinstance(self.network, OptimizedModule) and not disable_compilation: print('compiling network') self.network = torch.compile(self.network) diff --git a/nnunetv2/model_sharing/entry_points.py b/nnunetv2/model_sharing/entry_points.py index 1ab7c9351..29f05beae 100644 --- a/nnunetv2/model_sharing/entry_points.py +++ b/nnunetv2/model_sharing/entry_points.py @@ -1,28 +1,35 @@ +from pathlib import Path + from nnunetv2.model_sharing.model_download import download_and_install_from_url from nnunetv2.model_sharing.model_export import export_pretrained_model from nnunetv2.model_sharing.model_import import install_model_from_zip_file +from nnunetv2.model_sharing.onnx_export import export_onnx_model def print_license_warning(): - print('') - print('######################################################') - print('!!!!!!!!!!!!!!!!!!!!!!!!WARNING!!!!!!!!!!!!!!!!!!!!!!!') - print('######################################################') - print("Using the pretrained model weights is subject to the license of the dataset they were trained on. Some " - "allow commercial use, others don't. It is your responsibility to make sure you use them appropriately! Use " - "nnUNet_print_pretrained_model_info(task_name) to see a summary of the dataset and where to find its license!") - print('######################################################') - print('') + print("") + print("######################################################") + print("!!!!!!!!!!!!!!!!!!!!!!!!WARNING!!!!!!!!!!!!!!!!!!!!!!!") + print("######################################################") + print( + "Using the pretrained model weights is subject to the license of the dataset they were trained on. Some " + "allow commercial use, others don't. It is your responsibility to make sure you use them appropriately! Use " + "nnUNet_print_pretrained_model_info(task_name) to see a summary of the dataset and where to find its license!" + ) + print("######################################################") + print("") def download_by_url(): import argparse + parser = argparse.ArgumentParser( description="Use this to download pretrained models. This script is intended to download models via url only. " - "CAREFUL: This script will overwrite " - "existing models (if they share the same trainer class and plans as " - "the pretrained model.") - parser.add_argument("url", type=str, help='URL of the pretrained model') + "CAREFUL: This script will overwrite " + "existing models (if they share the same trainer class and plans as " + "the pretrained model." + ) + parser.add_argument("url", type=str, help="URL of the pretrained model") args = parser.parse_args() url = args.url download_and_install_from_url(url) @@ -30,9 +37,11 @@ def download_by_url(): def install_from_zip_entry_point(): import argparse + parser = argparse.ArgumentParser( - description="Use this to install a zip file containing a pretrained model.") - parser.add_argument("zip", type=str, help='zip file') + description="Use this to install a zip file containing a pretrained model." + ) + parser.add_argument("zip", type=str, help="zip file") args = parser.parse_args() zip = args.zip install_model_from_zip_file(zip) @@ -40,22 +49,156 @@ def install_from_zip_entry_point(): def export_pretrained_model_entry(): import argparse + + parser = argparse.ArgumentParser( + description="Use this to export a trained model as a zip file." + ) + parser.add_argument("-d", type=str, required=True, help="Dataset name or id") + parser.add_argument("-o", type=str, required=True, help="Output file name") + parser.add_argument( + "-c", + nargs="+", + type=str, + required=False, + default=("3d_lowres", "3d_fullres", "2d", "3d_cascade_fullres"), + help="List of configuration names", + ) + parser.add_argument( + "-tr", required=False, type=str, default="nnUNetTrainer", help="Trainer class" + ) + parser.add_argument( + "-p", required=False, type=str, default="nnUNetPlans", help="plans identifier" + ) + parser.add_argument( + "-f", + required=False, + nargs="+", + type=str, + default=(0, 1, 2, 3, 4), + help="list of fold ids", + ) + parser.add_argument( + "-chk", + required=False, + nargs="+", + type=str, + default=("checkpoint_final.pth",), + help="List of checkpoint names to export. Default: checkpoint_final.pth", + ) + parser.add_argument( + "--not_strict", + action="store_false", + default=False, + required=False, + help="Set this to allow missing folds and/or configurations", + ) + parser.add_argument( + "--exp_cv_preds", + action="store_true", + required=False, + help="Set this to export the cross-validation predictions as well", + ) + args = parser.parse_args() + + export_pretrained_model( + dataset_name_or_id=args.d, + output_file=args.o, + configurations=args.c, + trainer=args.tr, + plans_identifier=args.p, + folds=args.f, + strict=not args.not_strict, + save_checkpoints=args.chk, + export_crossval_predictions=args.exp_cv_preds, + ) + + +def export_pretrained_model_onnx_entry(): + import argparse + parser = argparse.ArgumentParser( - description="Use this to export a trained model as a zip file.") - parser.add_argument('-d', type=str, required=True, help='Dataset name or id') - parser.add_argument('-o', type=str, required=True, help='Output file name') - parser.add_argument('-c', nargs='+', type=str, required=False, - default=('3d_lowres', '3d_fullres', '2d', '3d_cascade_fullres'), - help="List of configuration names") - parser.add_argument('-tr', required=False, type=str, default='nnUNetTrainer', help='Trainer class') - parser.add_argument('-p', required=False, type=str, default='nnUNetPlans', help='plans identifier') - parser.add_argument('-f', required=False, nargs='+', type=str, default=(0, 1, 2, 3, 4), help='list of fold ids') - parser.add_argument('-chk', required=False, nargs='+', type=str, default=('checkpoint_final.pth', ), - help='Lis tof checkpoint names to export. Default: checkpoint_final.pth') - parser.add_argument('--not_strict', action='store_false', default=False, required=False, help='Set this to allow missing folds and/or configurations') - parser.add_argument('--exp_cv_preds', action='store_true', required=False, help='Set this to export the cross-validation predictions as well') + description="Use this to export a trained model to ONNX format." + "You are responsible for creating the ONNX pipeline yourself." + ) + parser.add_argument("-d", type=str, required=True, help="Dataset name or id") + parser.add_argument("-o", type=Path, required=True, help="Output directory") + parser.add_argument( + "-c", + nargs="+", + type=str, + required=False, + default=("3d_lowres", "3d_fullres", "2d", "3d_cascade_fullres"), + help="List of configuration names", + ) + parser.add_argument( + "-tr", required=False, type=str, default="nnUNetTrainer", help="Trainer class" + ) + parser.add_argument( + "-p", required=False, type=str, default="nnUNetPlans", help="plans identifier" + ) + parser.add_argument( + "-f", required=False, nargs="+", type=str, default=None, help="list of fold ids" + ) + parser.add_argument( + "-b", + required=False, + type=int, + default=0, + help="Batch size. Set to 0 for dynamic axes. Default: 0", + ) + parser.add_argument( + "-chk", + required=False, + nargs="+", + type=str, + default=("checkpoint_final.pth",), + help="List of checkpoint names to export. Default: checkpoint_final.pth", + ) + parser.add_argument( + "--not_strict", + action="store_false", + default=False, + required=False, + help="Set this to allow missing folds and/or configurations", + ) + parser.add_argument( + "-v", + action="store_false", + default=False, + required=False, + help="Set this to get verbose output", + ) args = parser.parse_args() - export_pretrained_model(dataset_name_or_id=args.d, output_file=args.o, configurations=args.c, trainer=args.tr, - plans_identifier=args.p, folds=args.f, strict=not args.not_strict, save_checkpoints=args.chk, - export_crossval_predictions=args.exp_cv_preds) + print("######################################################") + print("!!!!!!!!!!!!!!!!!!!!!!!!WARNING!!!!!!!!!!!!!!!!!!!!!!!") + print("######################################################") + print( + "Exported models are provided as-is, without any\n" + "guarantees, warranties and/or support from MIC-DKFZ,\n" + "any associated persons and/or other entities.\n" + ) + print( + "You will bear sole responsibility for the proper\n" + "use of the exported models.\n" + ) + print( + "You are responsible for creating and validating\n" + "the ONNX pipeline yourself. To this end we provide\n" + "the .onnx file, and a config.json containing any\n" + "details you might need." + ) + print("######################################################\n") + + export_onnx_model( + dataset_name_or_id=args.d, + output_dir=args.o, + configurations=args.c, + batch_size=args.b, + trainer=args.tr, + plans_identifier=args.p, + folds=args.f, + strict=not args.not_strict, + save_checkpoints=args.chk, + verbose=args.v, + ) diff --git a/nnunetv2/model_sharing/onnx_export.py b/nnunetv2/model_sharing/onnx_export.py new file mode 100644 index 000000000..4d099e0a4 --- /dev/null +++ b/nnunetv2/model_sharing/onnx_export.py @@ -0,0 +1,196 @@ +import json +from os.path import isdir, join +from pathlib import Path +from typing import Tuple, Union + +import numpy as np +import onnx +import onnxruntime +import torch +from batchgenerators.utilities.file_and_folder_operations import load_json + +from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor +from nnunetv2.utilities.dataset_name_id_conversion import \ + maybe_convert_to_dataset_name +from nnunetv2.utilities.file_path_utilities import get_output_folder + + +def export_onnx_model( + dataset_name_or_id: Union[int, str], + output_dir: Path, + configurations: Tuple[str] = ( + "2d", + "3d_lowres", + "3d_fullres", + "3d_cascade_fullres", + ), + batch_size: int = 0, + trainer: str = "nnUNetTrainer", + plans_identifier: str = "nnUNetPlans", + folds: Tuple[Union[int, str], ...] = (0, 1, 2, 3, 4), + strict: bool = True, + save_checkpoints: Tuple[str, ...] = ("checkpoint_final.pth",), + output_names: tuple[str, ...] = None, + verbose: bool = False, +) -> None: + if not output_names: + output_names = (f"{checkpoint[:-4]}.onnx" for checkpoint in save_checkpoints) + + if batch_size < 0: + raise ValueError("batch_size must be non-negative") + + use_dynamic_axes = batch_size == 0 + + dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id) + for c in configurations: + print(f"Configuration {c}") + trainer_output_dir = get_output_folder( + dataset_name, trainer, plans_identifier, c + ) + dataset_json = load_json(join(trainer_output_dir, "dataset.json")) + + # While we load in this file indirectly, we need the plans file to + # determine the foreground intensity properties. + plans = load_json(join(trainer_output_dir, 'plans.json')) + foreground_intensity_properties = plans['foreground_intensity_properties_per_channel'] + + if not isdir(trainer_output_dir): + if strict: + raise RuntimeError( + f"{dataset_name} is missing the trained model of configuration {c}" + ) + else: + print(f"Skipping configuration {c}, does not exist") + continue + + predictor = nnUNetPredictor( + perform_everything_on_gpu=False, + device=torch.device("cpu"), + ) + + for checkpoint_name, output_name in zip(save_checkpoints, output_names): + predictor.initialize_from_trained_model_folder( + model_training_output_dir=trainer_output_dir, + use_folds=folds, + checkpoint_name=checkpoint_name, + disable_compilation=True, + ) + + list_of_parameters = predictor.list_of_parameters + network = predictor.network + config = predictor.configuration_manager + + for fold, params in zip(folds, list_of_parameters): + network.load_state_dict(params) + + network.eval() + + curr_output_dir = output_dir / c / f"fold_{fold}" + if not curr_output_dir.exists(): + curr_output_dir.mkdir(parents=True) + else: + if len(list(curr_output_dir.iterdir())) > 0: + raise RuntimeError( + f"Output directory {curr_output_dir} is not empty" + ) + + if use_dynamic_axes: + rand_input = torch.rand((1, 1, *config.patch_size)) + torch_output = network(rand_input) + + torch.onnx.export( + network, + rand_input, + curr_output_dir / output_name, + export_params=True, + verbose=verbose, + input_names=["input"], + output_names=["output"], + dynamic_axes={ + "input": {0: "batch_size"}, + "output": {0: "batch_size"}, + }, + ) + else: + rand_input = torch.rand((batch_size, 1, *config.patch_size)) + torch_output = network(rand_input) + + torch.onnx.export( + network, + rand_input, + curr_output_dir / output_name, + export_params=True, + verbose=verbose, + input_names=["input"], + output_names=["output"], + ) + + onnx_model = onnx.load(curr_output_dir / output_name) + onnx.checker.check_model(onnx_model) + + ort_session = onnxruntime.InferenceSession( + curr_output_dir / output_name, providers=["CPUExecutionProvider"] + ) + ort_inputs = {ort_session.get_inputs()[0].name: rand_input.numpy()} + ort_outs = ort_session.run(None, ort_inputs) + + try: + np.testing.assert_allclose( + torch_output.detach().cpu().numpy(), + ort_outs[0], + rtol=1e-03, + atol=1e-05, + verbose=True, + ) + except AssertionError as e: + print("WARN: Differences found between torch and onnx:\n") + print(e) + print( + "\nExport will continue, but please verify that your pipeline matches the original." + ) + + print(f"Exported {curr_output_dir / output_name}") + + with open(curr_output_dir / "config.json", "w") as f: + config_dict = { + "configuration": c, + "fold": fold, + "model_parameters": { + "batch_size": batch_size + if not use_dynamic_axes + else "dynamic", + "patch_size": config.patch_size, + "spacing": config.spacing, + "normalization_schemes": config.normalization_schemes, + # These are mostly interesting for certification + # uses, but they are also useful for debugging. + "UNet_class_name": config.UNet_class_name, + "UNet_base_num_features": config.UNet_base_num_features, + "unet_max_num_features": config.unet_max_num_features, + "conv_kernel_sizes": config.conv_kernel_sizes, + "pool_op_kernel_sizes": config.pool_op_kernel_sizes, + "num_pool_per_axis": config.num_pool_per_axis, + }, + "dataset_parameters": { + "dataset_name": dataset_name, + "num_channels": len(dataset_json["channel_names"].keys()), + "channels": { + k: { + "name": v, + # For when normalization is not Z-Score + "foreground_properties": foreground_intensity_properties[k], + } + for k, v in dataset_json["channel_names"].items() + }, + "num_classes": len(dataset_json["labels"].keys()), + "class_names": { + v: k for k, v in dataset_json["labels"].items() + }, + }, + } + + json.dump( + config_dict, + f, + indent=4, + ) diff --git a/pyproject.toml b/pyproject.toml index 91bc31563..e7d6e1dec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,6 +74,7 @@ nnUNetv2_plot_overlay_pngs = "nnunetv2.utilities.overlay_plots:entry_point_gener nnUNetv2_download_pretrained_model_by_url = "nnunetv2.model_sharing.entry_points:download_by_url" nnUNetv2_install_pretrained_model_from_zip = "nnunetv2.model_sharing.entry_points:install_from_zip_entry_point" nnUNetv2_export_model_to_zip = "nnunetv2.model_sharing.entry_points:export_pretrained_model_entry" +nnUNetv2_export_model_to_onnx = "nnunetv2.model_sharing.entry_points:export_pretrained_model_onnx_entry" nnUNetv2_move_plans_between_datasets = "nnunetv2.experiment_planning.plans_for_pretraining.move_plans_between_datasets:entry_point_move_plans_between_datasets" nnUNetv2_evaluate_folder = "nnunetv2.evaluation.evaluate_predictions:evaluate_folder_entry_point" nnUNetv2_evaluate_simple = "nnunetv2.evaluation.evaluate_predictions:evaluate_simple_entry_point" @@ -85,6 +86,11 @@ dev = [ "ruff", "pre-commit" ] +onnx-export = [ + "onnx", + "onnxscript", + "onnxruntime", +] [tool.codespell] skip = '.git,*.pdf,*.svg'