Skip to content

Commit

Permalink
Add MedNext implementation (#8004)
Browse files Browse the repository at this point in the history
Fixes #7786 

### Description
Added MedNext architectures implementation for MONAI.

Since a lot of the code is heavily sourced from the original MedNext
repo, https://github.com/MIC-DKFZ/MedNeXt, I wanted to check if there is
an attribution policy with regarded to borrowed source code. I've added
a derivative notice bellow the monai copyright comment. Let me know if
this needs to be changed.

The blocks have been taken almost as is but the network implementation
has been changed largely to allow flexible blocks and follow MONAI
segresnet styling.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Suraj Pai <[email protected]>
Signed-off-by: Robin CREMESE <[email protected]>
Co-authored-by: Robin CREMESE <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: YunLiu <[email protected]>
  • Loading branch information
4 people authored Nov 13, 2024
1 parent b6663b9 commit 941e739
Show file tree
Hide file tree
Showing 5 changed files with 805 additions and 0 deletions.
1 change: 1 addition & 0 deletions monai/networks/blocks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .fcn import FCN, GCN, MCFCN, Refine
from .feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool, LastLevelP6P7
from .localnet_block import LocalNetDownSampleBlock, LocalNetFeatureExtractorBlock, LocalNetUpSampleBlock
from .mednext_block import MedNeXtBlock, MedNeXtDownBlock, MedNeXtOutBlock, MedNeXtUpBlock
from .mlp import MLPBlock
from .patchembedding import PatchEmbed, PatchEmbeddingBlock
from .regunet_block import RegistrationDownSampleBlock, RegistrationExtractionBlock, RegistrationResidualConvBlock
Expand Down
309 changes: 309 additions & 0 deletions monai/networks/blocks/mednext_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,309 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Portions of this code are derived from the original repository at:
# https://github.com/MIC-DKFZ/MedNeXt
# and are used under the terms of the Apache License, Version 2.0.

from __future__ import annotations

import torch
import torch.nn as nn

all = ["MedNeXtBlock", "MedNeXtDownBlock", "MedNeXtUpBlock", "MedNeXtOutBlock"]


def get_conv_layer(spatial_dim: int = 3, transpose: bool = False):
if spatial_dim == 2:
return nn.ConvTranspose2d if transpose else nn.Conv2d
else: # spatial_dim == 3
return nn.ConvTranspose3d if transpose else nn.Conv3d


class MedNeXtBlock(nn.Module):
"""
MedNeXtBlock class for the MedNeXt model.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
expansion_ratio (int): Expansion ratio for the block. Defaults to 4.
kernel_size (int): Kernel size for convolutions. Defaults to 7.
use_residual_connection (int): Whether to use residual connection. Defaults to True.
norm_type (str): Type of normalization to use. Defaults to "group".
dim (str): Dimension of the input. Can be "2d" or "3d". Defaults to "3d".
global_resp_norm (bool): Whether to use global response normalization. Defaults to False.
"""

def __init__(
self,
in_channels: int,
out_channels: int,
expansion_ratio: int = 4,
kernel_size: int = 7,
use_residual_connection: int = True,
norm_type: str = "group",
dim="3d",
global_resp_norm=False,
):

super().__init__()

self.do_res = use_residual_connection

self.dim = dim
conv = get_conv_layer(spatial_dim=2 if dim == "2d" else 3)
global_resp_norm_param_shape = (1,) * (2 if dim == "2d" else 3)
# First convolution layer with DepthWise Convolutions
self.conv1 = conv(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=kernel_size,
stride=1,
padding=kernel_size // 2,
groups=in_channels,
)

# Normalization Layer. GroupNorm is used by default.
if norm_type == "group":
self.norm = nn.GroupNorm(num_groups=in_channels, num_channels=in_channels) # type: ignore
elif norm_type == "layer":
self.norm = nn.LayerNorm(
normalized_shape=[in_channels] + [kernel_size] * (2 if dim == "2d" else 3) # type: ignore
)
# Second convolution (Expansion) layer with Conv3D 1x1x1
self.conv2 = conv(
in_channels=in_channels, out_channels=expansion_ratio * in_channels, kernel_size=1, stride=1, padding=0
)

# GeLU activations
self.act = nn.GELU()

# Third convolution (Compression) layer with Conv3D 1x1x1
self.conv3 = conv(
in_channels=expansion_ratio * in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0
)

self.global_resp_norm = global_resp_norm
if self.global_resp_norm:
global_resp_norm_param_shape = (1, expansion_ratio * in_channels) + global_resp_norm_param_shape
self.global_resp_beta = nn.Parameter(torch.zeros(global_resp_norm_param_shape), requires_grad=True)
self.global_resp_gamma = nn.Parameter(torch.zeros(global_resp_norm_param_shape), requires_grad=True)

def forward(self, x):
"""
Forward pass of the MedNeXtBlock.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor.
"""
x1 = x
x1 = self.conv1(x1)
x1 = self.act(self.conv2(self.norm(x1)))

if self.global_resp_norm:
# gamma, beta: learnable affine transform parameters
# X: input of shape (N,C,H,W,D)
if self.dim == "2d":
gx = torch.norm(x1, p=2, dim=(-2, -1), keepdim=True)
else:
gx = torch.norm(x1, p=2, dim=(-3, -2, -1), keepdim=True)
nx = gx / (gx.mean(dim=1, keepdim=True) + 1e-6)
x1 = self.global_resp_gamma * (x1 * nx) + self.global_resp_beta + x1
x1 = self.conv3(x1)
if self.do_res:
x1 = x + x1
return x1


class MedNeXtDownBlock(MedNeXtBlock):
"""
MedNeXtDownBlock class for downsampling in the MedNeXt model.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
expansion_ratio (int): Expansion ratio for the block. Defaults to 4.
kernel_size (int): Kernel size for convolutions. Defaults to 7.
use_residual_connection (bool): Whether to use residual connection. Defaults to False.
norm_type (str): Type of normalization to use. Defaults to "group".
dim (str): Dimension of the input. Can be "2d" or "3d". Defaults to "3d".
global_resp_norm (bool): Whether to use global response normalization. Defaults to False.
"""

def __init__(
self,
in_channels: int,
out_channels: int,
expansion_ratio: int = 4,
kernel_size: int = 7,
use_residual_connection: bool = False,
norm_type: str = "group",
dim: str = "3d",
global_resp_norm: bool = False,
):

super().__init__(
in_channels,
out_channels,
expansion_ratio,
kernel_size,
use_residual_connection=False,
norm_type=norm_type,
dim=dim,
global_resp_norm=global_resp_norm,
)

conv = get_conv_layer(spatial_dim=2 if dim == "2d" else 3)
self.resample_do_res = use_residual_connection
if use_residual_connection:
self.res_conv = conv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=2)

self.conv1 = conv(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=kernel_size,
stride=2,
padding=kernel_size // 2,
groups=in_channels,
)

def forward(self, x):
"""
Forward pass of the MedNeXtDownBlock.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor.
"""
x1 = super().forward(x)

if self.resample_do_res:
res = self.res_conv(x)
x1 = x1 + res

return x1


class MedNeXtUpBlock(MedNeXtBlock):
"""
MedNeXtUpBlock class for upsampling in the MedNeXt model.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
expansion_ratio (int): Expansion ratio for the block. Defaults to 4.
kernel_size (int): Kernel size for convolutions. Defaults to 7.
use_residual_connection (bool): Whether to use residual connection. Defaults to False.
norm_type (str): Type of normalization to use. Defaults to "group".
dim (str): Dimension of the input. Can be "2d" or "3d". Defaults to "3d".
global_resp_norm (bool): Whether to use global response normalization. Defaults to False.
"""

def __init__(
self,
in_channels: int,
out_channels: int,
expansion_ratio: int = 4,
kernel_size: int = 7,
use_residual_connection: bool = False,
norm_type: str = "group",
dim: str = "3d",
global_resp_norm: bool = False,
):
super().__init__(
in_channels,
out_channels,
expansion_ratio,
kernel_size,
use_residual_connection=False,
norm_type=norm_type,
dim=dim,
global_resp_norm=global_resp_norm,
)

self.resample_do_res = use_residual_connection

self.dim = dim
conv = get_conv_layer(spatial_dim=2 if dim == "2d" else 3, transpose=True)
if use_residual_connection:
self.res_conv = conv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=2)

self.conv1 = conv(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=kernel_size,
stride=2,
padding=kernel_size // 2,
groups=in_channels,
)

def forward(self, x):
"""
Forward pass of the MedNeXtUpBlock.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor.
"""
x1 = super().forward(x)
# Asymmetry but necessary to match shape

if self.dim == "2d":
x1 = torch.nn.functional.pad(x1, (1, 0, 1, 0))
else:
x1 = torch.nn.functional.pad(x1, (1, 0, 1, 0, 1, 0))

if self.resample_do_res:
res = self.res_conv(x)
if self.dim == "2d":
res = torch.nn.functional.pad(res, (1, 0, 1, 0))
else:
res = torch.nn.functional.pad(res, (1, 0, 1, 0, 1, 0))
x1 = x1 + res

return x1


class MedNeXtOutBlock(nn.Module):
"""
MedNeXtOutBlock class for the output block in the MedNeXt model.
Args:
in_channels (int): Number of input channels.
n_classes (int): Number of output classes.
dim (str): Dimension of the input. Can be "2d" or "3d".
"""

def __init__(self, in_channels, n_classes, dim):
super().__init__()

conv = get_conv_layer(spatial_dim=2 if dim == "2d" else 3, transpose=True)
self.conv_out = conv(in_channels, n_classes, kernel_size=1)

def forward(self, x):
"""
Forward pass of the MedNeXtOutBlock.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor.
"""
return self.conv_out(x)
19 changes: 19 additions & 0 deletions monai/networks/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,25 @@
from .generator import Generator
from .highresnet import HighResBlock, HighResNet
from .hovernet import Hovernet, HoVernet, HoVerNet, HoverNet
from .mednext import (
MedNeXt,
MedNext,
MedNextB,
MedNeXtB,
MedNextBase,
MedNextL,
MedNeXtL,
MedNeXtLarge,
MedNextLarge,
MedNextM,
MedNeXtM,
MedNeXtMedium,
MedNextMedium,
MedNextS,
MedNeXtS,
MedNeXtSmall,
MedNextSmall,
)
from .milmodel import MILModel
from .netadapter import NetAdapter
from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator
Expand Down
Loading

0 comments on commit 941e739

Please sign in to comment.