diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 33f7b8a53c..fdf4997e58 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -202,7 +202,7 @@ def update_meta(rets: Sequence, func, args, kwargs) -> Sequence: :py:func:`MetaTensor._copy_meta`). """ out = [] - metas = None + metas = None # optional output metadicts for each of the return value in `rets` is_batch = any(x.is_batch for x in MetaObj.flatten_meta_objs(args, kwargs.values()) if hasattr(x, "is_batch")) for idx, ret in enumerate(rets): # if not `MetaTensor`, nothing to do. @@ -219,55 +219,61 @@ def update_meta(rets: Sequence, func, args, kwargs) -> Sequence: # the following is not implemented but the network arch may run into this case: # if func == torch.cat and any(m.is_batch if hasattr(m, "is_batch") else False for m in meta_args): # raise NotImplementedError("torch.cat is not implemented for batch of MetaTensors.") - - # If we have a batch of data, then we need to be careful if a slice of - # the data is returned. Depending on how the data are indexed, we return - # some or all of the metadata, and the return object may or may not be a - # batch of data (e.g., `batch[:,-1]` versus `batch[0]`). if is_batch: - # if indexing e.g., `batch[0]` - if func == torch.Tensor.__getitem__: - batch_idx = args[1] - if isinstance(batch_idx, Sequence): - batch_idx = batch_idx[0] - # if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the - # first element will be `slice(None, None, None)` and `Ellipsis`, - # respectively. Don't need to do anything with the metadata. - if batch_idx not in (slice(None, None, None), Ellipsis, None) and idx == 0: - ret_meta = decollate_batch(args[0], detach=False)[batch_idx] - if isinstance(ret_meta, list) and ret_meta: # e.g. batch[0:2], re-collate - try: - ret_meta = list_data_collate(ret_meta) - except (TypeError, ValueError, RuntimeError, IndexError) as e: - raise ValueError( - "Inconsistent batched metadata dicts when slicing a batch of MetaTensors, " - "please convert it into a torch Tensor using `x.as_tensor()` or " - "a numpy array using `x.array`." - ) from e - elif isinstance(ret_meta, MetaObj): # e.g. `batch[0]` or `batch[0, 1]`, batch_idx is int - ret_meta.is_batch = False - if hasattr(ret_meta, "__dict__"): - ret.__dict__ = ret_meta.__dict__.copy() - # `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`. - # But we only want to split the batch if the `unbind` is along the 0th - # dimension. - elif func == torch.Tensor.unbind: - if len(args) > 1: - dim = args[1] - elif "dim" in kwargs: - dim = kwargs["dim"] - else: - dim = 0 - if dim == 0: - if metas is None: - metas = decollate_batch(args[0], detach=False) - ret.__dict__ = metas[idx].__dict__.copy() - ret.is_batch = False - + ret = MetaTensor._handle_batched(ret, idx, metas, func, args, kwargs) out.append(ret) # if the input was a tuple, then return it as a tuple return tuple(out) if isinstance(rets, tuple) else out + @classmethod + def _handle_batched(cls, ret, idx, metas, func, args, kwargs): + """utility function to handle batched MetaTensors.""" + # If we have a batch of data, then we need to be careful if a slice of + # the data is returned. Depending on how the data are indexed, we return + # some or all of the metadata, and the return object may or may not be a + # batch of data (e.g., `batch[:,-1]` versus `batch[0]`). + # if indexing e.g., `batch[0]` + if func == torch.Tensor.__getitem__: + if idx > 0 or len(args) < 2 or len(args[0]) < 1: + return ret + batch_idx = args[1][0] if isinstance(args[1], Sequence) else args[1] + # if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the + # first element will be `slice(None, None, None)` and `Ellipsis`, + # respectively. Don't need to do anything with the metadata. + if batch_idx in (slice(None, None, None), Ellipsis, None) or isinstance(batch_idx, torch.Tensor): + return ret + dec_batch = decollate_batch(args[0], detach=False) + ret_meta = dec_batch[batch_idx] + if isinstance(ret_meta, list) and ret_meta: # e.g. batch[0:2], re-collate + try: + ret_meta = list_data_collate(ret_meta) + except (TypeError, ValueError, RuntimeError, IndexError) as e: + raise ValueError( + "Inconsistent batched metadata dicts when slicing a batch of MetaTensors, " + "please consider converting it into a torch Tensor using `x.as_tensor()` or " + "a numpy array using `x.array`." + ) from e + elif isinstance(ret_meta, MetaObj): # e.g. `batch[0]` or `batch[0, 1]`, batch_idx is int + ret_meta.is_batch = False + if hasattr(ret_meta, "__dict__"): + ret.__dict__ = ret_meta.__dict__.copy() + # `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`. + # But we only want to split the batch if the `unbind` is along the 0th dimension. + elif func == torch.Tensor.unbind: + if len(args) > 1: + dim = args[1] + elif "dim" in kwargs: + dim = kwargs["dim"] + else: + dim = 0 + if dim == 0: + if metas is None: + metas = decollate_batch(args[0], detach=False) + if hasattr(metas[idx], "__dict__"): + ret.__dict__ = metas[idx].__dict__.copy() + ret.is_batch = False + return ret + @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None) -> Any: """Wraps all torch functions.""" diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index 739955ea67..0cd0522036 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -413,6 +413,10 @@ def test_slicing(self): x.is_batch = True with self.assertRaises(ValueError): x[slice(0, 8)] + x = MetaTensor(np.zeros((3, 3, 4))) + x.is_batch = True + self.assertEqual(x[torch.tensor([True, False, True])].shape, (2, 3, 4)) + self.assertEqual(x[[True, False, True]].shape, (2, 3, 4)) @parameterized.expand(DTYPES) @SkipIfBeforePyTorchVersion((1, 8))