diff --git a/SwinMM/INSTALL.md b/SwinMM/INSTALL.md
new file mode 100644
index 00000000..16ceb026
--- /dev/null
+++ b/SwinMM/INSTALL.md
@@ -0,0 +1,101 @@
+# Installation
+
+We provide installation instructions here.
+
+## Setup
+
+### Using Docker
+
+The simplest way to use SwinMM is to use our docker image [`swinmm`](https://drive.google.com/file/d/1EGSoqN-HphyMV_gKUq-g7_BSwTTg35oA/view?usp=sharing), which has contained all the needed dependencies. Download the `swinmm.tar` into the `SwinMM` directory and try the following scripts:
+
+```bash
+cd SwinMM
+docker import - swinmm < swinmm.tar
+docker run --runtime=nvidia --gpus=all -m="800g" --shm-size="32g" -itd -v ./:/volume swinmm /bin/bash
+docker exec -it swinmm /bin/bash
+conda activate SwinMM
+```
+
+To use docker, make sure you have installed `docker` and `nvidia-docker`.
+
+### Manual
+
+For fast dataset loading, we required the users to install the Redis database, for example, on Ubuntu: `sudo apt-get install redis`
+
+We also recommend the users install the PyTorch-based version from the official website.
+
+Two packages are recommended to install manually according to their complicated dependencies: [bagua==0.9.2](https://github.com/BaguaSys/bagua), [monai==0.9.0](https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies)
+
+The others can be installed through `pip install -r requirements.txt`
+
+## Datasets
+
+Our pre-training dataset includes 5833 volumes from 8 public datasets:
+
+- [AbdomenCT-1K](https://github.com/JunMa11/AbdomenCT-1K)
+- [BTCV](https://www.synapse.org/#!Synapse:syn3193805/wiki/217789)
+- [MSD](http://medicaldecathlon.com/)
+- [TCIACovid19](https://wiki.cancerimagingarchive.net/display/Public/CT+Images+in+COVID-19/)
+- [WORD](https://github.com/HiLab-git/WORD)
+- [TCIA-Colon](https://wiki.cancerimagingarchive.net/display/Public/CT+COLONOGRAPHY/)
+- [LiDC](https://wiki.cancerimagingarchive.net/display/Public/LIDC-IDRI/)
+- [HNSCC](https://wiki.cancerimagingarchive.net/display/Public/HNSCC)
+
+We choose two popular datasets to test the downstream segmentation performance:
+
+- [WORD](https://github.com/HiLab-git/WORD) (The Whole abdominal Organ Dataset)
+- [ACDC](https://www.creatis.insa-lyon.fr/Challenge/acdc/#challenge/584e75606a3c77492fe91bba) (Automated Cardiac Diagnosis Challenge)
+
+The json files can be downloaded from [pretrain_jsons](https://drive.google.com/file/d/1gJThxBvnJnc2_N1nFX7xywjFWFw7DSEY/view?usp=sharing) and [word_jsons](https://drive.google.com/file/d/1Td4T_k2QlEcTETz9TERGsVdOyebD5ULv/view?usp=sharing);
+
+The dataset is organized as below:
+
+```text
+SwinMM
+├── WORD
+│ └── dataset
+│ └── dataset12_WORD
+│ ├── imagesTr
+│ ├── imagesTs
+│ ├── imagesVal
+│ ├── labelsTr
+│ ├── labelsTs
+│ ├── labelsVal
+│ └── dataset12_WORD.json
+└── Pretrain
+ ├── dataset
+ │ ├── dataset00_BTCV
+ │ ├── dataset02_Heart
+ │ ├── dataset03_Liver
+ │ ├── dataset04_Hippocampus
+ │ ├── dataset06_Lung
+ │ ├── dataset07_Pancreas
+ │ ├── dataset08_HepaticVessel
+ │ ├── dataset09_Spleen
+ │ ├── dataset10_Colon
+ │ ├── dataset11_TCIAcovid19
+ │ ├── dataset12_WORD
+ │ ├── dataset13_AbdomenCT-1K
+ │ ├── dataset_HNSCC
+ │ ├── dataset_TCIAcolon
+ │ └── dataset_LIDC
+ └── jsons
+ ├── dataset00_BTCV.json
+ ├── dataset01_BrainTumour.json
+ ├── dataset02_Heart.json
+ ├── dataset03_Liver.json
+ ├── dataset04_Hippocampus.json
+ ├── dataset05_Prostate.json
+ ├── dataset06_Lung.json
+ ├── dataset07_Pancreas.json
+ ├── dataset08_HepaticVessel.json
+ ├── dataset09_Spleen.json
+ ├── dataset10_Colon.json
+ ├── dataset11_TCIAcovid19.json
+ ├── dataset12_WORD.json
+ ├── dataset13_AbdomenCT-1K.json
+ ├── dataset_HNSCC.json
+ ├── dataset_TCIAcolon.json
+ └── dataset_LIDC.json
+
+```
diff --git a/SwinMM/Pretrain/jsons/Download Pretrain Jsons Here b/SwinMM/Pretrain/jsons/Download Pretrain Jsons Here
new file mode 100644
index 00000000..e69de29b
diff --git a/SwinMM/Pretrain/jsons/__init__.py b/SwinMM/Pretrain/jsons/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/SwinMM/Pretrain/losses/loss.py b/SwinMM/Pretrain/losses/loss.py
new file mode 100644
index 00000000..8e0b5a88
--- /dev/null
+++ b/SwinMM/Pretrain/losses/loss.py
@@ -0,0 +1,95 @@
+# Copyright 2020 - 2022 MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+from torch.nn import functional as F
+
+
+class ContrastLoss(torch.nn.Module):
+ def __init__(self, args, batch_size, temperature=0.5):
+ super().__init__()
+ device = torch.device(f"cuda:{args.local_rank}")
+ self.batch_size = batch_size
+ self.register_buffer("temp", torch.tensor(temperature).to(torch.device(f"cuda:{args.local_rank}")))
+ self.register_buffer("neg_mask", (~torch.eye(batch_size * 2, batch_size * 2, dtype=bool).to(device)).float())
+
+ def forward(self, x_i, x_j):
+ z_i = F.normalize(x_i, dim=1)
+ z_j = F.normalize(x_j, dim=1)
+ z = torch.cat([z_i, z_j], dim=0)
+ sim = F.cosine_similarity(z.unsqueeze(1), z.unsqueeze(0), dim=2)
+ sim_ij = torch.diag(sim, self.batch_size)
+ sim_ji = torch.diag(sim, -self.batch_size)
+ pos = torch.cat([sim_ij, sim_ji], dim=0)
+ nom = torch.exp(pos / self.temp)
+ denom = self.neg_mask * torch.exp(sim / self.temp)
+ return torch.sum(-torch.log(nom / torch.sum(denom, dim=1))) / (2 * self.batch_size)
+
+
+class MutualLoss(torch.nn.Module):
+ def __init__(self, args):
+ super().__init__()
+ self.alpha = 1.0
+ self.mask_ratio = args.mask_ratio
+ self.recon_loss_2 = torch.nn.MSELoss().cuda()
+
+ def __call__(self, rec1, rec2, mask):
+ mask = mask.to(dtype=rec1.dtype)
+ rec1, rec2 = [val * mask for val in [rec1, rec2]]
+
+ recon_loss = self.recon_loss_2(rec1, rec2) / self.mask_ratio
+ return self.alpha * recon_loss
+
+
+class Loss(torch.nn.Module):
+ def __init__(self, batch_size, args):
+ super().__init__()
+ self.rot_loss = torch.nn.CrossEntropyLoss().cuda()
+ self.recon_loss = torch.nn.L1Loss().cuda()
+ self.recon_loss_2 = torch.nn.MSELoss().cuda()
+ self.contrast_loss = ContrastLoss(args, batch_size).cuda()
+ self.alpha1 = 1.0
+ self.alpha2 = 1.0
+ self.alpha3 = 1.0
+ self.norm_pix_loss = args.norm_pix_loss
+ self.mask_ratio = args.mask_ratio
+
+ def __call__(
+ self,
+ output_rot,
+ target_rot,
+ output_contrastive,
+ target_contrastive,
+ output_recons,
+ target_recons,
+ mask,
+ only_mae=False,
+ ):
+ B, C, H, W, D = output_recons.shape
+ target_recons = target_recons.reshape(B, C, -1)
+
+ if self.norm_pix_loss:
+ mean = target_recons.mean(dim=-1, keepdim=True)
+ var = target_recons.var(dim=-1, keepdim=True)
+ target_recons = (target_recons - mean) / (var + 1.0e-6) ** 0.5
+ target_recons = target_recons.reshape(B, C, H, W, D)
+ # masked voxels.
+ mask = mask.to(dtype=target_recons.dtype)[None, ...]
+ target_recons, output_recons = [val * mask for val in [target_recons, output_recons]]
+ recon_loss = self.recon_loss_2(output_recons, target_recons) / self.mask_ratio
+ recon_loss = self.alpha3 * recon_loss
+ if only_mae:
+ return recon_loss
+ contrast_loss = self.alpha2 * self.contrast_loss(output_contrastive, target_contrastive)
+ rot_loss = self.alpha1 * self.rot_loss(output_rot, target_rot)
+ total_loss = rot_loss + contrast_loss + recon_loss
+
+ return total_loss, (rot_loss, contrast_loss, recon_loss)
diff --git a/SwinMM/Pretrain/main.py b/SwinMM/Pretrain/main.py
new file mode 100644
index 00000000..419b8e1c
--- /dev/null
+++ b/SwinMM/Pretrain/main.py
@@ -0,0 +1,317 @@
+# Copyright 2020 - 2022 MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import os
+import random
+from time import time
+
+import timm.optim.optim_factory as optim_factory
+import torch
+import torch.distributed as dist
+import torch.optim as optim
+from losses.loss import Loss, MutualLoss
+from models.ssl_head import SSLHead
+from optimizers.lr_scheduler import WarmupCosineSchedule
+from torch.cuda.amp import GradScaler, autocast
+from torch.nn.parallel import DistributedDataParallel
+from utils import view_ops, view_transforms
+from utils.data_utils import get_loader
+from utils.dataset_in_memory import hijack_bagua_serialization
+from utils.ops import mask_rand_patch
+
+# torch
+torch.multiprocessing.set_sharing_strategy("file_system")
+
+
+def main():
+ def save_ckpt(state, checkpoint_dir):
+ torch.save(state, checkpoint_dir)
+
+ def train(args, global_step, train_loader, val_best, scaler):
+ model.train()
+
+ for _, batch in enumerate(train_loader):
+ t1 = time()
+ x = batch["image"].cuda()
+
+ x1, rot1 = view_ops.rot_rand(x)
+ x2, rot2 = view_ops.rot_rand(x)
+
+ window_sizes = tuple(args.window_size for _ in range(3))
+ input_sizes = (args.roi_x, args.roi_y, args.roi_z)
+ x1_masked, mask1 = mask_rand_patch(window_sizes, input_sizes, args.mask_ratio, x1)
+ x2_masked, mask2 = mask_rand_patch(window_sizes, input_sizes, args.mask_ratio, x2)
+
+ # NOTE(meijieru): x1, x2 may have different rot transform, so we
+ # allow same permute transform here.
+ permutations_candidates = set(view_transforms.permutation_transforms.keys()) - {0}
+ permutations = [random.choice(list(permutations_candidates)) for _ in range(2)]
+ x1_masked_permuted, x2_masked_permuted = [
+ view_transforms.permutation_transforms[vn](val) for vn, val in zip(permutations, [x1_masked, x2_masked])
+ ]
+
+ with autocast(enabled=args.amp):
+ rot1_p, contrastive1_p, rec_x1 = model(x1_masked)
+ rot2_p, contrastive2_p, rec_x2 = model(x2_masked)
+ _, contrastive3_p, rec_x3 = model(x1_masked_permuted)
+ _, contrastive4_p, rec_x4 = model(x2_masked_permuted)
+
+ # masked voxels: [2, H, W, D]
+ mask = torch.stack([mask1, mask2], dim=0)
+ rec_x3, rec_x4 = [
+ view_transforms.permutation_inverse_transforms[vn](val)
+ for vn, val in zip(permutations, [rec_x3, rec_x4])
+ ]
+
+ rot_p = torch.cat([rot1_p, rot2_p], dim=0)
+ rots = torch.cat([rot1, rot2], dim=0)
+ # [B, 2, H, W, D]
+ imgs_recon = torch.cat([rec_x1, rec_x2], dim=1)
+ imgs = torch.cat([x1, x2], dim=1)
+ loss1, losses_tasks1 = loss_function(
+ rot_p, rots, contrastive1_p, contrastive2_p, imgs_recon, imgs, mask
+ )
+
+ mutual_loss1 = mutual_loss_function(rec_x3, rec_x1, mask1)
+
+ imgs_recon = torch.cat([rec_x3, rec_x4], dim=1)
+ loss2 = loss_function(
+ rot_p, rots, contrastive3_p, contrastive4_p, imgs_recon, imgs, mask, only_mae=True
+ )
+
+ loss = loss1 + loss2 + mutual_loss1
+
+ mutual_loss2 = None
+ if args.mutual_learning_on_more_view:
+
+ def _align_rot(x, src_rot, dst_rot):
+ return view_transforms.rotation_transforms[dst_rot](
+ view_transforms.rotation_inverse_transforms[src_rot](x)
+ ).contiguous()
+
+ # [B, C, H, W, D]
+ rec_x4_aligned = torch.stack(
+ [
+ _align_rot(val, src_rot.item(), dst_rot.item())
+ for val, src_rot, dst_rot in zip(rec_x4, rot2, rot1)
+ ]
+ )
+ # [B, 1, H, W, D]
+ mask2_aligned = torch.concat(
+ [
+ _align_rot(mask2[None, None], src_rot.item(), dst_rot.item())
+ for src_rot, dst_rot in zip(rot2, rot1)
+ ]
+ )
+ mask_intersection = torch.logical_and(mask2_aligned, mask1)
+ # Rescale to the same scale of mutual_loss1
+ rescaler = mask1.sum() * mask2_aligned.size(0) / (mask2_aligned.sum() + 1e-6)
+ mutual_loss2 = mutual_loss_function(rec_x4_aligned, rec_x1, mask_intersection) * rescaler
+
+ loss = loss + mutual_loss2
+
+ if args.amp:
+ scaler.scale(loss).backward()
+ scaler.step(optimizer)
+ scaler.update()
+ else:
+ loss.backward()
+ if args.grad_clip:
+ torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
+ optimizer.step()
+
+ if args.lrdecay:
+ scheduler.step()
+ optimizer.zero_grad()
+ if args.distributed:
+ if dist.get_rank() == 0:
+ rot_loss = losses_tasks1[0].item()
+ con_loss = losses_tasks1[1].item()
+ rec_loss = losses_tasks1[2].item() + loss2.item()
+ print(
+ "Step:{}/{}, Loss:{:.4f}, Rot:{:.4f}, Con:{:.4f}, Rec:{:.4f}, Time:{:.4f}".format(
+ global_step, args.num_steps, loss, rot_loss, con_loss, rec_loss, time() - t1
+ )
+ )
+ else:
+ print("Step:{}/{}, Loss:{:.4f}, Time:{:.4f}".format(global_step, args.num_steps, loss, time() - t1))
+
+ global_step += 1
+ if args.distributed:
+ val_cond = (dist.get_rank() == 0) and (global_step % args.eval_num == 0)
+ else:
+ val_cond = global_step % args.eval_num == 0
+
+ if val_cond and global_step % 1000 == 0:
+ checkpoint = {
+ "global_step": global_step,
+ "state_dict": model.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ }
+ save_ckpt(checkpoint, logdir + "/model_{}.pt".format(global_step))
+ return global_step, loss, val_best
+
+ parser = argparse.ArgumentParser(description="PyTorch Training")
+ parser.add_argument("--logdir", default="test", type=str, help="directory to save the tensorboard logs")
+ parser.add_argument("--epochs", default=100, type=int, help="number of training epochs")
+ parser.add_argument("--num_steps", default=100000, type=int, help="number of training iterations")
+ parser.add_argument("--eval_num", default=100, type=int, help="evaluation frequency")
+ parser.add_argument("--warmup_steps", default=500, type=int, help="warmup steps")
+ parser.add_argument("--in_channels", default=1, type=int, help="number of input channels")
+ parser.add_argument("--feature_size", default=48, type=int, help="embedding size")
+ parser.add_argument("--dropout_path_rate", default=0.0, type=float, help="drop path rate")
+ parser.add_argument("--use_checkpoint", action="store_true", help="use gradient checkpointing to save memory")
+ parser.add_argument("--spatial_dims", default=3, type=int, help="spatial dimension of input data")
+ parser.add_argument("--a_min", default=-1000, type=float, help="a_min in ScaleIntensityRanged")
+ parser.add_argument("--a_max", default=1000, type=float, help="a_max in ScaleIntensityRanged")
+ parser.add_argument("--b_min", default=0.0, type=float, help="b_min in ScaleIntensityRanged")
+ parser.add_argument("--b_max", default=1.0, type=float, help="b_max in ScaleIntensityRanged")
+ parser.add_argument("--space_x", default=1.5, type=float, help="spacing in x direction")
+ parser.add_argument("--space_y", default=1.5, type=float, help="spacing in y direction")
+ parser.add_argument("--space_z", default=2.0, type=float, help="spacing in z direction")
+ parser.add_argument("--roi_x", default=64, type=int, help="roi size in x direction")
+ parser.add_argument("--roi_y", default=64, type=int, help="roi size in y direction")
+ parser.add_argument("--roi_z", default=64, type=int, help="roi size in z direction")
+ parser.add_argument("--mask_ratio", default=0.5, type=float, help="mask ratio for MAE pretraining")
+ parser.add_argument("--window_size", default=16, type=int, help="window size for MAE pretraining")
+ parser.add_argument("--batch_size", default=1, type=int, help="number of batch size")
+ parser.add_argument("--sw_batch_size", default=2, type=int, help="number of sliding window batch size")
+ parser.add_argument("--lr", default=4e-4, type=float, help="learning rate")
+ parser.add_argument("--decay", default=0.1, type=float, help="decay rate")
+ parser.add_argument("--momentum", default=0.9, type=float, help="momentum")
+ parser.add_argument("--lrdecay", action="store_true", help="enable learning rate decay")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="maximum gradient norm")
+ parser.add_argument("--loss_type", default="SSL", type=str)
+ parser.add_argument("--opt", default="adamw", type=str, help="optimization algorithm")
+ parser.add_argument("--lr_schedule", default="warmup_cosine", type=str)
+ parser.add_argument("--resume", default=None, type=str, help="resume training")
+ parser.add_argument("--local_rank", type=int, default=0, help="local rank")
+ parser.add_argument("--grad_clip", action="store_true", help="gradient clip")
+ parser.add_argument("--noamp", action="store_true", help="do NOT use amp for training")
+ parser.add_argument("--dist-url", default="env://", help="url used to set up distributed training")
+ parser.add_argument("--norm_pix_loss", action="store_true", help="normalize before compute reconstruction loss")
+ parser.add_argument("--redis_ports", nargs="+", type=int, help="redis ports")
+ parser.add_argument("--redis_compression", type=str, default="lz4", help="compression method for redis.")
+ parser.add_argument("--use_normal_dataset", action="store_true", help="use monai Dataset class")
+ parser.add_argument(
+ "--nouse_multi_epochs_loader",
+ action="store_true",
+ help="not use the multi-epochs-loader to save time at the beginning of every epoch",
+ )
+ parser.add_argument(
+ "--mutual_learning_on_more_view", action="store_true", help="also use rotate for mutual learning"
+ )
+ parser.add_argument("--workers", default=16, type=int, help="number of workers")
+
+ args = parser.parse_args()
+ logdir = "./runs/" + args.logdir
+ args.amp = not args.noamp
+ args.lr = args.lr * args.batch_size / 2
+ torch.backends.cudnn.benchmark = True
+ args.distributed = False
+ if "WORLD_SIZE" in os.environ:
+ args.distributed = int(os.environ["WORLD_SIZE"]) > 1
+ args.device = "cuda:0"
+ args.world_size = 1
+ args.rank = 0
+
+ if args.distributed:
+ args.device = "cuda:%d" % args.local_rank
+ torch.cuda.set_device(args.local_rank)
+ dist.init_process_group(backend="nccl", init_method=args.dist_url)
+ args.world_size = dist.get_world_size()
+ args.rank = dist.get_rank()
+ print(
+ "Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d."
+ % (args.rank, args.world_size)
+ )
+ else:
+ print("Training with a single process on 1 GPUs.")
+ assert args.rank >= 0
+
+ if args.redis_compression is not None:
+ hijack_bagua_serialization(args.redis_compression)
+
+ if args.rank == 0:
+ os.makedirs(logdir, exist_ok=True)
+
+ model = SSLHead(args)
+ model.cuda()
+ model_without_ddp = model
+
+ param_groups = optim_factory.param_groups_weight_decay(
+ model_without_ddp, weight_decay=args.decay, no_weight_decay_list=model_without_ddp.no_weight_decay()
+ )
+ if args.opt == "adam":
+ optimizer = optim.Adam(param_groups, lr=args.lr)
+
+ elif args.opt == "adamw":
+ optimizer = optim.AdamW(param_groups, lr=args.lr)
+
+ elif args.opt == "sgd":
+ optimizer = optim.SGD(param_groups, lr=args.lr, momentum=args.momentum)
+ else:
+ raise ValueError(f"Unknown optimizer: {args.opt})")
+
+ global_step = 0
+ if args.resume:
+ model_pth = args.resume
+ model_dict = torch.load(model_pth)
+ new_state = {}
+
+ for k, v in model_dict["state_dict"].items():
+ new_name = k[7:]
+ new_state[new_name] = v
+
+ model.load_state_dict(new_state)
+ global_step = model_dict["global_step"]
+ model.optimizer = model_dict["optimizer"]
+
+ if args.lrdecay:
+ if args.lr_schedule == "warmup_cosine":
+ scheduler = WarmupCosineSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=args.num_steps)
+
+ elif args.lr_schedule == "poly":
+
+ def lambdas(epoch):
+ return (1 - float(epoch) / float(args.epochs)) ** 0.9
+
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambdas)
+
+ mutual_loss_function = MutualLoss(args)
+ loss_function = Loss(args.batch_size * args.sw_batch_size, args)
+ if args.distributed:
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
+ model = DistributedDataParallel(model, device_ids=[args.local_rank], broadcast_buffers=False)
+ model_without_ddp = model.module
+ train_loader, _ = get_loader(args)
+
+ best_val = 1e8
+ if args.amp:
+ scaler = GradScaler()
+ else:
+ scaler = None
+ while global_step < args.num_steps:
+ global_step, loss, best_val = train(args, global_step, train_loader, best_val, scaler)
+ checkpoint = {"epoch": args.epochs, "state_dict": model.state_dict(), "optimizer": optimizer.state_dict()}
+
+ if args.distributed:
+ if dist.get_rank() == 0:
+ torch.save(model.state_dict(), logdir + "final_model.pth")
+ dist.destroy_process_group()
+ else:
+ torch.save(model.state_dict(), logdir + "final_model.pth")
+ save_ckpt(checkpoint, logdir + "/model_final_epoch.pt")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/SwinMM/Pretrain/models/ssl_head.py b/SwinMM/Pretrain/models/ssl_head.py
new file mode 100644
index 00000000..fc259418
--- /dev/null
+++ b/SwinMM/Pretrain/models/ssl_head.py
@@ -0,0 +1,99 @@
+# Copyright 2020 - 2022 MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.nn as nn
+
+from monai.networks.nets.swin_unetr import SwinTransformer as SwinViT
+from monai.utils import ensure_tuple_rep
+
+
+class SSLHead(nn.Module):
+ def __init__(self, args, upsample="vae", dim=768):
+ super(SSLHead, self).__init__()
+ patch_size = ensure_tuple_rep(2, args.spatial_dims)
+ window_size = ensure_tuple_rep(7, args.spatial_dims)
+ self.swinViT = SwinViT(
+ in_chans=args.in_channels,
+ embed_dim=args.feature_size,
+ window_size=window_size,
+ patch_size=patch_size,
+ depths=[2, 2, 2, 2],
+ num_heads=[3, 6, 12, 24],
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ drop_rate=0.0,
+ attn_drop_rate=0.0,
+ drop_path_rate=args.dropout_path_rate,
+ norm_layer=torch.nn.LayerNorm,
+ use_checkpoint=args.use_checkpoint,
+ spatial_dims=args.spatial_dims,
+ )
+ self.rotation_pre = nn.Identity()
+ self.rotation_head = nn.Linear(dim, 4)
+ self.contrastive_pre = nn.Identity()
+ self.contrastive_head = nn.Linear(dim, 512)
+ if upsample == "large_kernel_deconv":
+ self.conv = nn.ConvTranspose3d(dim, args.in_channels, kernel_size=(32, 32, 32), stride=(32, 32, 32))
+ elif upsample == "deconv":
+ self.conv = nn.Sequential(
+ nn.ConvTranspose3d(dim, dim // 2, kernel_size=(2, 2, 2), stride=(2, 2, 2)),
+ nn.ConvTranspose3d(dim // 2, dim // 4, kernel_size=(2, 2, 2), stride=(2, 2, 2)),
+ nn.ConvTranspose3d(dim // 4, dim // 8, kernel_size=(2, 2, 2), stride=(2, 2, 2)),
+ nn.ConvTranspose3d(dim // 8, dim // 16, kernel_size=(2, 2, 2), stride=(2, 2, 2)),
+ nn.ConvTranspose3d(dim // 16, args.in_channels, kernel_size=(2, 2, 2), stride=(2, 2, 2)),
+ )
+ elif upsample == "vae":
+ self.conv = nn.Sequential(
+ nn.Conv3d(dim, dim // 2, kernel_size=3, stride=1, padding=1),
+ nn.InstanceNorm3d(dim // 2),
+ nn.LeakyReLU(),
+ nn.Upsample(scale_factor=2, mode="trilinear", align_corners=False),
+ nn.Conv3d(dim // 2, dim // 4, kernel_size=3, stride=1, padding=1),
+ nn.InstanceNorm3d(dim // 4),
+ nn.LeakyReLU(),
+ nn.Upsample(scale_factor=2, mode="trilinear", align_corners=False),
+ nn.Conv3d(dim // 4, dim // 8, kernel_size=3, stride=1, padding=1),
+ nn.InstanceNorm3d(dim // 8),
+ nn.LeakyReLU(),
+ nn.Upsample(scale_factor=2, mode="trilinear", align_corners=False),
+ nn.Conv3d(dim // 8, dim // 16, kernel_size=3, stride=1, padding=1),
+ nn.InstanceNorm3d(dim // 16),
+ nn.LeakyReLU(),
+ nn.Upsample(scale_factor=2, mode="trilinear", align_corners=False),
+ nn.Conv3d(dim // 16, dim // 16, kernel_size=3, stride=1, padding=1),
+ nn.InstanceNorm3d(dim // 16),
+ nn.LeakyReLU(),
+ nn.Upsample(scale_factor=2, mode="trilinear", align_corners=False),
+ nn.Conv3d(dim // 16, args.in_channels, kernel_size=1, stride=1),
+ )
+
+ def forward(self, x):
+ x_out = self.swinViT(x.contiguous())[4]
+ _, c, h, w, d = x_out.shape
+ x4_reshape = x_out.flatten(start_dim=2, end_dim=4)
+ x4_reshape = x4_reshape.transpose(1, 2)
+ x_rot = self.rotation_pre(x4_reshape[:, 0])
+ x_rot = self.rotation_head(x_rot)
+ x_contrastive = self.contrastive_pre(x4_reshape[:, 1])
+ x_contrastive = self.contrastive_head(x_contrastive)
+ x_rec = x_out.flatten(start_dim=2, end_dim=4)
+ x_rec = x_rec.view(-1, c, h, w, d)
+ x_rec = self.conv(x_rec)
+ return x_rot, x_contrastive, x_rec
+
+ def no_weight_decay(self):
+ """Disable weight_decay on specific weights."""
+ nwd = {"swinViT.absolute_pos_embed"}
+ for n, _ in self.named_parameters():
+ if "relative_position_bias_table" in n:
+ nwd.add(n)
+ return nwd
diff --git a/SwinMM/Pretrain/optimizers/__init__.py b/SwinMM/Pretrain/optimizers/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/SwinMM/Pretrain/optimizers/lr_scheduler.py b/SwinMM/Pretrain/optimizers/lr_scheduler.py
new file mode 100644
index 00000000..0c352927
--- /dev/null
+++ b/SwinMM/Pretrain/optimizers/lr_scheduler.py
@@ -0,0 +1,172 @@
+# Copyright 2020 - 2021 MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import warnings
+from typing import List
+
+from torch import nn as nn
+from torch.optim import Adam, Optimizer
+from torch.optim.lr_scheduler import LambdaLR, _LRScheduler
+
+__all__ = ["LinearLR", "ExponentialLR"]
+
+
+class _LRSchedulerMONAI(_LRScheduler):
+ """Base class for increasing the learning rate between two boundaries over a number
+ of iterations"""
+
+ def __init__(self, optimizer: Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1) -> None:
+ """
+ Args:
+ optimizer: wrapped optimizer.
+ end_lr: the final learning rate.
+ num_iter: the number of iterations over which the test occurs.
+ last_epoch: the index of last epoch.
+ Returns:
+ None
+ """
+ self.end_lr = end_lr
+ self.num_iter = num_iter
+ super(_LRSchedulerMONAI, self).__init__(optimizer, last_epoch)
+
+
+class LinearLR(_LRSchedulerMONAI):
+ """Linearly increases the learning rate between two boundaries over a number of
+ iterations.
+ """
+
+ def get_lr(self):
+ r = self.last_epoch / (self.num_iter - 1)
+ return [base_lr + r * (self.end_lr - base_lr) for base_lr in self.base_lrs]
+
+
+class ExponentialLR(_LRSchedulerMONAI):
+ """Exponentially increases the learning rate between two boundaries over a number of
+ iterations.
+ """
+
+ def get_lr(self):
+ r = self.last_epoch / (self.num_iter - 1)
+ return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs]
+
+
+class WarmupCosineSchedule(LambdaLR):
+ """Linear warmup and then cosine decay.
+ Based on https://huggingface.co/ implementation.
+ """
+
+ def __init__(
+ self, optimizer: Optimizer, warmup_steps: int, t_total: int, cycles: float = 0.5, last_epoch: int = -1
+ ) -> None:
+ """
+ Args:
+ optimizer: wrapped optimizer.
+ warmup_steps: number of warmup iterations.
+ t_total: total number of training iterations.
+ cycles: cosine cycles parameter.
+ last_epoch: the index of last epoch.
+ Returns:
+ None
+ """
+ self.warmup_steps = warmup_steps
+ self.t_total = t_total
+ self.cycles = cycles
+ super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch)
+
+ def lr_lambda(self, step):
+ if step < self.warmup_steps:
+ return float(step) / float(max(1.0, self.warmup_steps))
+ progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps))
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.cycles) * 2.0 * progress)))
+
+
+class LinearWarmupCosineAnnealingLR(_LRScheduler):
+ def __init__(
+ self,
+ optimizer: Optimizer,
+ warmup_epochs: int,
+ max_epochs: int,
+ warmup_start_lr: float = 0.0,
+ eta_min: float = 0.0,
+ last_epoch: int = -1,
+ ) -> None:
+ """
+ Args:
+ optimizer (Optimizer): Wrapped optimizer.
+ warmup_epochs (int): Maximum number of iterations for linear warmup
+ max_epochs (int): Maximum number of iterations
+ warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0.
+ eta_min (float): Minimum learning rate. Default: 0.
+ last_epoch (int): The index of last epoch. Default: -1.
+ """
+ self.warmup_epochs = warmup_epochs
+ self.max_epochs = max_epochs
+ self.warmup_start_lr = warmup_start_lr
+ self.eta_min = eta_min
+
+ super(LinearWarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch)
+
+ def get_lr(self) -> List[float]:
+ """
+ Compute learning rate using chainable form of the scheduler
+ """
+ if not self._get_lr_called_within_step:
+ warnings.warn(
+ "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", UserWarning
+ )
+
+ if self.last_epoch == 0:
+ return [self.warmup_start_lr] * len(self.base_lrs)
+ elif self.last_epoch < self.warmup_epochs:
+ return [
+ group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
+ for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
+ ]
+ elif self.last_epoch == self.warmup_epochs:
+ return self.base_lrs
+ elif (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0:
+ return [
+ group["lr"]
+ + (base_lr - self.eta_min) * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2
+ for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
+ ]
+
+ return [
+ (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
+ / (
+ 1
+ + math.cos(
+ math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs)
+ )
+ )
+ * (group["lr"] - self.eta_min)
+ + self.eta_min
+ for group in self.optimizer.param_groups
+ ]
+
+ def _get_closed_form_lr(self) -> List[float]:
+ """
+ Called when epoch is passed as a param to the `step` function of the scheduler.
+ """
+ if self.last_epoch < self.warmup_epochs:
+ return [
+ self.warmup_start_lr + self.last_epoch * (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
+ for base_lr in self.base_lrs
+ ]
+
+ return [
+ self.eta_min
+ + 0.5
+ * (base_lr - self.eta_min)
+ * (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
+ for base_lr in self.base_lrs
+ ]
diff --git a/SwinMM/Pretrain/run.sh b/SwinMM/Pretrain/run.sh
new file mode 100644
index 00000000..2f1883a3
--- /dev/null
+++ b/SwinMM/Pretrain/run.sh
@@ -0,0 +1,10 @@
+python -m torch.distributed.launch --nproc_per_node=8 --master_port=11223 main.py \
+ --batch_size=2 \
+ --num_steps=30000 \
+ --lrdecay \
+ --eval_num=500 \
+ --lr=5e-4 \
+ --decay=0.1 \
+ --norm_pix_loss \
+ --redis_ports 39996 39997 39998 39999 \
+ --redis_compression zlib
diff --git a/SwinMM/Pretrain/utils/__init__.py b/SwinMM/Pretrain/utils/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/SwinMM/Pretrain/utils/data_utils.py b/SwinMM/Pretrain/utils/data_utils.py
new file mode 100644
index 00000000..35fab4d1
--- /dev/null
+++ b/SwinMM/Pretrain/utils/data_utils.py
@@ -0,0 +1,176 @@
+# Copyright 2020 - 2022 MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import timm.data
+from utils import dataset_in_memory
+
+from monai.data import DataLoader, Dataset, DistributedSampler, load_decathlon_datalist
+from monai.data.utils import list_data_collate
+from monai.transforms import (
+ AddChanneld,
+ Compose,
+ LoadImaged,
+ Orientationd,
+ RandSpatialCropSamplesd,
+ ScaleIntensityRanged,
+ SpatialPadd,
+ ToTensord,
+)
+
+
+def get_loader(args):
+ splits0 = "/dataset00_BTCV.json"
+ # splits1 = "/dataset01_BrainTumour.json"
+ splits2 = "/dataset02_Heart.json"
+ splits3 = "/dataset03_Liver.json"
+ splits4 = "/dataset04_Hippocampus.json"
+ # splits5 = "/dataset05_Prostate.json"
+ splits6 = "/dataset06_Lung.json"
+ splits7 = "/dataset07_Pancreas.json"
+ splits8 = "/dataset08_HepaticVessel.json"
+ splits9 = "/dataset09_Spleen.json"
+ splits10 = "/dataset10_Colon.json"
+ splits11 = "/dataset11_TCIAcovid19.json"
+ splits12 = "/dataset12_WORD.json"
+ splits13 = "/dataset13_AbdomenCT-1K.json"
+ splits14 = "/dataset_HNSCC.json"
+ splits15 = "/dataset_TCIAcolon.json"
+ splits16 = "/dataset_LIDC.json"
+
+ list_dir = "./jsons"
+ jsonlist0 = list_dir + splits0
+ # jsonlist1 = list_dir + splits1
+ jsonlist2 = list_dir + splits2
+ jsonlist3 = list_dir + splits3
+ jsonlist4 = list_dir + splits4
+ # jsonlist5 = list_dir + splits5
+ jsonlist6 = list_dir + splits6
+ jsonlist7 = list_dir + splits7
+ jsonlist8 = list_dir + splits8
+ jsonlist9 = list_dir + splits9
+ jsonlist10 = list_dir + splits10
+ jsonlist11 = list_dir + splits11
+ jsonlist12 = list_dir + splits12
+ jsonlist13 = list_dir + splits13
+ jsonlist14 = list_dir + splits14
+ jsonlist15 = list_dir + splits15
+ jsonlist16 = list_dir + splits16
+
+ datadir0 = "./dataset/dataset00_BTCV"
+ # datadir1 = "./dataset/dataset01_BrainTumour"
+ datadir2 = "./dataset/dataset02_Heart"
+ datadir3 = "./dataset/dataset03_Liver"
+ datadir4 = "./dataset/dataset04_Hippocampus"
+ # datadir5 = "./dataset/dataset05_Prostate"
+ datadir6 = "./dataset/dataset06_Lung"
+ datadir7 = "./dataset/dataset07_Pancreas"
+ datadir8 = "./dataset/dataset08_HepaticVessel"
+ datadir9 = "./dataset/dataset09_Spleen"
+ datadir10 = "./dataset/dataset10_Colon"
+ datadir11 = "./dataset/dataset11_TCIAcovid19"
+ datadir12 = "./dataset/dataset12_WORD"
+ datadir13 = "./dataset/dataset13_AbdomenCT-1K"
+ datadir14 = "./dataset/dataset_HNSCC"
+ datadir15 = "./dataset/dataset_TCIAcolon"
+ datadir16 = "./dataset/dataset_LIDC"
+
+ datalist = []
+ for json_path, base_dir in zip(
+ [
+ jsonlist0,
+ jsonlist2,
+ jsonlist3,
+ jsonlist4,
+ jsonlist6,
+ jsonlist7,
+ jsonlist8,
+ jsonlist9,
+ jsonlist10,
+ jsonlist11,
+ jsonlist12,
+ jsonlist13,
+ # jsonlist14,
+ # jsonlist15,
+ # jsonlist16,
+ ],
+ [
+ datadir0,
+ datadir2,
+ datadir3,
+ datadir4,
+ datadir6,
+ datadir7,
+ datadir8,
+ datadir9,
+ datadir10,
+ datadir11,
+ datadir12,
+ datadir13,
+ # datadir14,
+ # datadir15,
+ # datadir16,
+ ],
+ ):
+ datalist_i = load_decathlon_datalist(json_path, False, "training", base_dir=base_dir)
+ datalist.extend([{"image": item["image"]} for item in datalist_i])
+
+ print("Dataset all training: number of data: {}".format(len(datalist)))
+
+ train_transforms = Compose(
+ [
+ LoadImaged(keys=["image"]),
+ AddChanneld(keys=["image"]),
+ Orientationd(keys=["image"], axcodes="RAS"),
+ ScaleIntensityRanged(
+ keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True
+ ),
+ SpatialPadd(keys="image", spatial_size=[args.roi_x, args.roi_y, args.roi_z]),
+ RandSpatialCropSamplesd(
+ keys=["image"],
+ roi_size=[args.roi_x, args.roi_y, args.roi_z],
+ num_samples=args.sw_batch_size,
+ random_center=True,
+ random_size=False,
+ ),
+ ToTensord(keys=["image"]),
+ ]
+ )
+
+ if args.use_normal_dataset:
+ train_ds = Dataset(data=datalist, transform=train_transforms)
+ else:
+ train_ds = dataset_in_memory.CachedDataset(
+ data=datalist,
+ transform=train_transforms,
+ dataset_name="pretrain_train",
+ hosts=[{"host": "localhost", "port": str(port)} for port in args.redis_ports],
+ cluster_mode=True,
+ capacity_per_node=200 * 1024 * 1024 * 1024,
+ writer_buffer_size=0, # Disable write buffer
+ )
+
+ if args.distributed:
+ train_sampler = DistributedSampler(dataset=train_ds, even_divisible=True, shuffle=True)
+ else:
+ train_sampler = None
+ loader_class = DataLoader
+ if not args.nouse_multi_epochs_loader:
+ loader_class = timm.data.loader.MultiEpochsDataLoader
+ train_loader = loader_class(
+ train_ds,
+ batch_size=args.batch_size,
+ num_workers=args.workers,
+ sampler=train_sampler,
+ drop_last=True,
+ collate_fn=list_data_collate,
+ )
+
+ return train_loader, None
diff --git a/SwinMM/Pretrain/utils/dataset_in_memory.py b/SwinMM/Pretrain/utils/dataset_in_memory.py
new file mode 100644
index 00000000..c2d83497
--- /dev/null
+++ b/SwinMM/Pretrain/utils/dataset_in_memory.py
@@ -0,0 +1 @@
+# ../../WORD/utils/dataset_in_memory.py
diff --git a/SwinMM/Pretrain/utils/ops.py b/SwinMM/Pretrain/utils/ops.py
new file mode 100644
index 00000000..133a5ac7
--- /dev/null
+++ b/SwinMM/Pretrain/utils/ops.py
@@ -0,0 +1,30 @@
+from typing import Tuple
+
+import numpy as np
+import torch
+
+
+def mask_rand_patch(
+ window_sizes: Tuple[int, int, int], input_sizes: Tuple[int, int, int], mask_ratio: float, samples: torch.Tensor
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Patch-wise random masking."""
+ if len(window_sizes) != len(input_sizes) or any(
+ [input_size % window_size != 0 for window_size, input_size in zip(window_sizes, input_sizes)]
+ ):
+ raise ValueError(f"{window_sizes} & {input_sizes} is not compatible.")
+
+ mask_shape = [input_size // window_size for input_size, window_size in zip(input_sizes, window_sizes)]
+ num_patches = np.prod(mask_shape).item()
+ mask = np.ones(num_patches, dtype=bool)
+ indices = np.random.choice(num_patches, round(num_patches * mask_ratio), replace=False)
+ mask[indices] = False
+ mask = mask.reshape(mask_shape)
+ wh, ww, wd = window_sizes
+ mask = np.logical_or(mask[:, None, :, None, :, None], np.zeros([1, wh, 1, ww, 1, wd], dtype=bool)).reshape(
+ input_sizes
+ )
+ mask = torch.from_numpy(mask).to(samples.device)
+
+ res = samples.detach().clone()
+ res[:, :, mask] = 0
+ return res, mask
diff --git a/SwinMM/Pretrain/utils/view_ops.py b/SwinMM/Pretrain/utils/view_ops.py
new file mode 100644
index 00000000..6f106bd3
--- /dev/null
+++ b/SwinMM/Pretrain/utils/view_ops.py
@@ -0,0 +1,18 @@
+"""View operations."""
+
+from typing import Tuple
+
+import numpy as np
+import torch
+from utils import view_transforms
+
+
+def rot_rand(xs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ img_n = xs.size()[0]
+ x_aug = xs.detach().clone()
+ x_rot = torch.zeros(img_n, dtype=torch.int64, device=xs.device)
+ for i in range(img_n):
+ orientation = np.random.randint(0, 4)
+ x_aug[i] = view_transforms.rotation_transforms[orientation](xs[i].unsqueeze(0))
+ x_rot[i] = orientation
+ return x_aug, x_rot
diff --git a/SwinMM/Pretrain/utils/view_transforms.py b/SwinMM/Pretrain/utils/view_transforms.py
new file mode 100644
index 00000000..caa29114
--- /dev/null
+++ b/SwinMM/Pretrain/utils/view_transforms.py
@@ -0,0 +1 @@
+# ../../WORD/utils/view_transforms.py
diff --git a/SwinMM/README.md b/SwinMM/README.md
new file mode 100644
index 00000000..374703c3
--- /dev/null
+++ b/SwinMM/README.md
@@ -0,0 +1,84 @@
+# SwinMM: Masked Multi-view with Swin Transformers for 3D Medical Image Segmentation
+
+
+
+
+
+## What is SwinMM?
+
+Masked Multi-view with Swin Transformers, dubbed SwinMM, is the first comprehensive multi-view pipeline for self-supervised medical image analysis. SwinMM yields competitive performance, significantly lower training costs, and higher data efficiency compared to recent state-of-the-art models. SwinMM consists of two key components.
+
+### Pretrain
+
+In the pre-training stage, we introduce a masked multi-view encoder that simultaneously trains masked multi-view observations with a diverse set of proxy tasks. These tasks include image reconstruction, rotation, contrastive learning, and a mutual learning paradigm that comprehensively leverages hidden multi-view information from 3D medical data by maximizing the consistency between predictions from different views.
+
+
+
+
+
+### Finetune
+
+In the fine-tuning stage, a cross-view decoder is developed to aggregate the multi-view information using a novel cross-view attention block.
+
+
+
+
+
+## Pre-trained Models
+
+We present two checkpoints here:
+
+- [pretrained_ckpt.pt](https://drive.google.com/file/d/1VFT1Oz5UGjAaXLdWAAdeD0mVeLyCQ0ry/view?usp=sharing)
+- [WORD_finetuned_ckpt.pt](https://drive.google.com/file/d/1VFT1Oz5UGjAaXLdWAAdeD0mVeLyCQ0ry/view?usp=sharing)
+
+Here is the sample testing result on WORD
+
+
+
+
+
+## Installation
+
+Please check [INSTALL.md](INSTALL.md) for installation instructions.
+
+## Evaluation
+
+Testing can be done using the following scripts. Please change `pretrained_dir` and `pretrained_model_name` according to the path of the checkpoint you would like to test, and change `data_dir` and `json_list` according to the datasets.
+
+```bash
+cd WORD
+python test_parrallel.py --pretrained_dir ./runs/multiview_101616/ \
+ --pretrained_model_name model.pt \
+ --distributed \
+ --data_dir ./dataset/dataset12_WORD/ \
+ --json_list dataset12_WORD.json
+```
+
+## Training
+
+Please check [TRAINING.md](TRAINING.md) for training instructions.
+
+## Acknowledgment
+
+This work is partially supported by Google Cloud Research Credits program.
+This Repo is based on [SwinUNETR](https://github.com/Project-MONAI/research-contributions/tree/main/SwinUNETR), [MONAI](https://monai.io/) and [bagua](https://github.com/BaguaSys/bagua).
+
+## Citation
+
+If you find this repository helpful, please consider citing:
+
+```
+@inproceedings{wang2023SwinMM,
+ title = {SwinMM: Masked Multi-view with Swin Transformers for 3D Medical Image Segmentation},
+ author = {Wang, Yiqing and Li, Zihan and Mei, Jieru and Wei, Zihao and Liu, Li and Wang, Chen and Sang, Shengtian and Yuille, Alan and Xie, Cihang and Zhou, Yuyin},
+ booktitle = {MICCAI},
+ year = {2023}
+}
+
+@article{cardoso2022monai,
+ title={Monai: An open-source framework for deep learning in healthcare},
+ author={Cardoso, M Jorge and Li, Wenqi and Brown, Richard and Ma, Nic and Kerfoot, Eric and Wang, Yiheng and Murrey, Benjamin and Myronenko, Andriy and Zhao, Can and Yang, Dong and others},
+ journal={arXiv preprint arXiv:2211.02701},
+ year={2022}
+}
+```
diff --git a/SwinMM/TRAINING.md b/SwinMM/TRAINING.md
new file mode 100644
index 00000000..58d7e4ef
--- /dev/null
+++ b/SwinMM/TRAINING.md
@@ -0,0 +1,33 @@
+# Training
+
+## Launch Redis
+
+Launch in-memory database, only need once
+
+```bash
+# It launches redis at ports 39996-39999.
+bash ./scripts/start_redis.sh
+```
+
+**NOTE**
+
+- If **data or preprocessing** changed, run `pkill redis-server` before further experiments
+- Try `--workers` from 8 to 32 for best performance
+- First epoch after launch the server could be slow, but should be fast later
+- Set `--redis_ports ` according to your redis setup.
+
+## Pre-training
+
+```bash
+cd Pretrain
+bash run.sh
+```
+
+## Finetuning
+
+- Prepare pretrained models: Copy the pretrained model to the `pretrained_models` directory in BTCV
+
+```bash
+cd WORD
+bash run.sh
+```
diff --git a/SwinMM/WORD/README.md b/SwinMM/WORD/README.md
new file mode 100644
index 00000000..2089dbb1
--- /dev/null
+++ b/SwinMM/WORD/README.md
@@ -0,0 +1,42 @@
+# Note for SwinMM Finetuning
+
+## Training
+
+### FIXME(outdated)
+
+```bash
+python main.py
+--feature_size=48
+--batch_size=1
+--logdir="swin_mm_test/"
+--roi_x=64
+--roi_y=64
+--roi_z=64
+--optim_lr=1e-4
+--lrschedule="warmup_cosine"
+--infer_overlap=0.5
+--save_checkpoint
+--data_dir="/dataset/dataset0/"
+--distributed
+--use_ssl_pretrained
+--pretrained_dir="./pretrained_models/"
+--pretrained_model_name="model_bestValRMSE.pt"
+```
+
+## Testing
+
+### FIXME(outdated)
+
+```bash
+python test.py
+--feature_size=48
+--batch_size=1
+--exp_name="swin_mm_test/"
+--roi_x=64
+--roi_y=64
+--roi_z=64
+--infer_overlap=0.5
+--data_dir="/dataset/dataset0/"
+--pretrained_dir="./runs/multiview_101021/"
+--pretrained_model_name="model.pt"
+```
diff --git a/SwinMM/WORD/dataset/Download Finetune Jsons Here b/SwinMM/WORD/dataset/Download Finetune Jsons Here
new file mode 100644
index 00000000..e69de29b
diff --git a/SwinMM/WORD/dataset/__init__.py b/SwinMM/WORD/dataset/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/SwinMM/WORD/inferers.py b/SwinMM/WORD/inferers.py
new file mode 100644
index 00000000..c943ca13
--- /dev/null
+++ b/SwinMM/WORD/inferers.py
@@ -0,0 +1,292 @@
+"""Multiview inferer."""
+
+import warnings
+from typing import Any, Callable, Dict, List, Mapping, Sequence, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from utils import view_ops, view_transforms
+
+from monai.data.utils import compute_importance_map, dense_patch_slices, get_valid_patch_size
+from monai.inferers.utils import _get_scan_interval
+from monai.transforms import Resize
+from monai.utils import (
+ BlendMode,
+ PytorchPadMode,
+ convert_data_type,
+ ensure_tuple,
+ fall_back_tuple,
+ look_up_option,
+ optional_import,
+)
+
+tqdm, _ = optional_import("tqdm", name="tqdm")
+
+
+def double_sliding_window_inference(
+ inputs: torch.Tensor,
+ view: int,
+ roi_size: Union[Sequence[int], int],
+ sw_batch_size: int,
+ predictor: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor], Dict[Any, torch.Tensor]]],
+ overlap: float = 0.25,
+ mode: Union[BlendMode, str] = BlendMode.CONSTANT,
+ sigma_scale: Union[Sequence[float], float] = 0.125,
+ padding_mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT,
+ cval: float = 0.0,
+ sw_device: Union[torch.device, str, None] = None,
+ device: Union[torch.device, str, None] = None,
+ progress: bool = False,
+ roi_weight_map: Union[torch.Tensor, None] = None,
+ *args: Any,
+ **kwargs: Any,
+) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[Any, torch.Tensor]]:
+ """
+ Sliding window inference on two `inputs` with `predictor`.
+
+ The outputs of `predictor` could be a tensor, a tuple, or a dictionary of tensors.
+ Each output in the tuple or dict value is allowed to have different resolutions with respect to the input.
+ e.g., the input patch spatial size is [128,128,128], the output (a tuple of two patches) patch sizes
+ could be ([128,64,256], [64,32,128]).
+ In this case, the parameter `overlap` and `roi_size` need to be carefully chosen to ensure the output ROI is still
+ an integer. If the predictor's input and output spatial sizes are not equal, we recommend choosing the parameters
+ so that `overlap*roi_size*output_size/input_size` is an integer (for each spatial dimension).
+
+ When roi_size is larger than the inputs' spatial size, the input image are padded during inference.
+ To maintain the same spatial sizes, the output image will be cropped to the original input size.
+
+ Args:
+ inputs: input image to be processed (assuming NCHW[D])
+ roi_size: the spatial window size for inferences.
+ When its components have None or non-positives, the corresponding inputs dimension will be used.
+ if the components of the `roi_size` are non-positive values, the transform will use the
+ corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted
+ to `(32, 64)` if the second spatial dimension size of img is `64`.
+ sw_batch_size: the batch size to run window slices.
+ predictor: given input tensor ``patch_data`` in shape NCHW[D],
+ The outputs of the function call ``predictor(patch_data)`` should be a tensor, a tuple, or a dictionary
+ with Tensor values. Each output in the tuple or dict value should have the same batch_size, i.e. NM'H'W'[D'];
+ where H'W'[D'] represents the output patch's spatial size, M is the number of output channels,
+ N is `sw_batch_size`, e.g., the input shape is (7, 1, 128,128,128),
+ the output could be a tuple of two tensors, with shapes: ((7, 5, 128, 64, 256), (7, 4, 64, 32, 128)).
+ In this case, the parameter `overlap` and `roi_size` need to be carefully chosen
+ to ensure the scaled output ROI sizes are still integers.
+ If the `predictor`'s input and output spatial sizes are different,
+ we recommend choosing the parameters so that ``overlap*roi_size*zoom_scale`` is an integer for each dimension.
+ overlap: Amount of overlap between scans.
+ mode: {``"constant"``, ``"gaussian"``}
+ How to blend output of overlapping windows. Defaults to ``"constant"``.
+
+ - ``"constant``": gives equal weight to all predictions.
+ - ``"gaussian``": gives less weight to predictions on edges of windows.
+
+ sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``"gaussian"``.
+ Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``.
+ When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding
+ spatial dimensions.
+ padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}
+ Padding mode for ``inputs``, when ``roi_size`` is larger than inputs. Defaults to ``"constant"``
+ See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
+ cval: fill value for 'constant' padding mode. Default: 0
+ sw_device: device for the window data.
+ By default the device (and accordingly the memory) of the `inputs` is used.
+ Normally `sw_device` should be consistent with the device where `predictor` is defined.
+ device: device for the stitched output prediction.
+ By default the device (and accordingly the memory) of the `inputs` is used. If for example
+ set to device=torch.device('cpu') the gpu memory consumption is less and independent of the
+ `inputs` and `roi_size`. Output is on the `device`.
+ progress: whether to print a `tqdm` progress bar.
+ roi_weight_map: pre-computed (non-negative) weight map for each ROI.
+ If not given, and ``mode`` is not `constant`, this map will be computed on the fly.
+ args: optional args to be passed to ``predictor``.
+ kwargs: optional keyword args to be passed to ``predictor``.
+
+ Note:
+ - input must be channel-first and have a batch dim, supports N-D sliding window.
+
+ """
+ compute_dtype = inputs.dtype
+ num_spatial_dims = len(inputs.shape) - 2
+ if overlap < 0 or overlap >= 1:
+ raise ValueError("overlap must be >= 0 and < 1.")
+
+ # determine image spatial size and batch size
+ # Note: all input images must have the same image size and batch size
+ batch_size, _, *image_size_ = inputs.shape
+
+ if device is None:
+ device = inputs.device
+ if sw_device is None:
+ sw_device = inputs.device
+
+ roi_size = fall_back_tuple(roi_size, image_size_)
+ # in case that image size is smaller than roi size
+ image_size = tuple(max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims))
+ pad_size = []
+ for k in range(len(inputs.shape) - 1, 1, -1):
+ diff = max(roi_size[k - 2] - inputs.shape[k], 0)
+ half = diff // 2
+ pad_size.extend([half, diff - half])
+ inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode).value, value=cval)
+ # inputs2 = F.pad(inputs2, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode).value, value=cval)
+
+ scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap)
+
+ # Store all slices in list
+ slices = dense_patch_slices(image_size, roi_size, scan_interval)
+ num_win = len(slices) # number of windows per image
+ total_slices = num_win * batch_size # total number of windows
+
+ # Create window-level importance map
+ valid_patch_size = get_valid_patch_size(image_size, roi_size)
+ if valid_patch_size == roi_size and (roi_weight_map is not None):
+ importance_map = roi_weight_map
+ else:
+ try:
+ importance_map = compute_importance_map(valid_patch_size, mode=mode, sigma_scale=sigma_scale, device=device)
+ except BaseException as e:
+ raise RuntimeError(
+ "Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'."
+ ) from e
+ importance_map = convert_data_type(importance_map, torch.Tensor, device, compute_dtype)[0] # type: ignore
+ # handle non-positive weights
+ min_non_zero = max(importance_map[importance_map != 0].min().item(), 1e-3)
+ importance_map = torch.clamp(importance_map.to(torch.float32), min=min_non_zero).to(compute_dtype)
+
+ # Perform predictions
+ dict_key, output_image_list_1, output_image_list_2, count_map_list = None, [], [], []
+ _initialized_ss = -1
+ is_tensor_output = True # whether the predictor's output is a tensor (instead of dict/tuple)
+
+ # for each patch
+ for slice_g in tqdm(range(0, total_slices, sw_batch_size)) if progress else range(0, total_slices, sw_batch_size):
+ slice_range = range(slice_g, min(slice_g + sw_batch_size, total_slices))
+ unravel_slice = [
+ [slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)] + list(slices[idx % num_win])
+ for idx in slice_range
+ ]
+ window_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device)
+ view_list = [view, (view + 1) % len(view_transforms.permutation_transforms)]
+ window_data_list = [view_ops.get_permute_transform(0, dst)(window_data) for dst in view_list]
+ # window_data_2 = torch.cat([inputs2[win_slice] for win_slice in unravel_slice]).to(sw_device)
+ seg_prob_out_1, seg_prob_out_2 = predictor(
+ window_data_list[0], window_data_list[1], view_list, *args, **kwargs
+ ) # batched patch segmentation
+ seg_prob_out_1, seg_prob_out_2 = view_ops.permute_inverse([seg_prob_out_1, seg_prob_out_2], view_list)
+
+ # convert seg_prob_out to tuple seg_prob_tuple, this does not allocate new memory.
+ seg_prob_tuple_1: Tuple[torch.Tensor, ...]
+ seg_prob_tuple_2: Tuple[torch.Tensor, ...]
+ if isinstance(seg_prob_out_1, torch.Tensor):
+ seg_prob_tuple_1 = (seg_prob_out_1,)
+ seg_prob_tuple_2 = (seg_prob_out_2,)
+ elif isinstance(seg_prob_out_1, Mapping):
+ if dict_key is None:
+ dict_key = sorted(seg_prob_out_1.keys()) # track predictor's output keys
+ seg_prob_tuple_1 = tuple(seg_prob_out_1[k] for k in dict_key)
+ seg_prob_tuple_2 = tuple(seg_prob_out_2[k] for k in dict_key)
+ is_tensor_output = False
+ else:
+ seg_prob_tuple_1 = ensure_tuple(seg_prob_out_1)
+ seg_prob_tuple_2 = ensure_tuple(seg_prob_out_2)
+ is_tensor_output = False
+
+ # for each output in multi-output list
+ for ss in range(len(seg_prob_tuple_1)):
+ seg_prob_1 = seg_prob_tuple_1[ss].to(device) # BxCxMxNxP or BxCxMxN
+ seg_prob_2 = seg_prob_tuple_2[ss].to(device)
+
+ # compute zoom scale: out_roi_size/in_roi_size
+ zoom_scale = []
+ for axis, (img_s_i, out_w_i, in_w_i) in enumerate(
+ zip(image_size, seg_prob_1.shape[2:], window_data.shape[2:])
+ ):
+ _scale = out_w_i / float(in_w_i)
+ if not (img_s_i * _scale).is_integer():
+ warnings.warn(
+ f"For spatial axis: {axis}, output[{ss}] will have non-integer shape. Spatial "
+ f"zoom_scale between output[{ss}] and input is {_scale}. Please pad inputs."
+ )
+ zoom_scale.append(_scale)
+
+ if _initialized_ss < ss: # init. the ss-th buffer at the first iteration
+ # construct multi-resolution outputs
+ output_classes = seg_prob_1.shape[1]
+ output_shape = [batch_size, output_classes] + [
+ int(image_size_d * zoom_scale_d) for image_size_d, zoom_scale_d in zip(image_size, zoom_scale)
+ ]
+ # allocate memory to store the full output and the count for overlapping parts
+ output_image_list_1.append(torch.zeros(output_shape, dtype=compute_dtype, device=device))
+ output_image_list_2.append(torch.zeros(output_shape, dtype=compute_dtype, device=device))
+ count_map_list.append(torch.zeros([1, 1] + output_shape[2:], dtype=compute_dtype, device=device))
+ _initialized_ss += 1
+
+ # resizing the importance_map
+ resizer = Resize(spatial_size=seg_prob_1.shape[2:], mode="nearest", anti_aliasing=False)
+
+ # store the result in the proper location of the full output. Apply weights from importance map.
+ for idx, original_idx in zip(slice_range, unravel_slice):
+ # zoom roi
+ original_idx_zoom = list(original_idx) # 4D for 2D image, 5D for 3D image
+ for axis in range(2, len(original_idx_zoom)):
+ zoomed_start = original_idx[axis].start * zoom_scale[axis - 2]
+ zoomed_end = original_idx[axis].stop * zoom_scale[axis - 2]
+ if not zoomed_start.is_integer() or (not zoomed_end.is_integer()):
+ warnings.warn(
+ f"For axis-{axis-2} of output[{ss}], the output roi range is not int. "
+ f"Input roi range is ({original_idx[axis].start}, {original_idx[axis].stop}). "
+ f"Spatial zoom_scale between output[{ss}] and input is {zoom_scale[axis - 2]}. "
+ f"Corresponding output roi range is ({zoomed_start}, {zoomed_end}).\n"
+ f"Please change overlap ({overlap}) or roi_size ({roi_size[axis-2]}) for axis-{axis-2}. "
+ "Tips: if overlap*roi_size*zoom_scale is an integer, it usually works."
+ )
+ original_idx_zoom[axis] = slice(int(zoomed_start), int(zoomed_end), None)
+ importance_map_zoom = resizer(importance_map.unsqueeze(0))[0].to(compute_dtype)
+ # store results and weights
+ output_image_list_1[ss][original_idx_zoom] += importance_map_zoom * seg_prob_1[idx - slice_g]
+ output_image_list_2[ss][original_idx_zoom] += importance_map_zoom * seg_prob_2[idx - slice_g]
+ count_map_list[ss][original_idx_zoom] += (
+ importance_map_zoom.unsqueeze(0).unsqueeze(0).expand(count_map_list[ss][original_idx_zoom].shape)
+ )
+
+ # account for any overlapping sections
+ for ss in range(len(output_image_list_1)):
+ count_map_pop = count_map_list.pop(0)
+ output_image_list_1[ss] = (output_image_list_1[ss] / count_map_pop).to(compute_dtype)
+ output_image_list_2[ss] = (output_image_list_2[ss] / count_map_pop).to(compute_dtype)
+
+ # remove padding if image_size smaller than roi_size
+ for ss in range(len(output_image_list_1)):
+ output_i_1, output_i_2 = output_image_list_1[ss], output_image_list_2[ss]
+ if torch.isnan(output_i_1).any() or torch.isinf(output_i_1).any():
+ warnings.warn("Sliding window inference results contain NaN or Inf.")
+ if torch.isnan(output_i_2).any() or torch.isinf(output_i_2).any():
+ warnings.warn("Sliding window inference results contain NaN or Inf.")
+
+ zoom_scale = [
+ seg_prob_map_shape_d / roi_size_d
+ for seg_prob_map_shape_d, roi_size_d in zip(output_i_1.shape[2:], roi_size)
+ ]
+
+ final_slicing: List[slice] = []
+ for sp in range(num_spatial_dims):
+ slice_dim = slice(pad_size[sp * 2], image_size_[num_spatial_dims - sp - 1] + pad_size[sp * 2])
+ slice_dim = slice(
+ int(round(slice_dim.start * zoom_scale[num_spatial_dims - sp - 1])),
+ int(round(slice_dim.stop * zoom_scale[num_spatial_dims - sp - 1])),
+ )
+ final_slicing.insert(0, slice_dim)
+ while len(final_slicing) < len(output_i_1.shape):
+ final_slicing.insert(0, slice(None))
+ output_image_list_1[ss] = output_i_1[final_slicing]
+ output_image_list_2[ss] = output_i_2[final_slicing]
+
+ if dict_key is not None: # if output of predictor is a dict
+ final_output_1 = dict(zip(dict_key, output_image_list_1))
+ final_output_2 = dict(zip(dict_key, output_image_list_2))
+ else:
+ final_output_1 = tuple(output_image_list_1) # type: ignore
+ final_output_2 = tuple(output_image_list_2) # type: ignore
+ final_output_1 = final_output_1[0] if is_tensor_output else final_output_1 # type: ignore
+ final_output_2 = final_output_2[0] if is_tensor_output else final_output_2 # type: ignore
+ return final_output_1, final_output_2
diff --git a/SwinMM/WORD/main.py b/SwinMM/WORD/main.py
new file mode 100644
index 00000000..6d5872b4
--- /dev/null
+++ b/SwinMM/WORD/main.py
@@ -0,0 +1,292 @@
+# Copyright 2020 - 2022 MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import logging
+import os
+from functools import partial
+
+import numpy as np
+import timm.optim.optim_factory as optim_factory
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+import torch.nn.parallel
+import torch.utils.data.distributed
+from inferers import double_sliding_window_inference
+from models import SwinUNETR
+from optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
+from timm.utils import setup_default_logging
+from torch.nn import KLDivLoss
+from trainer import run_training
+from utils.data_utils import get_loader
+from utils.dataset_in_memory import hijack_bagua_serialization
+
+from monai.losses import DiceCELoss
+from monai.metrics import DiceMetric
+from monai.transforms import AsDiscrete
+from monai.utils.enums import MetricReduction
+
+parser = argparse.ArgumentParser(description="Swin UNETR segmentation pipeline")
+parser.add_argument("--checkpoint", default=None, help="start training from saved checkpoint")
+parser.add_argument("--logdir", default="multiview_101616/", type=str, help="directory to save the tensorboard logs")
+parser.add_argument(
+ "--pretrained_dir", default="./pretrained_models/", type=str, help="pretrained checkpoint directory"
+)
+parser.add_argument("--data_dir", default="./dataset/dataset12_WORD/", type=str, help="dataset directory")
+parser.add_argument("--json_list", default="dataset12_WORD.json", type=str, help="dataset json file")
+parser.add_argument("--pretrained_model_name", default="model_bestValRMSE.pt", type=str, help="pretrained model name")
+parser.add_argument("--save_checkpoint", action="store_true", help="save checkpoint during training")
+parser.add_argument("--max_epochs", default=1500, type=int, help="max number of training epochs")
+parser.add_argument("--batch_size", default=1, type=int, help="number of batch size")
+parser.add_argument("--sw_batch_size", default=16, type=int, help="number of sliding window batch size")
+parser.add_argument("--optim_lr", default=1e-4, type=float, help="optimization learning rate")
+parser.add_argument("--optim_name", default="adamw", type=str, help="optimization algorithm")
+parser.add_argument("--reg_weight", default=1e-5, type=float, help="regularization weight")
+parser.add_argument("--layer_decay", default=1.0, type=float, help="layer-wise learning rate decay")
+parser.add_argument("--momentum", default=0.99, type=float, help="momentum")
+parser.add_argument("--noamp", action="store_true", help="do NOT use amp for training")
+parser.add_argument("--val_every", default=100, type=int, help="validation frequency")
+parser.add_argument("--val_start", default=1000, type=int, help="val start from epoch")
+parser.add_argument("--unsuper_every", default=1, type=int, help="unsupervised training frequency")
+parser.add_argument("--unsuper_start", default=100, type=int, help="unsupervised training frequency")
+parser.add_argument("--unsupervised", action="store_true", help="start unsupervised training")
+parser.add_argument("--distributed", action="store_true", help="start distributed training")
+parser.add_argument("--world_size", default=1, type=int, help="number of nodes for distributed training")
+parser.add_argument("--rank", default=0, type=int, help="node rank for distributed training")
+parser.add_argument("--dist-url", default="tcp://127.0.0.1:23456", type=str, help="distributed url")
+parser.add_argument("--dist-backend", default="nccl", type=str, help="distributed backend")
+parser.add_argument("--workers", default=4, type=int, help="number of workers")
+parser.add_argument("--feature_size", default=48, type=int, help="feature size")
+parser.add_argument("--in_channels", default=1, type=int, help="number of input channels")
+parser.add_argument("--out_channels", default=17, type=int, help="number of output channels")
+parser.add_argument("--use_normal_dataset", action="store_true", help="use monai Dataset class")
+parser.add_argument("--use_normal_dataset_val", action="store_true", help="use monai Dataset class for val")
+parser.add_argument(
+ "--nouse_multi_epochs_loader",
+ action="store_true",
+ help="not use the multi-epochs-loader to save time at the beginning of every epoch",
+)
+parser.add_argument("--a_min", default=-175.0, type=float, help="a_min in ScaleIntensityRanged")
+parser.add_argument("--a_max", default=250.0, type=float, help="a_max in ScaleIntensityRanged")
+parser.add_argument("--b_min", default=0.0, type=float, help="b_min in ScaleIntensityRanged")
+parser.add_argument("--b_max", default=1.0, type=float, help="b_max in ScaleIntensityRanged")
+parser.add_argument("--space_x", default=1.5, type=float, help="spacing in x direction")
+parser.add_argument("--space_y", default=1.5, type=float, help="spacing in y direction")
+parser.add_argument("--space_z", default=2.0, type=float, help="spacing in z direction")
+parser.add_argument("--roi_x", default=64, type=int, help="roi size in x direction")
+parser.add_argument("--roi_y", default=64, type=int, help="roi size in y direction")
+parser.add_argument("--roi_z", default=64, type=int, help="roi size in z direction")
+parser.add_argument("--dropout_rate", default=0.0, type=float, help="dropout rate")
+parser.add_argument("--dropout_path_rate", default=0.0, type=float, help="drop path rate")
+parser.add_argument("--RandFlipd_prob", default=0.2, type=float, help="RandFlipd aug probability")
+parser.add_argument("--RandRotate90d_prob", default=0.2, type=float, help="RandRotate90d aug probability")
+parser.add_argument("--RandScaleIntensityd_prob", default=0.1, type=float, help="RandScaleIntensityd aug probability")
+parser.add_argument("--RandShiftIntensityd_prob", default=0.1, type=float, help="RandShiftIntensityd aug probability")
+parser.add_argument("--infer_overlap", default=0.5, type=float, help="sliding window inference overlap")
+parser.add_argument("--lrschedule", default="warmup_cosine", type=str, help="type of learning rate scheduler")
+parser.add_argument("--warmup_epochs", default=50, type=int, help="number of warmup epochs")
+parser.add_argument("--resume_ckpt", action="store_true", help="resume training from pretrained checkpoint")
+parser.add_argument("--smooth_dr", default=1e-6, type=float, help="constant added to dice denominator to avoid nan")
+parser.add_argument("--smooth_nr", default=0.0, type=float, help="constant added to dice numerator to avoid zero")
+parser.add_argument("--use_checkpoint", action="store_true", help="use gradient checkpointing to save memory")
+parser.add_argument("--use_ssl_pretrained", action="store_true", help="use self-supervised pretrained weights")
+parser.add_argument("--spatial_dims", default=3, type=int, help="spatial dimension of input data")
+parser.add_argument("--squared_dice", action="store_true", help="use squared Dice")
+parser.add_argument("--norm_name", default="batch", help="multi gpu use")
+parser.add_argument(
+ "--cross_attention_in_origin_view", action="store_true", help="Whether compute cross attention in original view"
+)
+parser.add_argument("--redis_ports", nargs="+", type=int, help="redis ports")
+parser.add_argument("--redis_compression", type=str, default=None, help="compression method for redis.")
+
+
+def main():
+ args = parser.parse_args()
+ args.amp = not args.noamp
+ args.logdir = "./runs/" + args.logdir
+ if args.distributed:
+ args.ngpus_per_node = torch.cuda.device_count()
+ print("Found total gpus", args.ngpus_per_node)
+ args.world_size = args.ngpus_per_node * args.world_size
+ mp.spawn(main_worker, nprocs=args.ngpus_per_node, args=(args,))
+ else:
+ main_worker(gpu=0, args=args)
+
+
+def main_worker(gpu, args):
+ if args.redis_compression is not None:
+ hijack_bagua_serialization(args.redis_compression)
+
+ if args.distributed:
+ torch.multiprocessing.set_start_method("fork", force=True)
+ np.set_printoptions(formatter={"float": "{: 0.3f}".format}, suppress=True)
+ args.gpu = gpu
+ if args.distributed:
+ args.rank = args.rank * args.ngpus_per_node + gpu
+ dist.init_process_group(
+ backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
+ )
+ torch.cuda.set_device(args.gpu)
+ torch.backends.cudnn.benchmark = True
+ args.test_mode = False
+ loader = get_loader(args)
+ print(args.rank, " gpu", args.gpu)
+ if args.rank == 0:
+ setup_default_logging()
+ logging.info(f"Batch size is: {args.batch_size}, epochs: {args.max_epochs}")
+ inf_size = [args.roi_x, args.roi_y, args.roi_z]
+
+ pretrained_dir = args.pretrained_dir
+ model = SwinUNETR(
+ img_size=(args.roi_x, args.roi_y, args.roi_z),
+ in_channels=args.in_channels,
+ out_channels=args.out_channels,
+ feature_size=args.feature_size,
+ fusion_depths=(1, 1, 1, 1, 1, 1),
+ drop_rate=0.0,
+ attn_drop_rate=0.0,
+ dropout_path_rate=args.dropout_path_rate,
+ use_checkpoint=args.use_checkpoint,
+ cross_attention_in_origin_view=args.cross_attention_in_origin_view,
+ )
+
+ if args.resume_ckpt:
+ model_dict = torch.load(os.path.join(pretrained_dir, args.pretrained_model_name), map_location="cpu")[
+ "state_dict"
+ ]
+ model.load_state_dict(model_dict)
+ logging.info("Use pretrained weights")
+
+ if args.use_ssl_pretrained:
+ try:
+ model_dict = torch.load(os.path.join(pretrained_dir, args.pretrained_model_name), map_location="cpu")
+ state_dict = model_dict["state_dict"]
+ # fix potential differences in state dict keys from pre-training to
+ # fine-tuning
+ if "module." in list(state_dict.keys())[0]:
+ logging.info("Tag 'module.' found in state dict - fixing!")
+ for key in list(state_dict.keys()):
+ state_dict[key.replace("module.", "")] = state_dict.pop(key)
+ if "swin_vit" in list(state_dict.keys())[0]:
+ logging.info("Tag 'swin_vit' found in state dict - fixing!")
+ for key in list(state_dict.keys()):
+ state_dict[key.replace("swin_vit", "swinViT")] = state_dict.pop(key)
+ # We now load model weights, setting param `strict` to False, i.e.:
+ # this load the encoder weights (Swin-ViT, SSL pre-trained), but leaves
+ # the decoder weights untouched (CNN UNet decoder).
+ model.load_state_dict(state_dict, strict=False)
+ logging.info("Using pretrained self-supervised Swin UNETR backbone weights !")
+ except ValueError:
+ raise ValueError("Self-supervised pre-trained weights not available for" + str(args.model_name))
+
+ if args.squared_dice:
+ dice_loss = DiceCELoss(
+ to_onehot_y=True, softmax=True, squared_pred=True, smooth_nr=args.smooth_nr, smooth_dr=args.smooth_dr
+ )
+ else:
+ dice_loss = DiceCELoss(to_onehot_y=True, softmax=True)
+ mutual_loss = KLDivLoss(reduction="mean") # CosineSimilarity(dim = 1)
+ post_label = AsDiscrete(to_onehot=True, n_classes=args.out_channels)
+ post_pred = AsDiscrete(argmax=True, to_onehot=True, n_classes=args.out_channels)
+ dice_acc = DiceMetric(include_background=True, reduction=MetricReduction.MEAN, get_not_nans=True)
+ model_inferer = partial(
+ double_sliding_window_inference,
+ roi_size=inf_size,
+ sw_batch_size=args.sw_batch_size,
+ predictor=model,
+ overlap=args.infer_overlap,
+ )
+
+ pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ logging.info(f"Total parameters count: {pytorch_total_params}")
+
+ best_acc = 0
+ start_epoch = 0
+
+ if args.checkpoint is not None:
+ checkpoint = torch.load(args.checkpoint, map_location="cpu")
+ from collections import OrderedDict
+
+ new_state_dict = OrderedDict()
+ for k, v in checkpoint["state_dict"].items():
+ new_state_dict[k.replace("backbone.", "")] = v
+ model.load_state_dict(new_state_dict, strict=False)
+ if "epoch" in checkpoint:
+ start_epoch = checkpoint["epoch"] + 1
+ if "best_acc" in checkpoint:
+ best_acc = checkpoint["best_acc"]
+ logging.info("=> loaded checkpoint '{}' (epoch {}) (bestacc {})".format(args.checkpoint, start_epoch, best_acc))
+
+ model.cuda(args.gpu)
+ model_without_ddp = model
+ if args.distributed:
+ torch.cuda.set_device(args.gpu)
+ if args.norm_name == "batch":
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
+ model.cuda(args.gpu)
+ model_without_ddp = model
+ model = torch.nn.parallel.DistributedDataParallel(
+ model, device_ids=[args.gpu], output_device=args.gpu, broadcast_buffers=False, find_unused_parameters=True
+ )
+
+ # build optimizer with layer-wise lr decay (lrd)
+ param_groups = optim_factory.param_groups_layer_decay(
+ model_without_ddp,
+ weight_decay=args.reg_weight,
+ no_weight_decay_list=model_without_ddp.no_weight_decay(),
+ layer_decay=args.layer_decay,
+ verbose=False,
+ )
+ if args.optim_name == "adam":
+ optimizer = torch.optim.Adam(param_groups, lr=args.optim_lr)
+ elif args.optim_name == "adamw":
+ optimizer = torch.optim.AdamW(param_groups, lr=args.optim_lr)
+ elif args.optim_name == "sgd":
+ optimizer = torch.optim.SGD(param_groups, lr=args.optim_lr, momentum=args.momentum, nesterov=True)
+ else:
+ raise ValueError("Unsupported Optimization Procedure: " + str(args.optim_name))
+
+ if args.lrschedule == "warmup_cosine":
+ scheduler = LinearWarmupCosineAnnealingLR(
+ optimizer, warmup_epochs=args.warmup_epochs, max_epochs=args.max_epochs
+ )
+ elif args.lrschedule == "cosine_anneal":
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.max_epochs)
+ if args.checkpoint is not None:
+ scheduler.step(epoch=start_epoch)
+ else:
+ scheduler = None
+
+ unsupervised_loader = None
+ if args.unsupervised:
+ unsupervised_loader = loader[2]
+ accuracy = run_training(
+ model=model,
+ train_loader=loader[0],
+ val_loader=loader[1],
+ unsupervised_loader=unsupervised_loader,
+ optimizer=optimizer,
+ self_crit=dice_loss,
+ mutual_crit=mutual_loss,
+ acc_func=dice_acc,
+ args=args,
+ model_inferer=model_inferer,
+ scheduler=scheduler,
+ start_epoch=start_epoch,
+ post_label=post_label,
+ post_pred=post_pred,
+ )
+ return accuracy
+
+
+if __name__ == "__main__":
+ main()
diff --git a/SwinMM/WORD/models/__init__.py b/SwinMM/WORD/models/__init__.py
new file mode 100644
index 00000000..a639a968
--- /dev/null
+++ b/SwinMM/WORD/models/__init__.py
@@ -0,0 +1 @@
+from .swin_unetr import *
diff --git a/SwinMM/WORD/models/cross_attention.py b/SwinMM/WORD/models/cross_attention.py
new file mode 100644
index 00000000..61ca832f
--- /dev/null
+++ b/SwinMM/WORD/models/cross_attention.py
@@ -0,0 +1,161 @@
+"""TransFusion from TransFusion: Multi-view Divergent Fusion for Medical Image Segmentation with Transformers."""
+import copy
+import math
+from typing import Sequence, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+from utils.view_ops import get_permute_transform, permute_inverse
+
+
+class Attention(nn.Module):
+ def __init__(self, num_heads=8, hidden_size=768, atte_dropout_rate=0.0):
+ super(Attention, self).__init__()
+ # self.vis = vis
+ self.num_attention_heads = num_heads
+ self.attention_head_size = int(hidden_size / self.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(hidden_size, self.all_head_size)
+ self.key = nn.Linear(hidden_size, self.all_head_size)
+ self.value = nn.Linear(hidden_size, self.all_head_size)
+
+ self.out = nn.Linear(hidden_size, hidden_size)
+ self.attn_dropout = nn.Dropout(atte_dropout_rate)
+ self.proj_dropout = nn.Dropout(atte_dropout_rate)
+
+ self.softmax = nn.Softmax(dim=-1)
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(*new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(self, x_1, x_2):
+ mixed_query_layer_1 = self.query(x_1)
+ mixed_key_layer_1 = self.key(x_1)
+ mixed_value_layer_1 = self.value(x_1)
+ query_layer_1 = self.transpose_for_scores(mixed_query_layer_1)
+ key_layer_1 = self.transpose_for_scores(mixed_key_layer_1)
+ value_layer_1 = self.transpose_for_scores(mixed_value_layer_1)
+ mixed_query_layer_2 = self.query(x_2)
+ mixed_key_layer_2 = self.key(x_2)
+ mixed_value_layer_2 = self.value(x_2)
+ query_layer_2 = self.transpose_for_scores(mixed_query_layer_2)
+ key_layer_2 = self.transpose_for_scores(mixed_key_layer_2)
+ value_layer_2 = self.transpose_for_scores(mixed_value_layer_2)
+
+ attention_scores_1 = torch.matmul(query_layer_1, key_layer_2.transpose(-1, -2))
+ attention_scores_1 = attention_scores_1 / math.sqrt(self.attention_head_size)
+ attention_probs_1 = self.softmax(attention_scores_1)
+ # weights_st = attention_probs_st if self.vis else None
+ attention_probs_1 = self.attn_dropout(attention_probs_1)
+ context_layer_1 = torch.matmul(attention_probs_1, value_layer_2)
+ context_layer_1 = context_layer_1.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape_1 = context_layer_1.size()[:-2] + (self.all_head_size,)
+ context_layer_1 = context_layer_1.view(*new_context_layer_shape_1)
+ attention_output_1 = self.out(context_layer_1)
+ attention_output_1 = self.proj_dropout(attention_output_1)
+
+ attention_scores_2 = torch.matmul(query_layer_2, key_layer_1.transpose(-1, -2))
+ attention_scores_2 = attention_scores_2 / math.sqrt(self.attention_head_size)
+ attention_probs_2 = self.softmax(attention_scores_2)
+ # weights_st = attention_probs_st if self.vis else None
+ attention_probs_2 = self.attn_dropout(attention_probs_2)
+ context_layer_2 = torch.matmul(attention_probs_2, value_layer_1)
+ context_layer_2 = context_layer_2.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape_2 = context_layer_2.size()[:-2] + (self.all_head_size,)
+ context_layer_2 = context_layer_2.view(*new_context_layer_shape_2)
+ attention_output_2 = self.out(context_layer_2)
+ attention_output_2 = self.proj_dropout(attention_output_2)
+
+ return attention_output_1, attention_output_2
+
+
+class Block(nn.Module):
+ def __init__(self, hidden_size=768, mlp_dim=1536, dropout_rate=0.5, num_heads=8, atte_dropout_rate=0.0):
+ super(Block, self).__init__()
+
+ del mlp_dim
+ del dropout_rate
+
+ self.hidden_size = hidden_size
+ self.attention_norm = nn.LayerNorm(hidden_size, eps=1e-6)
+ self.attn = Attention(num_heads=num_heads, hidden_size=hidden_size, atte_dropout_rate=atte_dropout_rate)
+
+ def forward(self, x_1, x_2):
+ x_1 = self.attention_norm(x_1)
+ x_2 = self.attention_norm(x_2)
+ x_1, x_2 = self.attn(x_1, x_2)
+ return x_1, x_2
+
+
+class TransFusion(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int = 768,
+ num_layers: int = 6,
+ mlp_dim: int = 1536,
+ dropout_rate: float = 0.5,
+ num_heads: int = 8,
+ atte_dropout_rate: float = 0.0,
+ roi_size: Union[Sequence[int], int] = (64, 64, 64),
+ scale: int = 16,
+ cross_attention_in_origin_view: bool = False,
+ ):
+ super().__init__()
+ if isinstance(roi_size, int):
+ roi_size = [roi_size for _ in range(3)]
+ self.cross_attention_in_origin_view = cross_attention_in_origin_view
+ patch_size = (1, 1, 1)
+ n_patches = (
+ (roi_size[0] // patch_size[0] // scale)
+ * (roi_size[1] // patch_size[1] // scale)
+ * (roi_size[2] // patch_size[2] // scale)
+ )
+ self.layer = nn.ModuleList()
+ self.encoder_norm = nn.LayerNorm(hidden_size, eps=1e-6)
+ self.patch_embeddings = nn.Conv3d(
+ in_channels=hidden_size, out_channels=hidden_size, kernel_size=patch_size, stride=patch_size
+ )
+ self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, hidden_size))
+ self.dropout = nn.Dropout(dropout_rate)
+ for _ in range(num_layers):
+ layer = Block(
+ hidden_size=hidden_size,
+ mlp_dim=mlp_dim,
+ dropout_rate=dropout_rate,
+ num_heads=num_heads,
+ atte_dropout_rate=atte_dropout_rate,
+ )
+ self.layer.append(copy.deepcopy(layer))
+
+ def forward(self, x_1, x_2, view_list):
+ if self.cross_attention_in_origin_view:
+ x_1, x_2 = permute_inverse([x_1, x_2], view_list)
+ else:
+ # Align x_2 to x_1.
+ x_2 = get_permute_transform(*view_list[::-1])(x_2)
+ x_1 = self.patch_embeddings(x_1)
+ x_2 = self.patch_embeddings(x_2)
+ x_1 = x_1.flatten(2).transpose(-1, -2)
+ x_2 = x_2.flatten(2).transpose(-1, -2)
+ x_1 = x_1 + self.position_embeddings
+ x_2 = x_2 + self.position_embeddings
+ x_1 = self.dropout(x_1)
+ x_2 = self.dropout(x_2)
+ for layer_block in self.layer:
+ x_1, x_2 = layer_block(x_1, x_2)
+ x_1 = self.encoder_norm(x_1)
+ x_2 = self.encoder_norm(x_2)
+ B, n_patch, hidden = x_1.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden)
+ l, h, w = int(np.cbrt(n_patch)), int(np.cbrt(n_patch)), int(np.cbrt(n_patch))
+ x_1 = x_1.permute(0, 2, 1).contiguous().view(B, hidden, l, h, w)
+ x_2 = x_2.permute(0, 2, 1).contiguous().view(B, hidden, l, h, w)
+ if self.cross_attention_in_origin_view:
+ x_1, x_2 = permute_inverse([x_1, x_2], view_list)
+ else:
+ x_2 = get_permute_transform(*view_list)(x_2)
+
+ return x_1, x_2
diff --git a/SwinMM/WORD/models/swin_unetr.py b/SwinMM/WORD/models/swin_unetr.py
new file mode 100644
index 00000000..ea7f96fa
--- /dev/null
+++ b/SwinMM/WORD/models/swin_unetr.py
@@ -0,0 +1,128 @@
+"""SwinUNETR with cross attention."""
+from typing import MutableMapping, Sequence, Tuple, Union
+
+import torch
+from models import cross_attention
+
+from monai.networks import blocks
+from monai.networks.nets import swin_unetr
+
+__all__ = ["SwinUNETR"]
+
+FeaturesDictType = MutableMapping[str, torch.Tensor]
+
+
+class SwinUNETR(swin_unetr.SwinUNETR):
+ """SwinUNETR with cross attention."""
+
+ def __init__(
+ self,
+ img_size: Union[Sequence[int], int],
+ *args,
+ num_heads: Sequence[int] = (3, 6, 12, 24),
+ feature_size: int = 24,
+ norm_name: Union[Tuple, str] = "instance",
+ drop_rate: float = 0.0,
+ attn_drop_rate: float = 0.0,
+ spatial_dims: int = 3,
+ fusion_depths: Sequence[int] = (2, 2, 2, 2, 2, 2),
+ cross_attention_in_origin_view: bool = False,
+ **kwargs,
+ ) -> None:
+ """
+ Args:
+ fusion_depths: TODO(yiqing).
+ cross_attention_in_origin_view: A bool indicates whether compute cross attention in origin view.
+ If not, compute cross attention in the view of the first input.
+
+ """
+ super().__init__(
+ img_size,
+ *args,
+ num_heads=num_heads,
+ feature_size=feature_size,
+ norm_name=norm_name,
+ spatial_dims=spatial_dims,
+ drop_rate=drop_rate,
+ attn_drop_rate=attn_drop_rate,
+ **kwargs,
+ )
+
+ self.encoder5 = blocks.UnetrBasicBlock(
+ spatial_dims=spatial_dims,
+ in_channels=8 * feature_size,
+ out_channels=8 * feature_size,
+ kernel_size=3,
+ stride=1,
+ norm_name=norm_name,
+ res_block=True,
+ )
+
+ self.cross_atte6 = cross_attention.TransFusion(
+ hidden_size=feature_size * 16,
+ num_layers=fusion_depths[5],
+ mlp_dim=feature_size * 32,
+ num_heads=num_heads[3],
+ dropout_rate=drop_rate,
+ atte_dropout_rate=attn_drop_rate,
+ roi_size=img_size,
+ scale=32,
+ cross_attention_in_origin_view=cross_attention_in_origin_view,
+ )
+
+ def forward_view_encoder(self, x):
+ """Encode features."""
+ x_hiddens = self.swinViT(x, self.normalize)
+ x_enc0 = self.encoder1(x)
+ x_enc1 = self.encoder2(x_hiddens[0])
+ x_enc2 = self.encoder3(x_hiddens[1])
+ x_enc3 = self.encoder4(x_hiddens[2])
+ x_enc4 = self.encoder5(x_hiddens[3]) # xa_hidden[3]
+ x_dec4 = self.encoder10(x_hiddens[4])
+ return {"enc0": x_enc0, "enc1": x_enc1, "enc2": x_enc2, "enc3": x_enc3, "enc4": x_enc4, "dec4": x_dec4}
+
+ def forward_view_decoder(self, x_encoded: FeaturesDictType) -> torch.Tensor:
+ """Decode features."""
+ x_dec3 = self.decoder5(x_encoded["dec4"], x_encoded["enc4"])
+ x_dec2 = self.decoder4(x_dec3, x_encoded["enc3"])
+ x_dec1 = self.decoder3(x_dec2, x_encoded["enc2"])
+ x_dec0 = self.decoder2(x_dec1, x_encoded["enc1"])
+ x_out = self.decoder1(x_dec0, x_encoded["enc0"])
+ x_logits = self.out(x_out)
+ return x_logits
+
+ def forward_view_cross_attention(
+ self, xa_encoded: FeaturesDictType, xb_encoded: FeaturesDictType, views: Sequence[int]
+ ) -> Tuple[FeaturesDictType, FeaturesDictType]:
+ """Inplace cross attention between views."""
+ xa_encoded["dec4"], xb_encoded["dec4"] = self.cross_atte6(xa_encoded["dec4"], xb_encoded["dec4"], views)
+ return xa_encoded, xb_encoded
+
+ def forward(self, xa: torch.Tensor, xb: torch.Tensor, views: Sequence[int]) -> Sequence[torch.Tensor]:
+ """Two views forward."""
+ xa_encoded = self.forward_view_encoder(xa)
+ xb_encoded = self.forward_view_encoder(xb)
+
+ xa_encoded, xb_encoded = self.forward_view_cross_attention(xa_encoded, xb_encoded, views)
+ return [self.forward_view_decoder(val) for val in [xa_encoded, xb_encoded]]
+
+ def no_weight_decay(self):
+ """Disable weight_decay on specific weights."""
+ nwd = {"swinViT.absolute_pos_embed"}
+ for n, _ in self.named_parameters():
+ if "relative_position_bias_table" in n:
+ nwd.add(n)
+ return nwd
+
+ def group_matcher(self, coarse=False):
+ """Layer counting helper, used by timm."""
+ return dict(
+ stem=r"^swinViT\.absolute_pos_embed|patch_embed", # stem and embed
+ blocks=r"^swinViT\.layers(\d+)\.0"
+ if coarse
+ else [
+ (r"^swinViT\.layers(\d+)\.0.downsample", (0,)),
+ (r"^swinViT\.layers(\d+)\.0\.\w+\.(\d+)", None),
+ (r"^swinViT\.norm", (99999,)),
+ ],
+ )
diff --git a/SwinMM/WORD/optimizers/__init__.py b/SwinMM/WORD/optimizers/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/SwinMM/WORD/optimizers/lr_scheduler.py b/SwinMM/WORD/optimizers/lr_scheduler.py
new file mode 100644
index 00000000..0c352927
--- /dev/null
+++ b/SwinMM/WORD/optimizers/lr_scheduler.py
@@ -0,0 +1,172 @@
+# Copyright 2020 - 2021 MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import warnings
+from typing import List
+
+from torch import nn as nn
+from torch.optim import Adam, Optimizer
+from torch.optim.lr_scheduler import LambdaLR, _LRScheduler
+
+__all__ = ["LinearLR", "ExponentialLR"]
+
+
+class _LRSchedulerMONAI(_LRScheduler):
+ """Base class for increasing the learning rate between two boundaries over a number
+ of iterations"""
+
+ def __init__(self, optimizer: Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1) -> None:
+ """
+ Args:
+ optimizer: wrapped optimizer.
+ end_lr: the final learning rate.
+ num_iter: the number of iterations over which the test occurs.
+ last_epoch: the index of last epoch.
+ Returns:
+ None
+ """
+ self.end_lr = end_lr
+ self.num_iter = num_iter
+ super(_LRSchedulerMONAI, self).__init__(optimizer, last_epoch)
+
+
+class LinearLR(_LRSchedulerMONAI):
+ """Linearly increases the learning rate between two boundaries over a number of
+ iterations.
+ """
+
+ def get_lr(self):
+ r = self.last_epoch / (self.num_iter - 1)
+ return [base_lr + r * (self.end_lr - base_lr) for base_lr in self.base_lrs]
+
+
+class ExponentialLR(_LRSchedulerMONAI):
+ """Exponentially increases the learning rate between two boundaries over a number of
+ iterations.
+ """
+
+ def get_lr(self):
+ r = self.last_epoch / (self.num_iter - 1)
+ return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs]
+
+
+class WarmupCosineSchedule(LambdaLR):
+ """Linear warmup and then cosine decay.
+ Based on https://huggingface.co/ implementation.
+ """
+
+ def __init__(
+ self, optimizer: Optimizer, warmup_steps: int, t_total: int, cycles: float = 0.5, last_epoch: int = -1
+ ) -> None:
+ """
+ Args:
+ optimizer: wrapped optimizer.
+ warmup_steps: number of warmup iterations.
+ t_total: total number of training iterations.
+ cycles: cosine cycles parameter.
+ last_epoch: the index of last epoch.
+ Returns:
+ None
+ """
+ self.warmup_steps = warmup_steps
+ self.t_total = t_total
+ self.cycles = cycles
+ super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch)
+
+ def lr_lambda(self, step):
+ if step < self.warmup_steps:
+ return float(step) / float(max(1.0, self.warmup_steps))
+ progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps))
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.cycles) * 2.0 * progress)))
+
+
+class LinearWarmupCosineAnnealingLR(_LRScheduler):
+ def __init__(
+ self,
+ optimizer: Optimizer,
+ warmup_epochs: int,
+ max_epochs: int,
+ warmup_start_lr: float = 0.0,
+ eta_min: float = 0.0,
+ last_epoch: int = -1,
+ ) -> None:
+ """
+ Args:
+ optimizer (Optimizer): Wrapped optimizer.
+ warmup_epochs (int): Maximum number of iterations for linear warmup
+ max_epochs (int): Maximum number of iterations
+ warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0.
+ eta_min (float): Minimum learning rate. Default: 0.
+ last_epoch (int): The index of last epoch. Default: -1.
+ """
+ self.warmup_epochs = warmup_epochs
+ self.max_epochs = max_epochs
+ self.warmup_start_lr = warmup_start_lr
+ self.eta_min = eta_min
+
+ super(LinearWarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch)
+
+ def get_lr(self) -> List[float]:
+ """
+ Compute learning rate using chainable form of the scheduler
+ """
+ if not self._get_lr_called_within_step:
+ warnings.warn(
+ "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", UserWarning
+ )
+
+ if self.last_epoch == 0:
+ return [self.warmup_start_lr] * len(self.base_lrs)
+ elif self.last_epoch < self.warmup_epochs:
+ return [
+ group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
+ for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
+ ]
+ elif self.last_epoch == self.warmup_epochs:
+ return self.base_lrs
+ elif (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0:
+ return [
+ group["lr"]
+ + (base_lr - self.eta_min) * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2
+ for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
+ ]
+
+ return [
+ (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
+ / (
+ 1
+ + math.cos(
+ math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs)
+ )
+ )
+ * (group["lr"] - self.eta_min)
+ + self.eta_min
+ for group in self.optimizer.param_groups
+ ]
+
+ def _get_closed_form_lr(self) -> List[float]:
+ """
+ Called when epoch is passed as a param to the `step` function of the scheduler.
+ """
+ if self.last_epoch < self.warmup_epochs:
+ return [
+ self.warmup_start_lr + self.last_epoch * (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
+ for base_lr in self.base_lrs
+ ]
+
+ return [
+ self.eta_min
+ + 0.5
+ * (base_lr - self.eta_min)
+ * (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
+ for base_lr in self.base_lrs
+ ]
diff --git a/SwinMM/WORD/outputs/__init__.py b/SwinMM/WORD/outputs/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/SwinMM/WORD/pretrained_models/__init__.py b/SwinMM/WORD/pretrained_models/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/SwinMM/WORD/requirements.txt b/SwinMM/WORD/requirements.txt
new file mode 100644
index 00000000..cdfb84e3
--- /dev/null
+++ b/SwinMM/WORD/requirements.txt
@@ -0,0 +1 @@
+timm>=0.6
diff --git a/SwinMM/WORD/run.sh b/SwinMM/WORD/run.sh
new file mode 100644
index 00000000..95847627
--- /dev/null
+++ b/SwinMM/WORD/run.sh
@@ -0,0 +1,9 @@
+python -m torch.distributed.launch --nproc_per_node=8 --master_port=11223 main.py --batch_size=2 \
+ --num_steps=30000 \
+ --lrdecay \
+ --eval_num=500 \
+ --lr=5e-4 \
+ --decay=0.1 \
+ --norm_pix_loss \
+ --redis_ports 39996 39997 39998 39999 \
+ --redis_compression zlib
diff --git a/SwinMM/WORD/runs/__init__.py b/SwinMM/WORD/runs/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/SwinMM/WORD/test.py b/SwinMM/WORD/test.py
new file mode 100644
index 00000000..98199d25
--- /dev/null
+++ b/SwinMM/WORD/test.py
@@ -0,0 +1,160 @@
+# Copyright 2020 - 2022 MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import os
+
+import nibabel as nib
+import numpy as np
+import torch
+from inferers import double_sliding_window_inference
+from models import SwinUNETR
+from utils.data_utils import get_loader
+from utils.misc import resample_3d
+
+from monai.metrics import compute_average_surface_distance, compute_hausdorff_distance, compute_meandice
+from monai.networks.utils import one_hot
+from monai.transforms import Spacing
+
+parser = argparse.ArgumentParser(description="Swin UNETR segmentation pipeline")
+parser.add_argument(
+ "--pretrained_dir", default="./runs/multiview_101616/", type=str, help="pretrained checkpoint directory"
+)
+parser.add_argument("--data_dir", default="./dataset/dataset12_WORD/", type=str, help="dataset directory")
+parser.add_argument("--exp_name", default="multiview_101616/", type=str, help="experiment name")
+parser.add_argument("--json_list", default="dataset12_WORD.json", type=str, help="dataset json file")
+parser.add_argument("--pretrained_model_name", default="model.pt", type=str, help="pretrained model name")
+parser.add_argument("--feature_size", default=48, type=int, help="feature size")
+parser.add_argument("--infer_overlap", default=0.7, type=float, help="sliding window inference overlap")
+parser.add_argument("--in_channels", default=1, type=int, help="number of input channels")
+parser.add_argument("--out_channels", default=17, type=int, help="number of output channels")
+parser.add_argument("--a_min", default=-175.0, type=float, help="a_min in ScaleIntensityRanged")
+parser.add_argument("--a_max", default=250.0, type=float, help="a_max in ScaleIntensityRanged")
+parser.add_argument("--b_min", default=0.0, type=float, help="b_min in ScaleIntensityRanged")
+parser.add_argument("--b_max", default=1.0, type=float, help="b_max in ScaleIntensityRanged")
+parser.add_argument("--space_x", default=1.5, type=float, help="spacing in x direction")
+parser.add_argument("--space_y", default=1.5, type=float, help="spacing in y direction")
+parser.add_argument("--space_z", default=2.0, type=float, help="spacing in z direction")
+parser.add_argument("--roi_x", default=64, type=int, help="roi size in x direction")
+parser.add_argument("--roi_y", default=64, type=int, help="roi size in y direction")
+parser.add_argument("--roi_z", default=64, type=int, help="roi size in z direction")
+parser.add_argument("--dropout_rate", default=0.0, type=float, help="dropout rate")
+parser.add_argument("--distributed", action="store_true", help="start distributed training")
+parser.add_argument("--workers", default=8, type=int, help="number of workers")
+parser.add_argument("--RandFlipd_prob", default=0.2, type=float, help="RandFlipd aug probability")
+parser.add_argument("--RandRotate90d_prob", default=0.2, type=float, help="RandRotate90d aug probability")
+parser.add_argument("--RandScaleIntensityd_prob", default=0.1, type=float, help="RandScaleIntensityd aug probability")
+parser.add_argument("--RandShiftIntensityd_prob", default=0.1, type=float, help="RandShiftIntensityd aug probability")
+parser.add_argument("--spatial_dims", default=3, type=int, help="spatial dimension of input data")
+parser.add_argument("--use_checkpoint", action="store_true", help="use gradient checkpointing to save memory")
+
+spacing = Spacing(pixdim=(1, 1, 1), mode="nearest")
+hd_per = 95
+view = ["Cor1", "Sag2", "Sag1", "Axi2", "Axi1", "Cor2", "Fuse"]
+
+
+def main():
+ args = parser.parse_args()
+ args.test_mode = True
+ output_directory = "./outputs/" + args.exp_name
+ if not os.path.exists(output_directory):
+ os.makedirs(output_directory)
+ val_loader = get_loader(args)
+ pretrained_dir = args.pretrained_dir
+ model_name = args.pretrained_model_name
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ pretrained_pth = os.path.join(pretrained_dir, model_name)
+ model = SwinUNETR(
+ img_size=(args.roi_x, args.roi_y, args.roi_z),
+ in_channels=args.in_channels,
+ out_channels=args.out_channels,
+ feature_size=args.feature_size,
+ fusion_depths=(1, 1, 1, 1, 1, 1),
+ drop_rate=0.0,
+ attn_drop_rate=0.0,
+ use_checkpoint=args.use_checkpoint,
+ )
+
+ model_dict = torch.load(pretrained_pth)["state_dict"]
+ model.load_state_dict(model_dict)
+ model.eval()
+ model.to(device)
+ dice_out = np.zeros((len(val_loader), len(view), args.out_channels - 1))
+ hd_out = np.zeros((len(val_loader), len(view), args.out_channels - 1))
+ asd_out = np.zeros((len(val_loader), len(view), args.out_channels - 1))
+
+ with torch.no_grad():
+ for id, batch in enumerate(val_loader):
+ val_inputs, val_labels = (batch["image"].to(device), batch["label"].to(device))
+ original_affine = batch["label_meta_dict"]["affine"][0].numpy()
+ _, _, h, w, d = val_labels.shape
+ target_shape = (h, w, d)
+ val_labels = val_labels.cpu().numpy()[0, :, :, :, :]
+ img_name = batch["image_meta_dict"]["filename_or_obj"][0].split("/")[-1]
+ print("Inference on case {}".format(img_name))
+ torch.cuda.empty_cache()
+ output_list = []
+ val_fuse = 0
+
+ val_labels = spacing(val_labels, original_affine)[0]
+ val_labels = np.expand_dims(val_labels, axis=0)
+ val_labels = one_hot(torch.from_numpy(val_labels), num_classes=args.out_channels, dim=1)
+
+ for i in range(3):
+ val_outputs_1, val_outputs_2 = double_sliding_window_inference(
+ val_inputs,
+ i,
+ (args.roi_x, args.roi_y, args.roi_z),
+ 16,
+ model,
+ overlap=args.infer_overlap,
+ mode="gaussian",
+ )
+
+ val_outputs_1 = torch.softmax(val_outputs_1, 1).cpu().numpy()[0]
+ val_outputs_2 = torch.softmax(val_outputs_2, 1).cpu().numpy()[0]
+ val_fuse = val_fuse + val_outputs_1 + val_outputs_2
+ output_list.append(val_outputs_1)
+ output_list.append(val_outputs_2)
+ output_list.append(val_fuse)
+
+ for i, output in enumerate(output_list):
+ output = np.argmax(output, axis=0, keepdims=False)
+ output = resample_3d(output, target_shape)
+ target_ornt = nib.orientations.axcodes2ornt(tuple(nib.aff2axcodes(original_affine)))
+ out_ornt = [[0, 1], [1, 1], [2, 1]]
+ ornt_transf = nib.orientations.ornt_transform(out_ornt, target_ornt)
+ output = nib.orientations.apply_orientation(output, ornt_transf)
+ nib.save(
+ nib.Nifti1Image(output[::-1, ::-1, :].astype(np.uint8), affine=original_affine),
+ os.path.join(output_directory, view[i] + "_" + img_name),
+ )
+ output = np.expand_dims(spacing(np.expand_dims(output, axis=(0)), original_affine)[0], axis=0)
+ output = one_hot(torch.from_numpy(output), num_classes=args.out_channels, dim=1)
+ print(output.shape, val_labels.shape)
+ dice_ = compute_meandice(output, val_labels, include_background=False).numpy()[0]
+ hd_ = compute_hausdorff_distance(output, val_labels, percentile=hd_per).numpy()[0]
+ asd_ = compute_average_surface_distance(output, val_labels).numpy()[0]
+ print("{} View Mean Dice: {}".format(view[i], np.mean(dice_)))
+ print("{} View Mean HD: {}".format(view[i], np.mean(hd_)))
+ print("{} View Mean ASD: {}".format(view[i], np.mean(asd_)))
+ dice_out[id, i, :] = dice_
+ hd_out[id, i, :] = hd_
+ asd_out[id, i, :] = asd_
+
+ for i in range(len(view)):
+ print("Overall {} View Mean Dice: {}".format(view[i], np.mean(dice_out[:, i, :], axis=0)))
+ print("Overall {} View Mean HD: {}".format(view[i], np.mean(hd_out[:, i, :], axis=0)))
+ print("Overall {} View Mean ASD: {}".format(view[i], np.mean(asd_out[:, i, :], axis=0)))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/SwinMM/WORD/test_parallel.py b/SwinMM/WORD/test_parallel.py
new file mode 100644
index 00000000..a1e8a358
--- /dev/null
+++ b/SwinMM/WORD/test_parallel.py
@@ -0,0 +1,232 @@
+import argparse
+import logging
+import os
+
+import nibabel as nib
+import numpy as np
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+from inferers import double_sliding_window_inference
+from models import SwinUNETR
+from timm.utils import setup_default_logging
+from utils.data_utils import get_loader
+from utils.misc import dice, distributed_all_gather, resample_3d
+
+from monai.metrics import compute_average_surface_distance, compute_hausdorff_distance, compute_meandice
+from monai.networks.utils import one_hot
+from monai.transforms import Spacing
+
+parser = argparse.ArgumentParser(description="Swin UNETR segmentation pipeline")
+parser.add_argument(
+ "--pretrained_dir", default="./runs/multiview_101616/", type=str, help="pretrained checkpoint directory"
+)
+parser.add_argument("--data_dir", default="./dataset/dataset12_WORD/", type=str, help="dataset directory")
+parser.add_argument("--exp_name", default="multiview_101616/", type=str, help="experiment name")
+parser.add_argument("--json_list", default="dataset12_WORD.json", type=str, help="dataset json file")
+parser.add_argument("--pretrained_model_name", default="model.pt", type=str, help="pretrained model name")
+parser.add_argument("--feature_size", default=48, type=int, help="feature size")
+parser.add_argument("--infer_overlap", default=0.7, type=float, help="sliding window inference overlap")
+parser.add_argument("--in_channels", default=1, type=int, help="number of input channels")
+parser.add_argument("--out_channels", default=17, type=int, help="number of output channels")
+parser.add_argument("--a_min", default=-175.0, type=float, help="a_min in ScaleIntensityRanged")
+parser.add_argument("--a_max", default=250.0, type=float, help="a_max in ScaleIntensityRanged")
+parser.add_argument("--b_min", default=0.0, type=float, help="b_min in ScaleIntensityRanged")
+parser.add_argument("--b_max", default=1.0, type=float, help="b_max in ScaleIntensityRanged")
+parser.add_argument("--space_x", default=1.5, type=float, help="spacing in x direction")
+parser.add_argument("--space_y", default=1.5, type=float, help="spacing in y direction")
+parser.add_argument("--space_z", default=2.0, type=float, help="spacing in z direction")
+parser.add_argument("--roi_x", default=64, type=int, help="roi size in x direction")
+parser.add_argument("--roi_y", default=64, type=int, help="roi size in y direction")
+parser.add_argument("--roi_z", default=64, type=int, help="roi size in z direction")
+parser.add_argument("--dropout_rate", default=0.0, type=float, help="dropout rate")
+parser.add_argument("--distributed", action="store_true", help="start distributed training")
+parser.add_argument("--world_size", default=1, type=int, help="number of nodes for distributed training")
+parser.add_argument("--rank", default=0, type=int, help="node rank for distributed training")
+parser.add_argument("--dist-url", default="tcp://127.0.0.1:23456", type=str, help="distributed url")
+parser.add_argument("--dist-backend", default="nccl", type=str, help="distributed backend")
+parser.add_argument("--workers", default=8, type=int, help="number of workers")
+parser.add_argument("--use_normal_dataset", action="store_true", help="use monai Dataset class")
+parser.add_argument(
+ "--nouse_multi_epochs_loader",
+ action="store_true",
+ help="not use the multi-epochs-loader to save time at the beginning of every epoch",
+)
+parser.add_argument("--RandFlipd_prob", default=0.2, type=float, help="RandFlipd aug probability")
+parser.add_argument("--RandRotate90d_prob", default=0.2, type=float, help="RandRotate90d aug probability")
+parser.add_argument("--RandScaleIntensityd_prob", default=0.1, type=float, help="RandScaleIntensityd aug probability")
+parser.add_argument("--RandShiftIntensityd_prob", default=0.1, type=float, help="RandShiftIntensityd aug probability")
+parser.add_argument("--spatial_dims", default=3, type=int, help="spatial dimension of input data")
+parser.add_argument("--use_checkpoint", action="store_true", help="use gradient checkpointing to save memory")
+parser.add_argument(
+ "--cross_attention_in_origin_view", action="store_true", help="Whether compute cross attention in original view"
+)
+
+spacing = Spacing(pixdim=(1, 1, 1), mode="nearest")
+hd_per = 95
+view = ["Cor1", "Sag2", "Sag1", "Axi2", "Axi1", "Cor2", "Fuse"]
+
+
+def main():
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
+ args = parser.parse_args()
+ output_directory = "./outputs/" + args.exp_name
+ if not os.path.exists(output_directory):
+ os.makedirs(output_directory)
+ if args.distributed:
+ args.ngpus_per_node = torch.cuda.device_count()
+ print("Found total gpus", args.ngpus_per_node)
+ args.world_size = args.ngpus_per_node * args.world_size
+ mp.spawn(main_worker, nprocs=args.ngpus_per_node, args=(args,))
+ else:
+ main_worker(gpu=0, args=args)
+
+
+def main_worker(gpu, args):
+ output_directory = "./outputs/" + args.exp_name
+ if args.distributed:
+ torch.multiprocessing.set_start_method("fork", force=True)
+ # np.set_printoptions(formatter={"float": "{: 0.3f}".format}, suppress=True)
+ args.gpu = gpu
+ if args.distributed:
+ args.rank = args.rank * args.ngpus_per_node + gpu
+ dist.init_process_group(
+ backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
+ )
+ torch.cuda.set_device(args.gpu)
+ torch.backends.cudnn.benchmark = True
+ args.test_mode = True
+ val_loader = get_loader(args)
+ print(args.rank, " gpu", args.gpu)
+ if args.rank == 0:
+ setup_default_logging()
+ # logging.info(f"Batch size is: {args.batch_size}, epochs: {args.max_epochs}")
+
+ pretrained_dir = args.pretrained_dir
+ model_name = args.pretrained_model_name
+ pretrained_pth = os.path.join(pretrained_dir, model_name)
+ model = SwinUNETR(
+ img_size=(args.roi_x, args.roi_y, args.roi_z),
+ in_channels=args.in_channels,
+ out_channels=args.out_channels,
+ feature_size=args.feature_size,
+ fusion_depths=(1, 1, 1, 1, 1, 1),
+ drop_rate=0.0,
+ attn_drop_rate=0.0,
+ use_checkpoint=args.use_checkpoint,
+ cross_attention_in_origin_view=args.cross_attention_in_origin_view,
+ )
+ model.load_state_dict(torch.load(pretrained_pth, map_location="cpu")["state_dict"])
+ model.cuda(args.gpu)
+ model.eval()
+ model_without_ddp = model
+ if args.distributed:
+ torch.cuda.set_device(args.gpu)
+ model.cuda(args.gpu)
+ model_without_ddp = model
+ model = torch.nn.parallel.DistributedDataParallel(
+ model, device_ids=[args.gpu], output_device=args.gpu, broadcast_buffers=False, find_unused_parameters=True
+ )
+
+ dice_all = [] # np.zeros((len(val_loader), len(view), args.out_channels - 1))
+ hd_all = [] # np.zeros((len(val_loader), len(view), args.out_channels - 1))
+ asd_all = [] # np.zeros((len(val_loader), len(view), args.out_channels - 1))
+
+ with torch.no_grad():
+ for id, batch in enumerate(val_loader):
+ val_inputs, val_labels = (batch["image"].cuda(args.gpu), batch["label"].cpu())
+ original_affine = batch["label_meta_dict"]["affine"][0].numpy()
+ _, _, h, w, d = val_labels.shape
+ target_shape = (h, w, d)
+ val_labels = val_labels.numpy()[0, :, :, :, :]
+ img_name = batch["image_meta_dict"]["filename_or_obj"][0].split("/")[-1]
+ print("Inference on case {}".format(img_name))
+ torch.cuda.empty_cache()
+ output_list = []
+ val_fuse = 0
+
+ val_labels = spacing(val_labels, original_affine)[0]
+ val_labels = np.expand_dims(val_labels, axis=0)
+ val_labels = one_hot(torch.from_numpy(val_labels), num_classes=args.out_channels, dim=1)
+
+ for i in range(3):
+ # j = i + 3
+ val_outputs_1, val_outputs_2 = double_sliding_window_inference(
+ val_inputs,
+ i,
+ (args.roi_x, args.roi_y, args.roi_z),
+ 16,
+ model,
+ overlap=args.infer_overlap,
+ mode="gaussian",
+ )
+
+ val_outputs_1 = torch.softmax(val_outputs_1, 1).cpu().numpy()[0]
+ val_outputs_2 = torch.softmax(val_outputs_2, 1).cpu().numpy()[0]
+ val_fuse = val_fuse + val_outputs_1 + val_outputs_2
+ output_list.append(val_outputs_1)
+ output_list.append(val_outputs_2)
+ output_list.append(val_fuse)
+ print("Inference finished on case {}".format(img_name))
+
+ dice = np.zeros((len(view), args.out_channels - 1))
+ hd = np.zeros((len(view), args.out_channels - 1))
+ asd = np.zeros((len(view), args.out_channels - 1))
+ for i, output in enumerate(output_list):
+ output = np.argmax(output, axis=0, keepdims=False)
+ output = resample_3d(output, target_shape)
+ target_ornt = nib.orientations.axcodes2ornt(tuple(nib.aff2axcodes(original_affine)))
+ out_ornt = [[0, 1], [1, 1], [2, 1]]
+ ornt_transf = nib.orientations.ornt_transform(out_ornt, target_ornt)
+ output = nib.orientations.apply_orientation(output, ornt_transf)
+ nib.save(
+ nib.Nifti1Image(output.astype(np.uint8), affine=original_affine),
+ os.path.join(output_directory, view[i] + "_" + img_name),
+ )
+ output = np.expand_dims(spacing(np.expand_dims(output, axis=(0)), original_affine)[0], axis=0)
+ output = one_hot(torch.from_numpy(output), num_classes=args.out_channels, dim=1)
+ # print(output.shape, val_labels.shape)
+ dice_ = compute_meandice(output, val_labels, include_background=False).numpy()[0]
+ hd_ = compute_hausdorff_distance(output, val_labels, percentile=hd_per).numpy()[0]
+ asd_ = compute_average_surface_distance(output, val_labels).numpy()[0]
+ print("{} {} View Mean Dice: {}".format(img_name, view[i], np.mean(dice_)))
+ print("{} {} View Mean HD: {}".format(img_name, view[i], np.mean(hd_)))
+ print("{} {} View Mean ASD: {}".format(img_name, view[i], np.mean(asd_)))
+ dice[i, :] = dice_
+ hd[i, :] = hd_
+ asd[i, :] = asd_
+ dice_all.append(dice)
+ hd_all.append(hd)
+ asd_all.append(asd)
+
+ dice_all = torch.tensor(np.stack(dice_all, axis=0)).cuda(args.gpu)
+ hd_all = torch.tensor(np.stack(hd_all, axis=0)).cuda(args.gpu)
+ asd_all = torch.tensor(np.stack(asd_all, axis=0)).cuda(args.gpu)
+ dice_list = distributed_all_gather([dice_all], out_numpy=False, is_valid=True)
+ hd_list = distributed_all_gather([hd_all], out_numpy=False, is_valid=True)
+ asd_list = distributed_all_gather([asd_all], out_numpy=False, is_valid=True)
+ dice_list = torch.flatten(torch.stack(dice_list[0], axis=0), start_dim=0, end_dim=1).cpu().numpy()
+ hd_list = torch.flatten(torch.stack(hd_list[0], axis=0), start_dim=0, end_dim=1).cpu().numpy()
+ asd_list = torch.flatten(torch.stack(asd_list[0], axis=0), start_dim=0, end_dim=1).cpu().numpy()
+
+ if args.rank == 0:
+ for i in range(len(view)):
+ print(dice_list.shape)
+ print("Overall {} View Mean Dice: {}".format(view[i], np.mean(dice_list[:, i, :], axis=0)))
+ print("Overall {} View Mean HD: {}".format(view[i], np.mean(hd_list[:, i, :], axis=0)))
+ print("Overall {} View Mean ASD: {}".format(view[i], np.mean(asd_list[:, i, :], axis=0)))
+ np.savetxt(
+ os.path.join(output_directory, view[i] + "Dice.txt"),
+ np.mean(dice_list[:, i, :], axis=0),
+ delimiter="\t",
+ )
+ np.savetxt(
+ os.path.join(output_directory, view[i] + "HD.txt"), np.mean(hd_list[:, i, :], axis=0), delimiter="\t"
+ )
+ np.savetxt(
+ os.path.join(output_directory, view[i] + "ASD.txt"), np.mean(asd_list[:, i, :], axis=0), delimiter="\t"
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/SwinMM/WORD/trainer.py b/SwinMM/WORD/trainer.py
new file mode 100644
index 00000000..5ea97b51
--- /dev/null
+++ b/SwinMM/WORD/trainer.py
@@ -0,0 +1,296 @@
+# Copyright 2020 - 2022 MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import time
+
+import numpy as np
+import torch
+import torch.distributed
+import torch.nn.functional as F
+import torch.nn.parallel
+import torch.utils.data.distributed
+from torch.cuda.amp import GradScaler, autocast
+from torch.utils.tensorboard import SummaryWriter
+from utils import view_ops
+from utils.misc import AverageMeter, distributed_all_gather
+
+from monai.data import decollate_batch
+
+
+def train_epoch(model, loader, optimizer, scaler, epoch, self_crit, mutual_crit, args):
+ model.train()
+ start_time = time.time()
+ run_loss = AverageMeter()
+ run_self_loss = AverageMeter()
+ run_mutual_loss = AverageMeter()
+
+ for idx, batch_data in enumerate(loader):
+ for param in model.parameters():
+ param.grad = None
+ if isinstance(batch_data, list):
+ data, target = batch_data
+ else:
+ data, target = batch_data["image"], batch_data["label"]
+ data = data.cuda(args.rank)
+ if not args.unsupervised:
+ target = target.cuda(args.rank)
+ data_list, view_list = view_ops.permute_rand(data)
+
+ loss = 0
+ self_loss_list, mutual_loss_list = [], []
+ with autocast(enabled=args.amp):
+ output1, output2 = model(data_list[0], data_list[1], view_list)
+ out_list = [output1, output2]
+ out_list = view_ops.permute_inverse(out_list, view_list)
+ if args.unsupervised:
+ target = torch.argmax(
+ (torch.softmax(out_list[0], dim=1) + torch.softmax(out_list[1], dim=1)) / 2, dim=1, keepdim=True
+ ).cuda(args.rank)
+ for i in range(len(out_list)):
+ self_loss = self_crit(out_list[i], target)
+ mutual_loss = 0
+ for j in range(len(out_list)): # KL divergence
+ if i != j:
+ mutual_end = mutual_crit(F.log_softmax(out_list[i], dim=1), F.softmax(out_list[j], dim=1))
+ mutual_loss += mutual_end
+ loss = loss + (self_loss + mutual_loss / (len(out_list) - 1)) / len(out_list)
+ self_loss_list.append(self_loss.item())
+ mutual_loss_list.append(mutual_loss.item())
+ self_loss = torch.mean(torch.tensor(self_loss_list)).cuda(args.rank)
+ mutual_loss = torch.mean(torch.tensor(mutual_loss_list)).cuda(args.rank)
+
+ if args.amp:
+ scaler.scale(loss).backward()
+ scaler.step(optimizer)
+ scaler.update()
+ else:
+ loss.backward()
+ optimizer.step()
+ if args.distributed:
+ is_valid = True
+ loss_list = distributed_all_gather([loss], out_numpy=True, is_valid=is_valid)
+ run_loss.update(
+ np.mean(np.mean(np.stack(loss_list, axis=0), axis=0), axis=0), n=args.batch_size * args.world_size
+ )
+ self_loss_list = distributed_all_gather([self_loss], out_numpy=True, is_valid=is_valid)
+ run_self_loss.update(
+ np.mean(np.mean(np.stack(self_loss_list, axis=0), axis=0), axis=0), n=args.batch_size * args.world_size
+ )
+ mutual_loss_list = distributed_all_gather([mutual_loss], out_numpy=True, is_valid=is_valid)
+ run_mutual_loss.update(
+ np.mean(np.mean(np.stack(mutual_loss_list, axis=0), axis=0), axis=0),
+ n=args.batch_size * args.world_size,
+ )
+ else:
+ run_loss.update(loss.item(), n=args.batch_size)
+ run_self_loss.update(self_loss.item(), n=args.batch_size)
+ run_mutual_loss.update(mutual_loss.item(), n=args.batch_size)
+ if args.rank == 0:
+ print(
+ "Epoch {}/{} {}/{}".format(epoch, args.max_epochs, idx, len(loader)),
+ "loss: {:.4f}".format(run_loss.avg),
+ "self_loss: {:.4f}".format(run_self_loss.avg),
+ "mutual_loss: {:.4f}".format(run_mutual_loss.avg),
+ "time {:.2f}s".format(time.time() - start_time),
+ )
+ start_time = time.time()
+ for param in model.parameters():
+ param.grad = None
+ return run_loss.avg, run_self_loss.avg, run_mutual_loss.avg
+
+
+def val_epoch(model, loader, epoch, acc_func, args, model_inferer=None, post_label=None, post_pred=None):
+ model.eval()
+ run_acc = AverageMeter()
+ start_time = time.time()
+ with torch.no_grad():
+ for idx, batch_data in enumerate(loader):
+ if isinstance(batch_data, list):
+ data, target = batch_data
+ else:
+ data, target = batch_data["image"], batch_data["label"]
+ data = data.cuda(args.rank)
+ torch.cuda.empty_cache()
+ with autocast(enabled=args.amp):
+ i = np.random.randint(0, 3)
+ if model_inferer is not None:
+ output1, output2 = model_inferer(data, i)
+ else:
+ output1, output2 = model(data, i)
+ output1, output2, target = output1.cpu(), output2.cpu(), target.cpu()
+ val_labels_list = decollate_batch(target)
+ val_labels_convert = [post_label(val_label_tensor) for val_label_tensor in val_labels_list]
+ out = (output1 + output2) / 2
+ val_outputs_list = decollate_batch(out)
+ val_output_convert = [post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list]
+ acc_func.reset()
+ acc_func(y_pred=val_output_convert, y=val_labels_convert)
+ acc, not_nans = acc_func.aggregate()
+ acc, not_nans = acc.cuda(args.rank), not_nans.cuda(args.rank)
+
+ if args.distributed:
+ is_valid = True
+ acc_list, not_nans_list = distributed_all_gather([acc, not_nans], out_numpy=True, is_valid=is_valid)
+ for al, nl in zip(acc_list, not_nans_list):
+ run_acc.update(al, n=nl)
+
+ else:
+ run_acc.update(acc.cpu().numpy(), n=not_nans.cpu().numpy())
+
+ if args.rank == 0:
+ avg_acc = np.mean(run_acc.avg)
+ print(
+ "Val {}/{} {}/{}".format(epoch, args.max_epochs, idx, len(loader)),
+ "acc",
+ avg_acc,
+ "time {:.2f}s".format(time.time() - start_time),
+ )
+ start_time = time.time()
+ return run_acc.avg
+
+
+def save_checkpoint(model, epoch, args, filename="model.pt", best_acc=0, optimizer=None, scheduler=None):
+ state_dict = model.state_dict() if not args.distributed else model.module.state_dict()
+ save_dict = {"epoch": epoch, "best_acc": best_acc, "state_dict": state_dict}
+ if optimizer is not None:
+ save_dict["optimizer"] = optimizer.state_dict()
+ if scheduler is not None:
+ save_dict["scheduler"] = scheduler.state_dict()
+ filename = os.path.join(args.logdir, filename)
+ torch.save(save_dict, filename)
+ print("Saving checkpoint", filename)
+
+
+def run_training(
+ model,
+ train_loader,
+ val_loader,
+ unsupervised_loader,
+ optimizer,
+ self_crit,
+ mutual_crit,
+ acc_func,
+ args,
+ model_inferer=None,
+ scheduler=None,
+ start_epoch=0,
+ post_label=None,
+ post_pred=None,
+):
+ writer = None
+ if args.logdir is not None and args.rank == 0:
+ writer = SummaryWriter(log_dir=args.logdir)
+ if args.rank == 0:
+ print("Writing Tensorboard logs to ", args.logdir)
+ scaler = None
+ if args.amp:
+ scaler = GradScaler()
+ val_acc_max = 0.0
+ for epoch in range(start_epoch, args.max_epochs):
+ if args.distributed:
+ train_loader.sampler.set_epoch(epoch)
+ torch.distributed.barrier()
+ print(args.rank, time.ctime(), "Epoch:", epoch)
+ epoch_time = time.time()
+ train_loss, self_loss, mutual_loss = train_epoch(
+ model,
+ train_loader,
+ optimizer,
+ scaler=scaler,
+ epoch=epoch,
+ self_crit=self_crit,
+ mutual_crit=mutual_crit,
+ args=args,
+ )
+ if args.rank == 0:
+ print(
+ "Final training {}/{}".format(epoch, args.max_epochs - 1),
+ "loss: {:.4f}".format(train_loss),
+ "self loss: {:.4f}".format(self_loss),
+ "mutual loss: {:.4f}".format(mutual_loss),
+ "time {:.2f}s".format(time.time() - epoch_time),
+ )
+ if args.rank == 0 and writer is not None:
+ writer.add_scalar("train_loss", train_loss, epoch)
+ writer.add_scalar("self_loss", self_loss, epoch)
+ writer.add_scalar("mutual_loss", mutual_loss, epoch)
+
+ if args.unsupervised and (epoch + 1) % args.unsuper_every == 0:
+ if args.distributed:
+ unsupervised_loader.sampler.set_epoch(epoch)
+ torch.distributed.barrier()
+ print(args.rank, time.ctime(), "Epoch:", epoch)
+ epoch_time = time.time()
+ train_loss, mutual_loss = train_epoch(
+ model,
+ unsupervised_loader,
+ optimizer,
+ scaler=scaler,
+ epoch=epoch,
+ self_crit=self_crit,
+ mutual_crit=mutual_crit,
+ args=args,
+ )
+ if args.rank == 0:
+ print(
+ "Final unsupervised training {}/{}".format(epoch, args.max_epochs - 1),
+ "loss: {:.4f}".format(train_loss),
+ "mutual loss: {:.4f}".format(mutual_loss),
+ "time {:.2f}s".format(time.time() - epoch_time),
+ )
+ if args.rank == 0 and writer is not None:
+ writer.add_scalar("train_unsupervised_loss", train_loss, epoch)
+ writer.add_scalar("unsupervised_self_loss", self_loss, epoch)
+ writer.add_scalar("unsupervised_mutual_loss", mutual_loss, epoch)
+
+ if epoch >= args.val_start and (epoch + 1) % args.val_every == 0:
+ if args.distributed:
+ torch.distributed.barrier()
+ epoch_time = time.time()
+ val_avg_acc = val_epoch(
+ model,
+ val_loader,
+ epoch=epoch,
+ acc_func=acc_func,
+ model_inferer=model_inferer,
+ args=args,
+ post_label=post_label,
+ post_pred=post_pred,
+ )
+
+ val_avg_acc = np.mean(val_avg_acc)
+
+ if args.rank == 0:
+ print(
+ "Final validation {}/{}".format(epoch, args.max_epochs - 1),
+ "acc",
+ val_avg_acc,
+ "time {:.2f}s".format(time.time() - epoch_time),
+ )
+ if writer is not None:
+ writer.add_scalar("val_acc", val_avg_acc, epoch)
+ if val_avg_acc > val_acc_max:
+ print("new best ({:.6f} --> {:.6f}). ".format(val_acc_max, val_avg_acc))
+ val_acc_max = val_avg_acc
+ if args.rank == 0 and args.logdir is not None and args.save_checkpoint:
+ save_checkpoint(
+ model, epoch, args, best_acc=val_acc_max, optimizer=optimizer, scheduler=scheduler
+ )
+ if args.rank == 0 and args.logdir is not None and args.save_checkpoint:
+ save_checkpoint(model, epoch, args, best_acc=val_acc_max, filename="model_final.pt")
+
+ if scheduler is not None:
+ scheduler.step()
+
+ print("Training Finished !, Best Accuracy: ", val_acc_max)
+
+ return val_acc_max
diff --git a/SwinMM/WORD/utils/__init__.py b/SwinMM/WORD/utils/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/SwinMM/WORD/utils/data_utils.py b/SwinMM/WORD/utils/data_utils.py
new file mode 100644
index 00000000..2788cce5
--- /dev/null
+++ b/SwinMM/WORD/utils/data_utils.py
@@ -0,0 +1,176 @@
+# Copyright 2020 - 2022 MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from typing import Any, Callable, MutableMapping, Sequence
+
+import timm.data
+import torch
+import torch.utils.data
+from utils import dataset_in_memory
+
+from monai import data, transforms
+from monai.data import load_decathlon_datalist
+from monai.data.utils import list_data_collate
+
+
+def get_dataset_kwargs(dataset_name: str, stage: str, use_normal_dataset: bool, args) -> MutableMapping[str, Any]:
+ dataset_kwargs = {}
+ if not use_normal_dataset:
+ dataset_kwargs = dict(
+ dataset_name=f"{stage}_{dataset_name}",
+ hosts=[{"host": "localhost", "port": str(port)} for port in args.redis_ports],
+ cluster_mode=True,
+ capacity_per_node=200 * 1024 * 1024 * 1024,
+ writer_buffer_size=0, # Disable write buffer
+ )
+ return dataset_kwargs
+
+
+def create_dataloader(
+ data_files: Sequence[Any],
+ is_testing: bool,
+ transform: Callable,
+ num_workers: int,
+ is_distributed: bool,
+ with_cache: bool,
+ use_multi_epochs_loader: bool,
+ batch_size: int,
+ dataset_kwargs: MutableMapping[str, Any],
+) -> torch.utils.data.DataLoader:
+ if not with_cache:
+ dataset = data.Dataset(data=data_files, transform=transform, **dataset_kwargs)
+ else:
+ dataset = dataset_in_memory.CachedDataset(data=data_files, transform=transform, **dataset_kwargs)
+ sampler = torch.utils.data.DistributedSampler(dataset, shuffle=not is_testing) if is_distributed else None
+
+ loader_class = data.DataLoader
+ if use_multi_epochs_loader:
+ loader_class = timm.data.loader.MultiEpochsDataLoader
+ loader = loader_class(
+ dataset,
+ batch_size=batch_size,
+ shuffle=False if is_distributed or is_testing else True,
+ num_workers=num_workers,
+ sampler=sampler,
+ pin_memory=True,
+ persistent_workers=True,
+ # NOTE(meijieru): otherwise `too many open`
+ collate_fn=list_data_collate,
+ )
+ return loader
+
+
+def get_loader(args):
+ data_dir = args.data_dir
+ datalist_json = os.path.join(data_dir, args.json_list)
+ train_transform = transforms.Compose(
+ [
+ transforms.LoadImaged(keys=["image", "label"]),
+ transforms.AddChanneld(keys=["image", "label"]),
+ transforms.Orientationd(keys=["image", "label"], axcodes="RAS"),
+ transforms.Spacingd(
+ keys=["image", "label"], pixdim=(args.space_x, args.space_y, args.space_z), mode=("bilinear", "nearest")
+ ),
+ transforms.ScaleIntensityRanged(
+ keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True
+ ),
+ # transforms.CropForegroundd(keys=["image", "label"], source_key="image"),
+ transforms.RandCropByPosNegLabeld(
+ keys=["image", "label"],
+ label_key="label",
+ spatial_size=(args.roi_x, args.roi_y, args.roi_z),
+ pos=1,
+ neg=1,
+ num_samples=4,
+ image_key="image",
+ image_threshold=0,
+ ),
+ transforms.RandFlipd(keys=["image", "label"], prob=args.RandFlipd_prob, spatial_axis=0),
+ transforms.RandFlipd(keys=["image", "label"], prob=args.RandFlipd_prob, spatial_axis=1),
+ transforms.RandFlipd(keys=["image", "label"], prob=args.RandFlipd_prob, spatial_axis=2),
+ transforms.RandRotate90d(keys=["image", "label"], prob=args.RandRotate90d_prob, max_k=3),
+ transforms.RandScaleIntensityd(keys="image", factors=0.1, prob=args.RandScaleIntensityd_prob),
+ transforms.RandShiftIntensityd(keys="image", offsets=0.1, prob=args.RandShiftIntensityd_prob),
+ transforms.ToTensord(keys=["image", "label"]),
+ ]
+ )
+ val_transform = transforms.Compose(
+ [
+ transforms.LoadImaged(keys=["image", "label"]),
+ transforms.AddChanneld(keys=["image", "label"]),
+ transforms.Orientationd(keys=["image", "label"], axcodes="RAS"),
+ transforms.Spacingd(
+ keys=["image", "label"], pixdim=(args.space_x, args.space_y, args.space_z), mode=("bilinear", "nearest")
+ ),
+ transforms.ScaleIntensityRanged(
+ keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True
+ ),
+ # transforms.CropForegroundd(keys=["image", "label"], source_key="image"),
+ transforms.ToTensord(keys=["image", "label"]),
+ ]
+ )
+
+ test_transform = transforms.Compose(
+ [
+ transforms.LoadImaged(keys=["image", "label"]),
+ transforms.AddChanneld(keys=["image", "label"]),
+ transforms.Orientationd(keys=["image", "label"], axcodes="RAS"),
+ transforms.Spacingd(keys="image", pixdim=(args.space_x, args.space_y, args.space_z), mode="bilinear"),
+ transforms.ScaleIntensityRanged(
+ keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True
+ ),
+ transforms.ToTensord(keys=["image", "label"]),
+ ]
+ )
+
+ def _get_subset_loader(subset_name: str, transform: Callable, is_testing: bool, use_normal_dataset: bool):
+ datalist = load_decathlon_datalist(datalist_json, True, subset_name, base_dir=data_dir)
+ dataset_kwargs = get_dataset_kwargs(subset_name, "finetune", use_normal_dataset, args)
+ loader = create_dataloader(
+ datalist,
+ is_testing,
+ transform,
+ args.workers,
+ args.distributed,
+ not use_normal_dataset,
+ not args.nouse_multi_epochs_loader,
+ args.batch_size,
+ dataset_kwargs,
+ )
+ return loader
+
+ if args.test_mode:
+ # Never cache as only go through once.
+ test_files = load_decathlon_datalist(datalist_json, True, "testing", base_dir=data_dir)
+ test_ds = data.Dataset(data=test_files, transform=test_transform)
+ test_sampler = torch.utils.data.DistributedSampler(test_ds, shuffle=False) if args.distributed else None
+ test_loader = data.DataLoader(
+ test_ds,
+ batch_size=1,
+ shuffle=False,
+ num_workers=args.workers,
+ sampler=test_sampler,
+ pin_memory=True,
+ persistent_workers=True,
+ )
+ loader = test_loader
+ else:
+ train_loader = _get_subset_loader("training", train_transform, False, args.use_normal_dataset)
+ val_loader = _get_subset_loader("validation", val_transform, True, args.use_normal_dataset_val)
+
+ if args.unsupervised:
+ unsupervised_loader = _get_subset_loader("unsupervised", train_transform, False, args.use_normal_dataset)
+ loader = [train_loader, val_loader, unsupervised_loader]
+ else:
+ loader = [train_loader, val_loader]
+
+ return loader
diff --git a/SwinMM/WORD/utils/dataset_in_memory.py b/SwinMM/WORD/utils/dataset_in_memory.py
new file mode 100644
index 00000000..7ee751ab
--- /dev/null
+++ b/SwinMM/WORD/utils/dataset_in_memory.py
@@ -0,0 +1,107 @@
+"""Cache the data using redis.
+
+TODO(meijieru): zeromp may be better.
+"""
+
+from typing import Callable, MutableMapping, Optional, Sequence, Union
+
+import bagua.torch_api.contrib.cache_loader as bagua_cache_loader
+import torch
+import torch.utils.data.dataset as torch_dataset
+
+import monai.data as monai_data
+import monai.transforms as monai_transforms
+
+_ALL_DATASET_NAMES = set()
+_SERIALIZATION_HIJACKED = False
+
+
+def hijack_bagua_serialization(method: str):
+ """Replace bagua serialization."""
+ global _SERIALIZATION_HIJACKED
+ if _SERIALIZATION_HIJACKED:
+ raise RuntimeError("Already hijacked.")
+
+ import pickle
+
+ if method == "lz4":
+ import lz4
+
+ compress, decompress = lz4.frame.compress, lz4.frame.decompress
+ elif method == "lzma":
+ import pylzma as lzma
+
+ compress, decompress = lzma.compress, lzma.decompress
+ elif method == "zlib":
+ import zlib
+
+ compress, decompress = zlib.compress, zlib.decompress
+ else:
+ raise ValueError(f"Unknown compress method: {method}")
+
+ bagua_cache_loader.serialize = lambda val: compress(pickle.dumps(val))
+ bagua_cache_loader.deserialize = lambda val: pickle.loads(decompress(val))
+ _SERIALIZATION_HIJACKED = True
+
+
+def is_deterministic_transform(transform) -> bool:
+ return not (
+ isinstance(transform, monai_transforms.Randomizable) or not isinstance(transform, monai_transforms.Transform)
+ )
+
+
+class CachedDataset(torch_dataset.Dataset):
+ def __init__(
+ self,
+ data: Sequence,
+ transform: Optional[Union[Sequence[Callable], Callable]] = None,
+ as_contiguous: bool = True,
+ backend: str = "redis",
+ hosts: Optional[Sequence[MutableMapping[str, str]]] = None,
+ dataset_name: str = "",
+ writer_buffer_size: int = 20,
+ **kwargs,
+ ) -> None:
+ super().__init__()
+
+ if hosts is None:
+ raise ValueError("We don't init bagua, have to manually launch redis")
+
+ # NOTE(meijieru): check if the dataset name is unique, to avoid
+ # potential confliction.
+ if not dataset_name or dataset_name in _ALL_DATASET_NAMES:
+ raise ValueError("Must have an unique name for each dataset.")
+ _ALL_DATASET_NAMES.add(dataset_name)
+
+ self._dataset = monai_data.Dataset(data=data)
+ self._cache_loader = bagua_cache_loader.CacheLoader(
+ backend, dataset_name, writer_buffer_size, hosts=hosts, **kwargs
+ )
+ self.transform = transform
+ self.as_contiguous = as_contiguous
+
+ def __len__(self):
+ return len(self._dataset)
+
+ def _apply_non_deterministic_transform(self, item):
+ for trans in self.transform.transforms: # type:ignore
+ if not is_deterministic_transform(trans):
+ item = monai_transforms.apply_transform(trans, item)
+ return item
+
+ def _apply_deterministic_transform(self, item):
+ for trans in self.transform.transforms: # type:ignore
+ # execute all the deterministic transforms
+ if not is_deterministic_transform(trans):
+ break
+ item = monai_transforms.apply_transform(trans, item)
+ if self.as_contiguous:
+ item = monai_transforms.convert_to_contiguous(item, memory_format=torch.contiguous_format)
+ return item
+
+ def _load_item(self, index: int):
+ return self._apply_deterministic_transform(self._dataset[index])
+
+ def __getitem__(self, item):
+ cached_item = self._cache_loader.get(item, self._load_item)
+ return self._apply_non_deterministic_transform(cached_item)
diff --git a/SwinMM/WORD/utils/misc.py b/SwinMM/WORD/utils/misc.py
new file mode 100644
index 00000000..6cb88985
--- /dev/null
+++ b/SwinMM/WORD/utils/misc.py
@@ -0,0 +1,78 @@
+# Copyright 2020 - 2022 MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import scipy.ndimage as ndimage
+import torch
+
+
+def resample_3d(img, target_size):
+ imx, imy, imz = img.shape
+ tx, ty, tz = target_size
+ zoom_ratio = (float(tx) / float(imx), float(ty) / float(imy), float(tz) / float(imz))
+ img_resampled = ndimage.zoom(img, zoom_ratio, order=0, prefilter=False)
+ return img_resampled
+
+
+class AverageMeter(object):
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = np.where(self.count > 0, self.sum / self.count, self.sum)
+
+
+def distributed_all_gather(
+ tensor_list, valid_batch_size=None, out_numpy=False, world_size=None, no_barrier=False, is_valid=None
+):
+ if world_size is None:
+ world_size = torch.distributed.get_world_size()
+ if valid_batch_size is not None:
+ valid_batch_size = min(valid_batch_size, world_size)
+ elif is_valid is not None:
+ is_valid = torch.tensor(bool(is_valid), dtype=torch.bool, device=tensor_list[0].device)
+ if not no_barrier:
+ torch.distributed.barrier()
+ tensor_list_out = []
+ with torch.no_grad():
+ if is_valid is not None:
+ is_valid_list = [torch.zeros_like(is_valid) for _ in range(world_size)]
+ torch.distributed.all_gather(is_valid_list, is_valid)
+ is_valid = [x.item() for x in is_valid_list]
+ for tensor in tensor_list:
+ gather_list = [torch.zeros_like(tensor) for _ in range(world_size)]
+ torch.distributed.all_gather(gather_list, tensor)
+ if valid_batch_size is not None:
+ gather_list = gather_list[:valid_batch_size]
+ elif is_valid is not None:
+ gather_list = [g for g, v in zip(gather_list, is_valid_list) if v]
+ if out_numpy:
+ gather_list = [t.cpu().numpy() for t in gather_list]
+ tensor_list_out.append(gather_list)
+ return tensor_list_out
+
+
+def dice(x, y):
+ intersect = np.sum(np.sum(np.sum(x * y)))
+ y_sum = np.sum(np.sum(np.sum(y)))
+ if y_sum == 0:
+ return 0.0
+ x_sum = np.sum(np.sum(np.sum(x)))
+ return 2 * intersect / (x_sum + y_sum)
diff --git a/SwinMM/WORD/utils/test_view_transforms.py b/SwinMM/WORD/utils/test_view_transforms.py
new file mode 100644
index 00000000..a8cbeb44
--- /dev/null
+++ b/SwinMM/WORD/utils/test_view_transforms.py
@@ -0,0 +1,39 @@
+"""Unit test for view transforms."""
+
+import itertools
+import unittest
+
+import numpy as np
+import torch
+from utils import view_transforms
+
+
+class ViewTransformsTest(unittest.TestCase):
+ def test_len(self):
+ self.assertTrue(len(view_transforms.all_forward_transforms), len(view_transforms.all_backward_transforms))
+
+ def test_inverse_transforms(self):
+ x = np.random.uniform(size=(2, 6, 3, 4, 5))
+ x_torch = torch.from_numpy(x)
+ for group_name, transforms in view_transforms.all_forward_transforms.items():
+ inverse_transforms = view_transforms.all_backward_transforms[group_name]
+ self.assertEqual(len(transforms), len(inverse_transforms))
+ for key in transforms:
+ x_recon = inverse_transforms[key](transforms[key](x_torch)).numpy()
+ np.testing.assert_allclose(x, x_recon)
+
+ def test_get_transforms_func(self):
+ x = np.random.uniform(size=(2, 6, 3, 4, 5))
+ x_torch = torch.from_numpy(x)
+
+ for order in [view_transforms.DEFAULT_ORDER, view_transforms.DEFAULT_ORDER[::-1]]:
+ views_all = itertools.product(*[view_transforms.all_forward_transforms[gn].keys() for gn in order])
+ for views in views_all:
+ func, inv_func = [
+ view_transforms.get_transforms_func(views, order, inverse) for inverse in [False, True]
+ ]
+ np.testing.assert_allclose(x, inv_func(func(x_torch)).numpy())
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/SwinMM/WORD/utils/view_ops.py b/SwinMM/WORD/utils/view_ops.py
new file mode 100644
index 00000000..8852e7e6
--- /dev/null
+++ b/SwinMM/WORD/utils/view_ops.py
@@ -0,0 +1,34 @@
+"""View operations."""
+
+from typing import Sequence, Tuple
+
+import numpy as np
+import torch
+from utils import view_transforms
+
+PermuteType = view_transforms.PermuteType
+TransformFuncType = view_transforms.TransformFuncType
+
+
+def get_permute_transform(view_src: PermuteType, view_dst: PermuteType) -> TransformFuncType:
+ """Gets transform function from view src to view dst."""
+
+ def transform(x: torch.Tensor) -> torch.Tensor:
+ x_view_0 = view_transforms.permutation_inverse_transforms[view_src](x)
+ return view_transforms.permutation_transforms[view_dst](x_view_0).contiguous()
+
+ return transform
+
+
+def permute_inverse(xs: Sequence[torch.Tensor], views: Sequence[PermuteType]) -> Sequence[torch.Tensor]:
+ """Transforms data back to origin view."""
+ return [get_permute_transform(view, 0)(x) for x, view in zip(xs, views)]
+
+
+def permute_rand(x: torch.Tensor, num_samples: int = 2) -> Tuple[Sequence[torch.Tensor], Sequence[PermuteType]]:
+ """Samples different transforms of data."""
+ num_permutes = len(view_transforms.permutation_transforms)
+ if num_samples > num_permutes:
+ raise ValueError("Duplicate samples.")
+ view_dsts = np.random.permutation(num_permutes)[:num_samples].tolist()
+ return [get_permute_transform(0, view)(x) for view in view_dsts], view_dsts
diff --git a/SwinMM/WORD/utils/view_transforms.py b/SwinMM/WORD/utils/view_transforms.py
new file mode 100644
index 00000000..752ab5bd
--- /dev/null
+++ b/SwinMM/WORD/utils/view_transforms.py
@@ -0,0 +1,69 @@
+"""View operations.
+
+Input format: [B, C, X, Y, Z, ...]
+
+NOTE(meijieru): 0 is reserved for identify transform.
+"""
+
+import enum
+from typing import Callable, Sequence, Union
+
+import torch
+
+RotateType = int
+PermuteType = int
+TransformFuncType = Callable[[torch.Tensor], torch.Tensor]
+# A composition of multiple view transoforms.
+TransformsType = Sequence[Union[PermuteType, RotateType]]
+
+
+class GroupName(enum.Enum):
+ ROTATE = 1
+ PERMUTE = 2
+
+
+DEFAULT_ORDER = (GroupName.ROTATE, GroupName.PERMUTE)
+
+rotation_transforms = {
+ 0: lambda x: x,
+ 1: lambda x: x.rot90(1, (3, 4)),
+ 2: lambda x: x.rot90(2, (3, 4)),
+ 3: lambda x: x.rot90(3, (3, 4)),
+}
+rotation_inverse_transforms = {
+ 0: lambda x: x,
+ 1: lambda x: x.rot90(3, (3, 4)),
+ 2: lambda x: x.rot90(2, (3, 4)),
+ 3: lambda x: x.rot90(1, (3, 4)),
+}
+permutation_transforms = {0: lambda x: x, 1: lambda x: x.permute(0, 1, 3, 2, 4), 2: lambda x: x.permute(0, 1, 4, 3, 2)}
+permutation_inverse_transforms = {
+ 0: lambda x: x,
+ 1: lambda x: x.permute(0, 1, 3, 2, 4),
+ 2: lambda x: x.permute(0, 1, 4, 3, 2),
+}
+
+all_forward_transforms = {GroupName.ROTATE: rotation_transforms, GroupName.PERMUTE: permutation_transforms}
+all_backward_transforms = {
+ GroupName.ROTATE: rotation_inverse_transforms,
+ GroupName.PERMUTE: permutation_inverse_transforms,
+}
+
+
+def get_transforms_func(
+ views: TransformsType, orders: Sequence[GroupName] = DEFAULT_ORDER, inverse: bool = False
+) -> TransformFuncType:
+ """Gets sequential transform functions."""
+ if len(views) != len(orders):
+ raise ValueError()
+
+ all_transforms = all_forward_transforms if not inverse else all_backward_transforms
+ funcs = [all_transforms[group_name][view] for view, group_name in zip(views, orders)]
+ funcs = funcs if not inverse else funcs[::-1]
+
+ def aux(val):
+ for func in funcs:
+ val = func(val)
+ return val
+
+ return aux
diff --git a/SwinMM/figures/ACDC.png b/SwinMM/figures/ACDC.png
new file mode 100644
index 00000000..af81301b
Binary files /dev/null and b/SwinMM/figures/ACDC.png differ
diff --git a/SwinMM/figures/Result.png b/SwinMM/figures/Result.png
new file mode 100644
index 00000000..278fe405
Binary files /dev/null and b/SwinMM/figures/Result.png differ
diff --git a/SwinMM/figures/SwinMMArch.png b/SwinMM/figures/SwinMMArch.png
new file mode 100644
index 00000000..08487b6f
Binary files /dev/null and b/SwinMM/figures/SwinMMArch.png differ
diff --git a/SwinMM/figures/finetune.png b/SwinMM/figures/finetune.png
new file mode 100644
index 00000000..1420a112
Binary files /dev/null and b/SwinMM/figures/finetune.png differ
diff --git a/SwinMM/figures/pretrain.png b/SwinMM/figures/pretrain.png
new file mode 100644
index 00000000..b2c71254
Binary files /dev/null and b/SwinMM/figures/pretrain.png differ
diff --git a/SwinMM/requirements.txt b/SwinMM/requirements.txt
new file mode 100644
index 00000000..047387a3
--- /dev/null
+++ b/SwinMM/requirements.txt
@@ -0,0 +1,33 @@
+apex==0.9.10dev
+batchgenerators==0.25
+caffe2==0.8.1
+colorama==0.4.6
+Flask==2.3.2
+gevent==23.7.0
+gorilla==0.4.0
+hypothesis==6.81.2
+lz4==4.3.2
+monai==1.2.0
+nibabel==5.1.0
+numba==0.57.1
+numpy==1.25.1
+pssh==2.3.1
+ptvsd==4.3.2
+pydantic==2.0.3
+pylzma==0.5.0
+pytest==7.4.0
+redis==4.6.0
+Requests==2.31.0
+scipy==1.11.1
+setuptools==65.6.3
+setuptools_rust==1.6.0
+tensorboardX==2.6.1
+tensorflow_datasets==4.9.2
+timm==0.9.2
+torchvision==0.15.2
+tqdm==4.65.0
+transformers==4.30.2
+typing_extensions==4.7.1
+urllib3==1.26.15
+xmlrunner==1.7.7
+xxhash==3.2.0
diff --git a/SwinMM/scripts/setup_env.sh b/SwinMM/scripts/setup_env.sh
new file mode 100644
index 00000000..fb4892a4
--- /dev/null
+++ b/SwinMM/scripts/setup_env.sh
@@ -0,0 +1,5 @@
+# Activate conda corresponding env first
+conda install -c conda-forge redis redis-py
+pip install timm
+pip install bagua-cuda111
+pip install setuptools==59.5.0
diff --git a/SwinMM/scripts/start_redis.sh b/SwinMM/scripts/start_redis.sh
new file mode 100644
index 00000000..9f2bd012
--- /dev/null
+++ b/SwinMM/scripts/start_redis.sh
@@ -0,0 +1,13 @@
+ports="39999 39998 39997 39996"
+for port in ${ports}; do
+ echo "run redis at localhost:${port}"
+ redis-server \
+ --daemonize yes \
+ --port ${port} \
+ --maxclients 100000 \
+ --maxmemory 0 \
+ --maxmemory-policy noeviction \
+ --appendonly no \
+ --save "" \
+ --protected-mode no
+done