Skip to content

Commit

Permalink
remove num_predictons variable
Browse files Browse the repository at this point in the history
  • Loading branch information
FabianIsensee committed Dec 5, 2023
1 parent 24f059a commit 15c3658
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions nnunetv2/inference/predict_from_raw_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,13 +549,12 @@ def _internal_maybe_mirror_and_predict(self, x: torch.Tensor) -> torch.Tensor:
# x should be 5d for 3d images and 4d for 2d. so the max value of mirror_axes cannot exceed len(x.shape) - 3
assert max(mirror_axes) <= x.ndim - 3, 'mirror_axes does not match the dimension of the input!'

num_predictons = 2 ** len(mirror_axes)
axes_combinations = [
c for i in range(len(mirror_axes)) for c in itertools.combinations([m + 2 for m in mirror_axes], i + 1)
]
for axes in axes_combinations:
prediction += torch.flip(self.network(torch.flip(x, (*axes,))), (*axes,))
prediction /= num_predictons
prediction /= (len(axes_combinations) + 1)
return prediction

def predict_sliding_window_return_logits(self, input_image: torch.Tensor) \
Expand Down

0 comments on commit 15c3658

Please sign in to comment.