Skip to content

Commit

Permalink
4757 update patch merging (#4758)
Browse files Browse the repository at this point in the history
* update patch merging

Signed-off-by: Wenqi Li <[email protected]>

* fixes unit tests

Signed-off-by: Wenqi Li <[email protected]>
  • Loading branch information
wyli authored Jul 25, 2022
1 parent 178e973 commit 356d2d2
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 11 deletions.
2 changes: 1 addition & 1 deletion monai/networks/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
seresnext50,
seresnext101,
)
from .swin_unetr import SwinUNETR
from .swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR
from .torchvision_fc import TorchVisionFCModel
from .transchex import BertAttention, BertMixedLayer, BertOutput, BertPreTrainedModel, MultiModal, Pooler, Transchex
from .unet import UNet, Unet
Expand Down
71 changes: 62 additions & 9 deletions monai/networks/nets/swin_unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Sequence, Tuple, Type, Union
from typing import Optional, Sequence, Tuple, Type, Union

import numpy as np
import torch
Expand All @@ -21,10 +21,23 @@
from monai.networks.blocks import MLPBlock as Mlp
from monai.networks.blocks import PatchEmbed, UnetOutBlock, UnetrBasicBlock, UnetrUpBlock
from monai.networks.layers import DropPath, trunc_normal_
from monai.utils import ensure_tuple_rep, optional_import
from monai.utils import ensure_tuple_rep, look_up_option, optional_import

rearrange, _ = optional_import("einops", name="rearrange")

__all__ = [
"SwinUNETR",
"window_partition",
"window_reverse",
"WindowAttention",
"SwinTransformerBlock",
"PatchMerging",
"PatchMergingV2",
"MERGING_MODE",
"BasicLayer",
"SwinTransformer",
]


class SwinUNETR(nn.Module):
"""
Expand All @@ -48,6 +61,7 @@ def __init__(
normalize: bool = True,
use_checkpoint: bool = False,
spatial_dims: int = 3,
downsample="merging",
) -> None:
"""
Args:
Expand All @@ -64,6 +78,9 @@ def __init__(
normalize: normalize output intermediate features in each stage.
use_checkpoint: use gradient checkpointing for reduced memory usage.
spatial_dims: number of spatial dims.
downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a
user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`.
The default is currently `"merging"` (the original version defined in v0.9.0).
Examples::
Expand Down Expand Up @@ -121,6 +138,7 @@ def __init__(
norm_layer=nn.LayerNorm,
use_checkpoint=use_checkpoint,
spatial_dims=spatial_dims,
downsample=look_up_option(downsample, MERGING_MODE) if isinstance(downsample, str) else downsample,
)

self.encoder1 = UnetrBasicBlock(
Expand Down Expand Up @@ -657,7 +675,7 @@ def forward(self, x, mask_matrix):
return x


class PatchMerging(nn.Module):
class PatchMergingV2(nn.Module):
"""
Patch merging layer based on: "Liu et al.,
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
Expand Down Expand Up @@ -695,8 +713,8 @@ def forward(self, x):
x2 = x[:, 0::2, 1::2, 0::2, :]
x3 = x[:, 0::2, 0::2, 1::2, :]
x4 = x[:, 1::2, 0::2, 1::2, :]
x5 = x[:, 0::2, 1::2, 0::2, :]
x6 = x[:, 0::2, 0::2, 1::2, :]
x5 = x[:, 1::2, 1::2, 0::2, :]
x6 = x[:, 0::2, 1::2, 1::2, :]
x7 = x[:, 1::2, 1::2, 1::2, :]
x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1)

Expand All @@ -716,6 +734,36 @@ def forward(self, x):
return x


class PatchMerging(PatchMergingV2):
"""The `PatchMerging` module previously defined in v0.9.0."""

def forward(self, x):
x_shape = x.size()
if len(x_shape) == 4:
return super().forward(x)
if len(x_shape) != 5:
raise ValueError(f"expecting 5D x, got {x.shape}.")
b, d, h, w, c = x_shape
pad_input = (h % 2 == 1) or (w % 2 == 1) or (d % 2 == 1)
if pad_input:
x = F.pad(x, (0, 0, 0, w % 2, 0, h % 2, 0, d % 2))
x0 = x[:, 0::2, 0::2, 0::2, :]
x1 = x[:, 1::2, 0::2, 0::2, :]
x2 = x[:, 0::2, 1::2, 0::2, :]
x3 = x[:, 0::2, 0::2, 1::2, :]
x4 = x[:, 1::2, 0::2, 1::2, :]
x5 = x[:, 0::2, 1::2, 0::2, :]
x6 = x[:, 0::2, 0::2, 1::2, :]
x7 = x[:, 1::2, 1::2, 1::2, :]
x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1)
x = self.norm(x)
x = self.reduction(x)
return x


MERGING_MODE = {"merging": PatchMerging, "mergingv2": PatchMergingV2}


def compute_mask(dims, window_size, shift_size, device):
"""Computing region masks based on: "Liu et al.,
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
Expand Down Expand Up @@ -776,7 +824,7 @@ def __init__(
drop: float = 0.0,
attn_drop: float = 0.0,
norm_layer: Type[LayerNorm] = nn.LayerNorm,
downsample: isinstance = None, # type: ignore
downsample: Optional[nn.Module] = None,
use_checkpoint: bool = False,
) -> None:
"""
Expand All @@ -791,7 +839,7 @@ def __init__(
drop: dropout rate.
attn_drop: attention dropout rate.
norm_layer: normalization layer.
downsample: downsample layer at the end of the layer.
downsample: an optional downsampling layer at the end of the layer.
use_checkpoint: use gradient checkpointing for reduced memory usage.
"""

Expand Down Expand Up @@ -820,7 +868,7 @@ def __init__(
]
)
self.downsample = downsample
if self.downsample is not None:
if callable(self.downsample):
self.downsample = downsample(dim=dim, norm_layer=norm_layer, spatial_dims=len(self.window_size))

def forward(self, x):
Expand Down Expand Up @@ -881,6 +929,7 @@ def __init__(
patch_norm: bool = False,
use_checkpoint: bool = False,
spatial_dims: int = 3,
downsample="merging",
) -> None:
"""
Args:
Expand All @@ -899,6 +948,9 @@ def __init__(
patch_norm: add normalization after patch embedding.
use_checkpoint: use gradient checkpointing for reduced memory usage.
spatial_dims: spatial dimension.
downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a
user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`.
The default is currently `"merging"` (the original version defined in v0.9.0).
"""

super().__init__()
Expand All @@ -920,6 +972,7 @@ def __init__(
self.layers2 = nn.ModuleList()
self.layers3 = nn.ModuleList()
self.layers4 = nn.ModuleList()
down_sample_mod = look_up_option(downsample, MERGING_MODE) if isinstance(downsample, str) else downsample
for i_layer in range(self.num_layers):
layer = BasicLayer(
dim=int(embed_dim * 2**i_layer),
Expand All @@ -932,7 +985,7 @@ def __init__(
drop=drop_rate,
attn_drop=attn_drop_rate,
norm_layer=norm_layer,
downsample=PatchMerging,
downsample=down_sample_mod,
use_checkpoint=use_checkpoint,
)
if i_layer == 0:
Expand Down
6 changes: 5 additions & 1 deletion tests/test_swin_unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
from parameterized import parameterized

from monai.networks import eval_mode
from monai.networks.nets.swin_unetr import PatchMerging, SwinUNETR
from monai.networks.nets.swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR
from monai.utils import optional_import

einops, has_einops = optional_import("einops")

TEST_CASE_SWIN_UNETR = []
case_idx = 0
test_merging_mode = ["mergingv2", "merging", PatchMerging, PatchMergingV2]
for attn_drop_rate in [0.4]:
for in_channels in [1]:
for depth in [[2, 1, 1, 1], [1, 2, 1, 1]]:
Expand All @@ -39,10 +41,12 @@
"depths": depth,
"norm_name": norm_name,
"attn_drop_rate": attn_drop_rate,
"downsample": test_merging_mode[case_idx % 4],
},
(2, in_channels, *img_size),
(2, out_channels, *img_size),
]
case_idx += 1
TEST_CASE_SWIN_UNETR.append(test_case)


Expand Down
4 changes: 4 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import operator
import os
import queue
import ssl
import sys
import tempfile
import time
Expand Down Expand Up @@ -123,6 +124,9 @@ def skip_if_downloading_fails():
yield
except (ContentTooShortError, HTTPError, ConnectionError) as e:
raise unittest.SkipTest(f"error while downloading: {e}") from e
except ssl.SSLError as ssl_e:
if "decryption failed" in str(ssl_e):
raise unittest.SkipTest(f"SSL error while downloading: {ssl_e}") from ssl_e
except RuntimeError as rt_e:
if "unexpected EOF" in str(rt_e):
raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e # incomplete download
Expand Down

0 comments on commit 356d2d2

Please sign in to comment.