Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast inference for large files with several classes #2540 #2545

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 33 additions & 16 deletions nnunetv2/inference/export_prediction.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import os
from copy import deepcopy
import time
from typing import Union, List

import numpy as np
import torch
from acvl_utils.cropping_and_padding.bounding_boxes import bounding_box_to_slice
from batchgenerators.utilities.file_and_folder_operations import load_json, isfile, save_pickle
from batchgenerators.utilities.file_and_folder_operations import load_json, save_pickle

from nnunetv2.configuration import default_num_processes
from nnunetv2.utilities.label_handling.label_handling import LabelManager
Expand All @@ -21,30 +20,45 @@ def convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits
num_threads_torch: int = default_num_processes):
old_threads = torch.get_num_threads()
torch.set_num_threads(num_threads_torch)

# resample to original shape
spacing_transposed = [properties_dict['spacing'][i] for i in plans_manager.transpose_forward]
current_spacing = configuration_manager.spacing if \
len(configuration_manager.spacing) == \
len(properties_dict['shape_after_cropping_and_before_resampling']) else \
[spacing_transposed[0], *configuration_manager.spacing]
predicted_logits = configuration_manager.resampling_fn_probabilities(predicted_logits,
properties_dict['shape_after_cropping_and_before_resampling'],
current_spacing,
[properties_dict['spacing'][i] for i in plans_manager.transpose_forward])
# return value of resampling_fn_probabilities can be ndarray or Tensor but that does not matter because
# apply_inference_nonlin will convert to torch
predicted_probabilities = label_manager.apply_inference_nonlin(predicted_logits)
del predicted_logits
segmentation = label_manager.convert_probabilities_to_segmentation(predicted_probabilities)

if return_probabilities:
predicted_logits = configuration_manager.resampling_fn_probabilities(predicted_logits,
properties_dict[
'shape_after_cropping_and_before_resampling'],
current_spacing,
[properties_dict['spacing'][i] for i in
plans_manager.transpose_forward])
# return value of resampling_fn_probabilities can be ndarray or Tensor but that does not matter because
# apply_inference_nonlin will convert to torch
predicted_probabilities = label_manager.apply_inference_nonlin(predicted_logits)
del predicted_logits
segmentation = label_manager.convert_probabilities_to_segmentation(predicted_probabilities)
else:
predicted_probabilities = label_manager.apply_inference_nonlin(predicted_logits)
del predicted_logits
segmentation = label_manager.convert_probabilities_to_segmentation(predicted_probabilities)
segmentation = configuration_manager.resampling_fn_probabilities(segmentation.unsqueeze(0),
properties_dict[
'shape_after_cropping_and_before_resampling'],
current_spacing,
[properties_dict['spacing'][i] for i in
plans_manager.transpose_forward],
order=0
)
# segmentation may be torch.Tensor but we continue with numpy
if isinstance(segmentation, torch.Tensor):
segmentation = segmentation.cpu().numpy()

# put segmentation in bbox (revert cropping)
segmentation_reverted_cropping = np.zeros(properties_dict['shape_before_cropping'],
dtype=np.uint8 if len(label_manager.foreground_labels) < 255 else np.uint16)
dtype=np.uint8 if len(
label_manager.foreground_labels) < 255 else np.uint16)
slicer = bounding_box_to_slice(properties_dict['bbox_used_for_cropping'])
segmentation_reverted_cropping[slicer] = segmentation
del segmentation
Expand Down Expand Up @@ -81,7 +95,8 @@ def export_prediction_from_logits(predicted_array_or_file: Union[np.ndarray, tor
# elif predicted_array_or_file.endswith('.npz'):
# predicted_array_or_file = np.load(predicted_array_or_file)['softmax']
# os.remove(tmp)

print("[INFO] Start working on export_prediction_from_logits")
tic = time.time()
if isinstance(dataset_json_dict_or_file, str):
dataset_json_dict_or_file = load_json(dataset_json_dict_or_file)

Expand All @@ -105,6 +120,7 @@ def export_prediction_from_logits(predicted_array_or_file: Union[np.ndarray, tor
rw = plans_manager.image_reader_writer_class()
rw.write_seg(segmentation_final, output_file_truncated + dataset_json_dict_or_file['file_ending'],
properties_dict)
print(f"[INFO] Elapsed time for export_prediction_from_logits: {time.time() - tic}")


def resample_and_save(predicted: Union[torch.Tensor, np.ndarray], target_shape: List[int], output_file: str,
Expand All @@ -130,7 +146,8 @@ def resample_and_save(predicted: Union[torch.Tensor, np.ndarray], target_shape:
len(configuration_manager.spacing) == len(properties_dict['shape_after_cropping_and_before_resampling']) else \
[spacing_transposed[0], *configuration_manager.spacing]
target_spacing = configuration_manager.spacing if len(configuration_manager.spacing) == \
len(properties_dict['shape_after_cropping_and_before_resampling']) else \
len(properties_dict[
'shape_after_cropping_and_before_resampling']) else \
[spacing_transposed[0], *configuration_manager.spacing]
predicted_array_or_file = configuration_manager.resampling_fn_probabilities(predicted,
target_shape,
Expand Down
2 changes: 1 addition & 1 deletion nnunetv2/inference/predict_from_raw_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(self,
self.perform_everything_on_device = perform_everything_on_device

def initialize_from_trained_model_folder(self, model_training_output_dir: str,
use_folds: Union[Tuple[Union[int, str]], None],
use_folds: Union[Tuple[Union[int, str], ...], None],
checkpoint_name: str = 'checkpoint_final.pth'):
"""
This is used when making predictions with a trained model
Expand Down