Skip to content

Commit

Permalink
Fix vista3d transpose bug (#8059)
Browse files Browse the repository at this point in the history
Fixes # .

### Description

A few sentences describing the changes proposed in this pull request.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: heyufan1995 <[email protected]>
Signed-off-by: YunLiu <[email protected]>
Signed-off-by: Yiheng Wang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: YunLiu <[email protected]>
Co-authored-by: Yiheng Wang <[email protected]>
Co-authored-by: Yiheng Wang <[email protected]>
  • Loading branch information
5 people authored Sep 2, 2024
1 parent 7219ee7 commit 6a0e1b0
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
2 changes: 1 addition & 1 deletion monai/apps/vista3d/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def point_based_window_inferer(
point_labels=point_labels,
class_vector=class_vector,
prompt_class=prompt_class,
patch_coords=unravel_slice,
patch_coords=[unravel_slice],
prev_mask=prev_mask,
**kwargs,
)
Expand Down
16 changes: 10 additions & 6 deletions monai/networks/nets/vista3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def set_auto_grad(self, auto_freeze: bool = False, point_freeze: bool = False):
def forward(
self,
input_images: torch.Tensor,
patch_coords: Sequence[slice] | None = None,
patch_coords: list[Sequence[slice]] | None = None,
point_coords: torch.Tensor | None = None,
point_labels: torch.Tensor | None = None,
class_vector: torch.Tensor | None = None,
Expand Down Expand Up @@ -364,8 +364,12 @@ def forward(
the points are for zero-shot or supported class. When class_vector and point_coords are both
provided, prompt_class is the same as class_vector. For prompt_class[b] > 512, point_coords[b]
will be considered novel class.
patch_coords: a sequence of the python slice objects representing the patch coordinates during sliding window inference.
This value is passed from sliding_window_inferer. This is an indicator for training phase or validation phase.
patch_coords: a list of sequence of the python slice objects representing the patch coordinates during sliding window
inference. This value is passed from sliding_window_inferer.
This is an indicator for training phase or validation phase.
Notice for sliding window batch size > 1 (only supported by automatic segmentation), patch_coords will inlcude
coordinates of multiple patches. If point prompts are included, the batch size can only be one and all the
functions using patch_coords will by default use patch_coords[0].
labels: [1, 1, H, W, D], the groundtruth label tensor, only used for point-only evaluation
label_set: the label index matching the indexes in labels. If labels are mapped to global index using RelabelID,
this label_set should be global mapped index. If labels are not mapped to global index, e.g. in zero-shot
Expand Down Expand Up @@ -395,14 +399,14 @@ def forward(
if val_point_sampler is None:
# TODO: think about how to refactor this part.
val_point_sampler = self.sample_points_patch_val
point_coords, point_labels, prompt_class = val_point_sampler(labels, patch_coords, label_set)
point_coords, point_labels, prompt_class = val_point_sampler(labels, patch_coords[0], label_set)
if prompt_class[0].item() == 0: # type: ignore
point_labels[0] = -1 # type: ignore
labels, prev_mask = None, None
elif point_coords is not None:
# If not performing patch-based point only validation, use user provided click points for inference.
# the point clicks is in original image space, convert it to current patch-coordinate space.
point_coords, point_labels = self.update_point_to_patch(patch_coords, point_coords, point_labels) # type: ignore
point_coords, point_labels = self.update_point_to_patch(patch_coords[0], point_coords, point_labels) # type: ignore

if point_coords is not None and point_labels is not None:
# remove points that used for padding purposes (point_label = -1)
Expand Down Expand Up @@ -455,7 +459,7 @@ def forward(
logits[mapping_index] = self.point_head(out, point_coords, point_labels, class_vector=prompt_class)
if prev_mask is not None and patch_coords is not None:
logits = self.connected_components_combine(
prev_mask[patch_coords].transpose(1, 0).to(logits.device),
prev_mask[patch_coords[0]].transpose(1, 0).to(logits.device),
logits[mapping_index],
point_coords, # type: ignore
point_labels, # type: ignore
Expand Down

0 comments on commit 6a0e1b0

Please sign in to comment.