diff --git a/SwinMM/INSTALL.md b/SwinMM/INSTALL.md new file mode 100644 index 00000000..16ceb026 --- /dev/null +++ b/SwinMM/INSTALL.md @@ -0,0 +1,101 @@ +# Installation + +We provide installation instructions here. + +## Setup + +### Using Docker + +The simplest way to use SwinMM is to use our docker image [`swinmm`](https://drive.google.com/file/d/1EGSoqN-HphyMV_gKUq-g7_BSwTTg35oA/view?usp=sharing), which has contained all the needed dependencies. Download the `swinmm.tar` into the `SwinMM` directory and try the following scripts: + +```bash +cd SwinMM +docker import - swinmm < swinmm.tar +docker run --runtime=nvidia --gpus=all -m="800g" --shm-size="32g" -itd -v ./:/volume swinmm /bin/bash +docker exec -it swinmm /bin/bash +conda activate SwinMM +``` + +To use docker, make sure you have installed `docker` and `nvidia-docker`. + +### Manual + +For fast dataset loading, we required the users to install the Redis database, for example, on Ubuntu: `sudo apt-get install redis` + +We also recommend the users install the PyTorch-based version from the official website. + +Two packages are recommended to install manually according to their complicated dependencies: [bagua==0.9.2](https://github.com/BaguaSys/bagua), [monai==0.9.0](https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies) + +The others can be installed through `pip install -r requirements.txt` + +## Datasets + +Our pre-training dataset includes 5833 volumes from 8 public datasets: + +- [AbdomenCT-1K](https://github.com/JunMa11/AbdomenCT-1K) +- [BTCV](https://www.synapse.org/#!Synapse:syn3193805/wiki/217789) +- [MSD](http://medicaldecathlon.com/) +- [TCIACovid19](https://wiki.cancerimagingarchive.net/display/Public/CT+Images+in+COVID-19/) +- [WORD](https://github.com/HiLab-git/WORD) +- [TCIA-Colon](https://wiki.cancerimagingarchive.net/display/Public/CT+COLONOGRAPHY/) +- [LiDC](https://wiki.cancerimagingarchive.net/display/Public/LIDC-IDRI/) +- [HNSCC](https://wiki.cancerimagingarchive.net/display/Public/HNSCC) + +We choose two popular datasets to test the downstream segmentation performance: + +- [WORD](https://github.com/HiLab-git/WORD) (The Whole abdominal Organ Dataset) +- [ACDC](https://www.creatis.insa-lyon.fr/Challenge/acdc/#challenge/584e75606a3c77492fe91bba) (Automated Cardiac Diagnosis Challenge) + +The json files can be downloaded from [pretrain_jsons](https://drive.google.com/file/d/1gJThxBvnJnc2_N1nFX7xywjFWFw7DSEY/view?usp=sharing) and [word_jsons](https://drive.google.com/file/d/1Td4T_k2QlEcTETz9TERGsVdOyebD5ULv/view?usp=sharing); + +The dataset is organized as below: + +```text +SwinMM +├── WORD +│ └── dataset +│ └── dataset12_WORD +│ ├── imagesTr +│ ├── imagesTs +│ ├── imagesVal +│ ├── labelsTr +│ ├── labelsTs +│ ├── labelsVal +│ └── dataset12_WORD.json +└── Pretrain + ├── dataset + │ ├── dataset00_BTCV + │ ├── dataset02_Heart + │ ├── dataset03_Liver + │ ├── dataset04_Hippocampus + │ ├── dataset06_Lung + │ ├── dataset07_Pancreas + │ ├── dataset08_HepaticVessel + │ ├── dataset09_Spleen + │ ├── dataset10_Colon + │ ├── dataset11_TCIAcovid19 + │ ├── dataset12_WORD + │ ├── dataset13_AbdomenCT-1K + │ ├── dataset_HNSCC + │ ├── dataset_TCIAcolon + │ └── dataset_LIDC + └── jsons + ├── dataset00_BTCV.json + ├── dataset01_BrainTumour.json + ├── dataset02_Heart.json + ├── dataset03_Liver.json + ├── dataset04_Hippocampus.json + ├── dataset05_Prostate.json + ├── dataset06_Lung.json + ├── dataset07_Pancreas.json + ├── dataset08_HepaticVessel.json + ├── dataset09_Spleen.json + ├── dataset10_Colon.json + ├── dataset11_TCIAcovid19.json + ├── dataset12_WORD.json + ├── dataset13_AbdomenCT-1K.json + ├── dataset_HNSCC.json + ├── dataset_TCIAcolon.json + └── dataset_LIDC.json + +``` diff --git a/SwinMM/Pretrain/jsons/Download Pretrain Jsons Here b/SwinMM/Pretrain/jsons/Download Pretrain Jsons Here new file mode 100644 index 00000000..e69de29b diff --git a/SwinMM/Pretrain/jsons/__init__.py b/SwinMM/Pretrain/jsons/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/SwinMM/Pretrain/losses/loss.py b/SwinMM/Pretrain/losses/loss.py new file mode 100644 index 00000000..8e0b5a88 --- /dev/null +++ b/SwinMM/Pretrain/losses/loss.py @@ -0,0 +1,95 @@ +# Copyright 2020 - 2022 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch.nn import functional as F + + +class ContrastLoss(torch.nn.Module): + def __init__(self, args, batch_size, temperature=0.5): + super().__init__() + device = torch.device(f"cuda:{args.local_rank}") + self.batch_size = batch_size + self.register_buffer("temp", torch.tensor(temperature).to(torch.device(f"cuda:{args.local_rank}"))) + self.register_buffer("neg_mask", (~torch.eye(batch_size * 2, batch_size * 2, dtype=bool).to(device)).float()) + + def forward(self, x_i, x_j): + z_i = F.normalize(x_i, dim=1) + z_j = F.normalize(x_j, dim=1) + z = torch.cat([z_i, z_j], dim=0) + sim = F.cosine_similarity(z.unsqueeze(1), z.unsqueeze(0), dim=2) + sim_ij = torch.diag(sim, self.batch_size) + sim_ji = torch.diag(sim, -self.batch_size) + pos = torch.cat([sim_ij, sim_ji], dim=0) + nom = torch.exp(pos / self.temp) + denom = self.neg_mask * torch.exp(sim / self.temp) + return torch.sum(-torch.log(nom / torch.sum(denom, dim=1))) / (2 * self.batch_size) + + +class MutualLoss(torch.nn.Module): + def __init__(self, args): + super().__init__() + self.alpha = 1.0 + self.mask_ratio = args.mask_ratio + self.recon_loss_2 = torch.nn.MSELoss().cuda() + + def __call__(self, rec1, rec2, mask): + mask = mask.to(dtype=rec1.dtype) + rec1, rec2 = [val * mask for val in [rec1, rec2]] + + recon_loss = self.recon_loss_2(rec1, rec2) / self.mask_ratio + return self.alpha * recon_loss + + +class Loss(torch.nn.Module): + def __init__(self, batch_size, args): + super().__init__() + self.rot_loss = torch.nn.CrossEntropyLoss().cuda() + self.recon_loss = torch.nn.L1Loss().cuda() + self.recon_loss_2 = torch.nn.MSELoss().cuda() + self.contrast_loss = ContrastLoss(args, batch_size).cuda() + self.alpha1 = 1.0 + self.alpha2 = 1.0 + self.alpha3 = 1.0 + self.norm_pix_loss = args.norm_pix_loss + self.mask_ratio = args.mask_ratio + + def __call__( + self, + output_rot, + target_rot, + output_contrastive, + target_contrastive, + output_recons, + target_recons, + mask, + only_mae=False, + ): + B, C, H, W, D = output_recons.shape + target_recons = target_recons.reshape(B, C, -1) + + if self.norm_pix_loss: + mean = target_recons.mean(dim=-1, keepdim=True) + var = target_recons.var(dim=-1, keepdim=True) + target_recons = (target_recons - mean) / (var + 1.0e-6) ** 0.5 + target_recons = target_recons.reshape(B, C, H, W, D) + # masked voxels. + mask = mask.to(dtype=target_recons.dtype)[None, ...] + target_recons, output_recons = [val * mask for val in [target_recons, output_recons]] + recon_loss = self.recon_loss_2(output_recons, target_recons) / self.mask_ratio + recon_loss = self.alpha3 * recon_loss + if only_mae: + return recon_loss + contrast_loss = self.alpha2 * self.contrast_loss(output_contrastive, target_contrastive) + rot_loss = self.alpha1 * self.rot_loss(output_rot, target_rot) + total_loss = rot_loss + contrast_loss + recon_loss + + return total_loss, (rot_loss, contrast_loss, recon_loss) diff --git a/SwinMM/Pretrain/main.py b/SwinMM/Pretrain/main.py new file mode 100644 index 00000000..419b8e1c --- /dev/null +++ b/SwinMM/Pretrain/main.py @@ -0,0 +1,317 @@ +# Copyright 2020 - 2022 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import random +from time import time + +import timm.optim.optim_factory as optim_factory +import torch +import torch.distributed as dist +import torch.optim as optim +from losses.loss import Loss, MutualLoss +from models.ssl_head import SSLHead +from optimizers.lr_scheduler import WarmupCosineSchedule +from torch.cuda.amp import GradScaler, autocast +from torch.nn.parallel import DistributedDataParallel +from utils import view_ops, view_transforms +from utils.data_utils import get_loader +from utils.dataset_in_memory import hijack_bagua_serialization +from utils.ops import mask_rand_patch + +# torch +torch.multiprocessing.set_sharing_strategy("file_system") + + +def main(): + def save_ckpt(state, checkpoint_dir): + torch.save(state, checkpoint_dir) + + def train(args, global_step, train_loader, val_best, scaler): + model.train() + + for _, batch in enumerate(train_loader): + t1 = time() + x = batch["image"].cuda() + + x1, rot1 = view_ops.rot_rand(x) + x2, rot2 = view_ops.rot_rand(x) + + window_sizes = tuple(args.window_size for _ in range(3)) + input_sizes = (args.roi_x, args.roi_y, args.roi_z) + x1_masked, mask1 = mask_rand_patch(window_sizes, input_sizes, args.mask_ratio, x1) + x2_masked, mask2 = mask_rand_patch(window_sizes, input_sizes, args.mask_ratio, x2) + + # NOTE(meijieru): x1, x2 may have different rot transform, so we + # allow same permute transform here. + permutations_candidates = set(view_transforms.permutation_transforms.keys()) - {0} + permutations = [random.choice(list(permutations_candidates)) for _ in range(2)] + x1_masked_permuted, x2_masked_permuted = [ + view_transforms.permutation_transforms[vn](val) for vn, val in zip(permutations, [x1_masked, x2_masked]) + ] + + with autocast(enabled=args.amp): + rot1_p, contrastive1_p, rec_x1 = model(x1_masked) + rot2_p, contrastive2_p, rec_x2 = model(x2_masked) + _, contrastive3_p, rec_x3 = model(x1_masked_permuted) + _, contrastive4_p, rec_x4 = model(x2_masked_permuted) + + # masked voxels: [2, H, W, D] + mask = torch.stack([mask1, mask2], dim=0) + rec_x3, rec_x4 = [ + view_transforms.permutation_inverse_transforms[vn](val) + for vn, val in zip(permutations, [rec_x3, rec_x4]) + ] + + rot_p = torch.cat([rot1_p, rot2_p], dim=0) + rots = torch.cat([rot1, rot2], dim=0) + # [B, 2, H, W, D] + imgs_recon = torch.cat([rec_x1, rec_x2], dim=1) + imgs = torch.cat([x1, x2], dim=1) + loss1, losses_tasks1 = loss_function( + rot_p, rots, contrastive1_p, contrastive2_p, imgs_recon, imgs, mask + ) + + mutual_loss1 = mutual_loss_function(rec_x3, rec_x1, mask1) + + imgs_recon = torch.cat([rec_x3, rec_x4], dim=1) + loss2 = loss_function( + rot_p, rots, contrastive3_p, contrastive4_p, imgs_recon, imgs, mask, only_mae=True + ) + + loss = loss1 + loss2 + mutual_loss1 + + mutual_loss2 = None + if args.mutual_learning_on_more_view: + + def _align_rot(x, src_rot, dst_rot): + return view_transforms.rotation_transforms[dst_rot]( + view_transforms.rotation_inverse_transforms[src_rot](x) + ).contiguous() + + # [B, C, H, W, D] + rec_x4_aligned = torch.stack( + [ + _align_rot(val, src_rot.item(), dst_rot.item()) + for val, src_rot, dst_rot in zip(rec_x4, rot2, rot1) + ] + ) + # [B, 1, H, W, D] + mask2_aligned = torch.concat( + [ + _align_rot(mask2[None, None], src_rot.item(), dst_rot.item()) + for src_rot, dst_rot in zip(rot2, rot1) + ] + ) + mask_intersection = torch.logical_and(mask2_aligned, mask1) + # Rescale to the same scale of mutual_loss1 + rescaler = mask1.sum() * mask2_aligned.size(0) / (mask2_aligned.sum() + 1e-6) + mutual_loss2 = mutual_loss_function(rec_x4_aligned, rec_x1, mask_intersection) * rescaler + + loss = loss + mutual_loss2 + + if args.amp: + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + else: + loss.backward() + if args.grad_clip: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) + optimizer.step() + + if args.lrdecay: + scheduler.step() + optimizer.zero_grad() + if args.distributed: + if dist.get_rank() == 0: + rot_loss = losses_tasks1[0].item() + con_loss = losses_tasks1[1].item() + rec_loss = losses_tasks1[2].item() + loss2.item() + print( + "Step:{}/{}, Loss:{:.4f}, Rot:{:.4f}, Con:{:.4f}, Rec:{:.4f}, Time:{:.4f}".format( + global_step, args.num_steps, loss, rot_loss, con_loss, rec_loss, time() - t1 + ) + ) + else: + print("Step:{}/{}, Loss:{:.4f}, Time:{:.4f}".format(global_step, args.num_steps, loss, time() - t1)) + + global_step += 1 + if args.distributed: + val_cond = (dist.get_rank() == 0) and (global_step % args.eval_num == 0) + else: + val_cond = global_step % args.eval_num == 0 + + if val_cond and global_step % 1000 == 0: + checkpoint = { + "global_step": global_step, + "state_dict": model.state_dict(), + "optimizer": optimizer.state_dict(), + } + save_ckpt(checkpoint, logdir + "/model_{}.pt".format(global_step)) + return global_step, loss, val_best + + parser = argparse.ArgumentParser(description="PyTorch Training") + parser.add_argument("--logdir", default="test", type=str, help="directory to save the tensorboard logs") + parser.add_argument("--epochs", default=100, type=int, help="number of training epochs") + parser.add_argument("--num_steps", default=100000, type=int, help="number of training iterations") + parser.add_argument("--eval_num", default=100, type=int, help="evaluation frequency") + parser.add_argument("--warmup_steps", default=500, type=int, help="warmup steps") + parser.add_argument("--in_channels", default=1, type=int, help="number of input channels") + parser.add_argument("--feature_size", default=48, type=int, help="embedding size") + parser.add_argument("--dropout_path_rate", default=0.0, type=float, help="drop path rate") + parser.add_argument("--use_checkpoint", action="store_true", help="use gradient checkpointing to save memory") + parser.add_argument("--spatial_dims", default=3, type=int, help="spatial dimension of input data") + parser.add_argument("--a_min", default=-1000, type=float, help="a_min in ScaleIntensityRanged") + parser.add_argument("--a_max", default=1000, type=float, help="a_max in ScaleIntensityRanged") + parser.add_argument("--b_min", default=0.0, type=float, help="b_min in ScaleIntensityRanged") + parser.add_argument("--b_max", default=1.0, type=float, help="b_max in ScaleIntensityRanged") + parser.add_argument("--space_x", default=1.5, type=float, help="spacing in x direction") + parser.add_argument("--space_y", default=1.5, type=float, help="spacing in y direction") + parser.add_argument("--space_z", default=2.0, type=float, help="spacing in z direction") + parser.add_argument("--roi_x", default=64, type=int, help="roi size in x direction") + parser.add_argument("--roi_y", default=64, type=int, help="roi size in y direction") + parser.add_argument("--roi_z", default=64, type=int, help="roi size in z direction") + parser.add_argument("--mask_ratio", default=0.5, type=float, help="mask ratio for MAE pretraining") + parser.add_argument("--window_size", default=16, type=int, help="window size for MAE pretraining") + parser.add_argument("--batch_size", default=1, type=int, help="number of batch size") + parser.add_argument("--sw_batch_size", default=2, type=int, help="number of sliding window batch size") + parser.add_argument("--lr", default=4e-4, type=float, help="learning rate") + parser.add_argument("--decay", default=0.1, type=float, help="decay rate") + parser.add_argument("--momentum", default=0.9, type=float, help="momentum") + parser.add_argument("--lrdecay", action="store_true", help="enable learning rate decay") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="maximum gradient norm") + parser.add_argument("--loss_type", default="SSL", type=str) + parser.add_argument("--opt", default="adamw", type=str, help="optimization algorithm") + parser.add_argument("--lr_schedule", default="warmup_cosine", type=str) + parser.add_argument("--resume", default=None, type=str, help="resume training") + parser.add_argument("--local_rank", type=int, default=0, help="local rank") + parser.add_argument("--grad_clip", action="store_true", help="gradient clip") + parser.add_argument("--noamp", action="store_true", help="do NOT use amp for training") + parser.add_argument("--dist-url", default="env://", help="url used to set up distributed training") + parser.add_argument("--norm_pix_loss", action="store_true", help="normalize before compute reconstruction loss") + parser.add_argument("--redis_ports", nargs="+", type=int, help="redis ports") + parser.add_argument("--redis_compression", type=str, default="lz4", help="compression method for redis.") + parser.add_argument("--use_normal_dataset", action="store_true", help="use monai Dataset class") + parser.add_argument( + "--nouse_multi_epochs_loader", + action="store_true", + help="not use the multi-epochs-loader to save time at the beginning of every epoch", + ) + parser.add_argument( + "--mutual_learning_on_more_view", action="store_true", help="also use rotate for mutual learning" + ) + parser.add_argument("--workers", default=16, type=int, help="number of workers") + + args = parser.parse_args() + logdir = "./runs/" + args.logdir + args.amp = not args.noamp + args.lr = args.lr * args.batch_size / 2 + torch.backends.cudnn.benchmark = True + args.distributed = False + if "WORLD_SIZE" in os.environ: + args.distributed = int(os.environ["WORLD_SIZE"]) > 1 + args.device = "cuda:0" + args.world_size = 1 + args.rank = 0 + + if args.distributed: + args.device = "cuda:%d" % args.local_rank + torch.cuda.set_device(args.local_rank) + dist.init_process_group(backend="nccl", init_method=args.dist_url) + args.world_size = dist.get_world_size() + args.rank = dist.get_rank() + print( + "Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d." + % (args.rank, args.world_size) + ) + else: + print("Training with a single process on 1 GPUs.") + assert args.rank >= 0 + + if args.redis_compression is not None: + hijack_bagua_serialization(args.redis_compression) + + if args.rank == 0: + os.makedirs(logdir, exist_ok=True) + + model = SSLHead(args) + model.cuda() + model_without_ddp = model + + param_groups = optim_factory.param_groups_weight_decay( + model_without_ddp, weight_decay=args.decay, no_weight_decay_list=model_without_ddp.no_weight_decay() + ) + if args.opt == "adam": + optimizer = optim.Adam(param_groups, lr=args.lr) + + elif args.opt == "adamw": + optimizer = optim.AdamW(param_groups, lr=args.lr) + + elif args.opt == "sgd": + optimizer = optim.SGD(param_groups, lr=args.lr, momentum=args.momentum) + else: + raise ValueError(f"Unknown optimizer: {args.opt})") + + global_step = 0 + if args.resume: + model_pth = args.resume + model_dict = torch.load(model_pth) + new_state = {} + + for k, v in model_dict["state_dict"].items(): + new_name = k[7:] + new_state[new_name] = v + + model.load_state_dict(new_state) + global_step = model_dict["global_step"] + model.optimizer = model_dict["optimizer"] + + if args.lrdecay: + if args.lr_schedule == "warmup_cosine": + scheduler = WarmupCosineSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=args.num_steps) + + elif args.lr_schedule == "poly": + + def lambdas(epoch): + return (1 - float(epoch) / float(args.epochs)) ** 0.9 + + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambdas) + + mutual_loss_function = MutualLoss(args) + loss_function = Loss(args.batch_size * args.sw_batch_size, args) + if args.distributed: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + model = DistributedDataParallel(model, device_ids=[args.local_rank], broadcast_buffers=False) + model_without_ddp = model.module + train_loader, _ = get_loader(args) + + best_val = 1e8 + if args.amp: + scaler = GradScaler() + else: + scaler = None + while global_step < args.num_steps: + global_step, loss, best_val = train(args, global_step, train_loader, best_val, scaler) + checkpoint = {"epoch": args.epochs, "state_dict": model.state_dict(), "optimizer": optimizer.state_dict()} + + if args.distributed: + if dist.get_rank() == 0: + torch.save(model.state_dict(), logdir + "final_model.pth") + dist.destroy_process_group() + else: + torch.save(model.state_dict(), logdir + "final_model.pth") + save_ckpt(checkpoint, logdir + "/model_final_epoch.pt") + + +if __name__ == "__main__": + main() diff --git a/SwinMM/Pretrain/models/ssl_head.py b/SwinMM/Pretrain/models/ssl_head.py new file mode 100644 index 00000000..fc259418 --- /dev/null +++ b/SwinMM/Pretrain/models/ssl_head.py @@ -0,0 +1,99 @@ +# Copyright 2020 - 2022 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + +from monai.networks.nets.swin_unetr import SwinTransformer as SwinViT +from monai.utils import ensure_tuple_rep + + +class SSLHead(nn.Module): + def __init__(self, args, upsample="vae", dim=768): + super(SSLHead, self).__init__() + patch_size = ensure_tuple_rep(2, args.spatial_dims) + window_size = ensure_tuple_rep(7, args.spatial_dims) + self.swinViT = SwinViT( + in_chans=args.in_channels, + embed_dim=args.feature_size, + window_size=window_size, + patch_size=patch_size, + depths=[2, 2, 2, 2], + num_heads=[3, 6, 12, 24], + mlp_ratio=4.0, + qkv_bias=True, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=args.dropout_path_rate, + norm_layer=torch.nn.LayerNorm, + use_checkpoint=args.use_checkpoint, + spatial_dims=args.spatial_dims, + ) + self.rotation_pre = nn.Identity() + self.rotation_head = nn.Linear(dim, 4) + self.contrastive_pre = nn.Identity() + self.contrastive_head = nn.Linear(dim, 512) + if upsample == "large_kernel_deconv": + self.conv = nn.ConvTranspose3d(dim, args.in_channels, kernel_size=(32, 32, 32), stride=(32, 32, 32)) + elif upsample == "deconv": + self.conv = nn.Sequential( + nn.ConvTranspose3d(dim, dim // 2, kernel_size=(2, 2, 2), stride=(2, 2, 2)), + nn.ConvTranspose3d(dim // 2, dim // 4, kernel_size=(2, 2, 2), stride=(2, 2, 2)), + nn.ConvTranspose3d(dim // 4, dim // 8, kernel_size=(2, 2, 2), stride=(2, 2, 2)), + nn.ConvTranspose3d(dim // 8, dim // 16, kernel_size=(2, 2, 2), stride=(2, 2, 2)), + nn.ConvTranspose3d(dim // 16, args.in_channels, kernel_size=(2, 2, 2), stride=(2, 2, 2)), + ) + elif upsample == "vae": + self.conv = nn.Sequential( + nn.Conv3d(dim, dim // 2, kernel_size=3, stride=1, padding=1), + nn.InstanceNorm3d(dim // 2), + nn.LeakyReLU(), + nn.Upsample(scale_factor=2, mode="trilinear", align_corners=False), + nn.Conv3d(dim // 2, dim // 4, kernel_size=3, stride=1, padding=1), + nn.InstanceNorm3d(dim // 4), + nn.LeakyReLU(), + nn.Upsample(scale_factor=2, mode="trilinear", align_corners=False), + nn.Conv3d(dim // 4, dim // 8, kernel_size=3, stride=1, padding=1), + nn.InstanceNorm3d(dim // 8), + nn.LeakyReLU(), + nn.Upsample(scale_factor=2, mode="trilinear", align_corners=False), + nn.Conv3d(dim // 8, dim // 16, kernel_size=3, stride=1, padding=1), + nn.InstanceNorm3d(dim // 16), + nn.LeakyReLU(), + nn.Upsample(scale_factor=2, mode="trilinear", align_corners=False), + nn.Conv3d(dim // 16, dim // 16, kernel_size=3, stride=1, padding=1), + nn.InstanceNorm3d(dim // 16), + nn.LeakyReLU(), + nn.Upsample(scale_factor=2, mode="trilinear", align_corners=False), + nn.Conv3d(dim // 16, args.in_channels, kernel_size=1, stride=1), + ) + + def forward(self, x): + x_out = self.swinViT(x.contiguous())[4] + _, c, h, w, d = x_out.shape + x4_reshape = x_out.flatten(start_dim=2, end_dim=4) + x4_reshape = x4_reshape.transpose(1, 2) + x_rot = self.rotation_pre(x4_reshape[:, 0]) + x_rot = self.rotation_head(x_rot) + x_contrastive = self.contrastive_pre(x4_reshape[:, 1]) + x_contrastive = self.contrastive_head(x_contrastive) + x_rec = x_out.flatten(start_dim=2, end_dim=4) + x_rec = x_rec.view(-1, c, h, w, d) + x_rec = self.conv(x_rec) + return x_rot, x_contrastive, x_rec + + def no_weight_decay(self): + """Disable weight_decay on specific weights.""" + nwd = {"swinViT.absolute_pos_embed"} + for n, _ in self.named_parameters(): + if "relative_position_bias_table" in n: + nwd.add(n) + return nwd diff --git a/SwinMM/Pretrain/optimizers/__init__.py b/SwinMM/Pretrain/optimizers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/SwinMM/Pretrain/optimizers/lr_scheduler.py b/SwinMM/Pretrain/optimizers/lr_scheduler.py new file mode 100644 index 00000000..0c352927 --- /dev/null +++ b/SwinMM/Pretrain/optimizers/lr_scheduler.py @@ -0,0 +1,172 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import warnings +from typing import List + +from torch import nn as nn +from torch.optim import Adam, Optimizer +from torch.optim.lr_scheduler import LambdaLR, _LRScheduler + +__all__ = ["LinearLR", "ExponentialLR"] + + +class _LRSchedulerMONAI(_LRScheduler): + """Base class for increasing the learning rate between two boundaries over a number + of iterations""" + + def __init__(self, optimizer: Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1) -> None: + """ + Args: + optimizer: wrapped optimizer. + end_lr: the final learning rate. + num_iter: the number of iterations over which the test occurs. + last_epoch: the index of last epoch. + Returns: + None + """ + self.end_lr = end_lr + self.num_iter = num_iter + super(_LRSchedulerMONAI, self).__init__(optimizer, last_epoch) + + +class LinearLR(_LRSchedulerMONAI): + """Linearly increases the learning rate between two boundaries over a number of + iterations. + """ + + def get_lr(self): + r = self.last_epoch / (self.num_iter - 1) + return [base_lr + r * (self.end_lr - base_lr) for base_lr in self.base_lrs] + + +class ExponentialLR(_LRSchedulerMONAI): + """Exponentially increases the learning rate between two boundaries over a number of + iterations. + """ + + def get_lr(self): + r = self.last_epoch / (self.num_iter - 1) + return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs] + + +class WarmupCosineSchedule(LambdaLR): + """Linear warmup and then cosine decay. + Based on https://huggingface.co/ implementation. + """ + + def __init__( + self, optimizer: Optimizer, warmup_steps: int, t_total: int, cycles: float = 0.5, last_epoch: int = -1 + ) -> None: + """ + Args: + optimizer: wrapped optimizer. + warmup_steps: number of warmup iterations. + t_total: total number of training iterations. + cycles: cosine cycles parameter. + last_epoch: the index of last epoch. + Returns: + None + """ + self.warmup_steps = warmup_steps + self.t_total = t_total + self.cycles = cycles + super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch) + + def lr_lambda(self, step): + if step < self.warmup_steps: + return float(step) / float(max(1.0, self.warmup_steps)) + progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps)) + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.cycles) * 2.0 * progress))) + + +class LinearWarmupCosineAnnealingLR(_LRScheduler): + def __init__( + self, + optimizer: Optimizer, + warmup_epochs: int, + max_epochs: int, + warmup_start_lr: float = 0.0, + eta_min: float = 0.0, + last_epoch: int = -1, + ) -> None: + """ + Args: + optimizer (Optimizer): Wrapped optimizer. + warmup_epochs (int): Maximum number of iterations for linear warmup + max_epochs (int): Maximum number of iterations + warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0. + eta_min (float): Minimum learning rate. Default: 0. + last_epoch (int): The index of last epoch. Default: -1. + """ + self.warmup_epochs = warmup_epochs + self.max_epochs = max_epochs + self.warmup_start_lr = warmup_start_lr + self.eta_min = eta_min + + super(LinearWarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch) + + def get_lr(self) -> List[float]: + """ + Compute learning rate using chainable form of the scheduler + """ + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", UserWarning + ) + + if self.last_epoch == 0: + return [self.warmup_start_lr] * len(self.base_lrs) + elif self.last_epoch < self.warmup_epochs: + return [ + group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) + for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) + ] + elif self.last_epoch == self.warmup_epochs: + return self.base_lrs + elif (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0: + return [ + group["lr"] + + (base_lr - self.eta_min) * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2 + for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) + ] + + return [ + (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) + / ( + 1 + + math.cos( + math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs) + ) + ) + * (group["lr"] - self.eta_min) + + self.eta_min + for group in self.optimizer.param_groups + ] + + def _get_closed_form_lr(self) -> List[float]: + """ + Called when epoch is passed as a param to the `step` function of the scheduler. + """ + if self.last_epoch < self.warmup_epochs: + return [ + self.warmup_start_lr + self.last_epoch * (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) + for base_lr in self.base_lrs + ] + + return [ + self.eta_min + + 0.5 + * (base_lr - self.eta_min) + * (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) + for base_lr in self.base_lrs + ] diff --git a/SwinMM/Pretrain/run.sh b/SwinMM/Pretrain/run.sh new file mode 100644 index 00000000..2f1883a3 --- /dev/null +++ b/SwinMM/Pretrain/run.sh @@ -0,0 +1,10 @@ +python -m torch.distributed.launch --nproc_per_node=8 --master_port=11223 main.py \ + --batch_size=2 \ + --num_steps=30000 \ + --lrdecay \ + --eval_num=500 \ + --lr=5e-4 \ + --decay=0.1 \ + --norm_pix_loss \ + --redis_ports 39996 39997 39998 39999 \ + --redis_compression zlib diff --git a/SwinMM/Pretrain/utils/__init__.py b/SwinMM/Pretrain/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/SwinMM/Pretrain/utils/data_utils.py b/SwinMM/Pretrain/utils/data_utils.py new file mode 100644 index 00000000..35fab4d1 --- /dev/null +++ b/SwinMM/Pretrain/utils/data_utils.py @@ -0,0 +1,176 @@ +# Copyright 2020 - 2022 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import timm.data +from utils import dataset_in_memory + +from monai.data import DataLoader, Dataset, DistributedSampler, load_decathlon_datalist +from monai.data.utils import list_data_collate +from monai.transforms import ( + AddChanneld, + Compose, + LoadImaged, + Orientationd, + RandSpatialCropSamplesd, + ScaleIntensityRanged, + SpatialPadd, + ToTensord, +) + + +def get_loader(args): + splits0 = "/dataset00_BTCV.json" + # splits1 = "/dataset01_BrainTumour.json" + splits2 = "/dataset02_Heart.json" + splits3 = "/dataset03_Liver.json" + splits4 = "/dataset04_Hippocampus.json" + # splits5 = "/dataset05_Prostate.json" + splits6 = "/dataset06_Lung.json" + splits7 = "/dataset07_Pancreas.json" + splits8 = "/dataset08_HepaticVessel.json" + splits9 = "/dataset09_Spleen.json" + splits10 = "/dataset10_Colon.json" + splits11 = "/dataset11_TCIAcovid19.json" + splits12 = "/dataset12_WORD.json" + splits13 = "/dataset13_AbdomenCT-1K.json" + splits14 = "/dataset_HNSCC.json" + splits15 = "/dataset_TCIAcolon.json" + splits16 = "/dataset_LIDC.json" + + list_dir = "./jsons" + jsonlist0 = list_dir + splits0 + # jsonlist1 = list_dir + splits1 + jsonlist2 = list_dir + splits2 + jsonlist3 = list_dir + splits3 + jsonlist4 = list_dir + splits4 + # jsonlist5 = list_dir + splits5 + jsonlist6 = list_dir + splits6 + jsonlist7 = list_dir + splits7 + jsonlist8 = list_dir + splits8 + jsonlist9 = list_dir + splits9 + jsonlist10 = list_dir + splits10 + jsonlist11 = list_dir + splits11 + jsonlist12 = list_dir + splits12 + jsonlist13 = list_dir + splits13 + jsonlist14 = list_dir + splits14 + jsonlist15 = list_dir + splits15 + jsonlist16 = list_dir + splits16 + + datadir0 = "./dataset/dataset00_BTCV" + # datadir1 = "./dataset/dataset01_BrainTumour" + datadir2 = "./dataset/dataset02_Heart" + datadir3 = "./dataset/dataset03_Liver" + datadir4 = "./dataset/dataset04_Hippocampus" + # datadir5 = "./dataset/dataset05_Prostate" + datadir6 = "./dataset/dataset06_Lung" + datadir7 = "./dataset/dataset07_Pancreas" + datadir8 = "./dataset/dataset08_HepaticVessel" + datadir9 = "./dataset/dataset09_Spleen" + datadir10 = "./dataset/dataset10_Colon" + datadir11 = "./dataset/dataset11_TCIAcovid19" + datadir12 = "./dataset/dataset12_WORD" + datadir13 = "./dataset/dataset13_AbdomenCT-1K" + datadir14 = "./dataset/dataset_HNSCC" + datadir15 = "./dataset/dataset_TCIAcolon" + datadir16 = "./dataset/dataset_LIDC" + + datalist = [] + for json_path, base_dir in zip( + [ + jsonlist0, + jsonlist2, + jsonlist3, + jsonlist4, + jsonlist6, + jsonlist7, + jsonlist8, + jsonlist9, + jsonlist10, + jsonlist11, + jsonlist12, + jsonlist13, + # jsonlist14, + # jsonlist15, + # jsonlist16, + ], + [ + datadir0, + datadir2, + datadir3, + datadir4, + datadir6, + datadir7, + datadir8, + datadir9, + datadir10, + datadir11, + datadir12, + datadir13, + # datadir14, + # datadir15, + # datadir16, + ], + ): + datalist_i = load_decathlon_datalist(json_path, False, "training", base_dir=base_dir) + datalist.extend([{"image": item["image"]} for item in datalist_i]) + + print("Dataset all training: number of data: {}".format(len(datalist))) + + train_transforms = Compose( + [ + LoadImaged(keys=["image"]), + AddChanneld(keys=["image"]), + Orientationd(keys=["image"], axcodes="RAS"), + ScaleIntensityRanged( + keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True + ), + SpatialPadd(keys="image", spatial_size=[args.roi_x, args.roi_y, args.roi_z]), + RandSpatialCropSamplesd( + keys=["image"], + roi_size=[args.roi_x, args.roi_y, args.roi_z], + num_samples=args.sw_batch_size, + random_center=True, + random_size=False, + ), + ToTensord(keys=["image"]), + ] + ) + + if args.use_normal_dataset: + train_ds = Dataset(data=datalist, transform=train_transforms) + else: + train_ds = dataset_in_memory.CachedDataset( + data=datalist, + transform=train_transforms, + dataset_name="pretrain_train", + hosts=[{"host": "localhost", "port": str(port)} for port in args.redis_ports], + cluster_mode=True, + capacity_per_node=200 * 1024 * 1024 * 1024, + writer_buffer_size=0, # Disable write buffer + ) + + if args.distributed: + train_sampler = DistributedSampler(dataset=train_ds, even_divisible=True, shuffle=True) + else: + train_sampler = None + loader_class = DataLoader + if not args.nouse_multi_epochs_loader: + loader_class = timm.data.loader.MultiEpochsDataLoader + train_loader = loader_class( + train_ds, + batch_size=args.batch_size, + num_workers=args.workers, + sampler=train_sampler, + drop_last=True, + collate_fn=list_data_collate, + ) + + return train_loader, None diff --git a/SwinMM/Pretrain/utils/dataset_in_memory.py b/SwinMM/Pretrain/utils/dataset_in_memory.py new file mode 100644 index 00000000..c2d83497 --- /dev/null +++ b/SwinMM/Pretrain/utils/dataset_in_memory.py @@ -0,0 +1 @@ +# ../../WORD/utils/dataset_in_memory.py diff --git a/SwinMM/Pretrain/utils/ops.py b/SwinMM/Pretrain/utils/ops.py new file mode 100644 index 00000000..133a5ac7 --- /dev/null +++ b/SwinMM/Pretrain/utils/ops.py @@ -0,0 +1,30 @@ +from typing import Tuple + +import numpy as np +import torch + + +def mask_rand_patch( + window_sizes: Tuple[int, int, int], input_sizes: Tuple[int, int, int], mask_ratio: float, samples: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """Patch-wise random masking.""" + if len(window_sizes) != len(input_sizes) or any( + [input_size % window_size != 0 for window_size, input_size in zip(window_sizes, input_sizes)] + ): + raise ValueError(f"{window_sizes} & {input_sizes} is not compatible.") + + mask_shape = [input_size // window_size for input_size, window_size in zip(input_sizes, window_sizes)] + num_patches = np.prod(mask_shape).item() + mask = np.ones(num_patches, dtype=bool) + indices = np.random.choice(num_patches, round(num_patches * mask_ratio), replace=False) + mask[indices] = False + mask = mask.reshape(mask_shape) + wh, ww, wd = window_sizes + mask = np.logical_or(mask[:, None, :, None, :, None], np.zeros([1, wh, 1, ww, 1, wd], dtype=bool)).reshape( + input_sizes + ) + mask = torch.from_numpy(mask).to(samples.device) + + res = samples.detach().clone() + res[:, :, mask] = 0 + return res, mask diff --git a/SwinMM/Pretrain/utils/view_ops.py b/SwinMM/Pretrain/utils/view_ops.py new file mode 100644 index 00000000..6f106bd3 --- /dev/null +++ b/SwinMM/Pretrain/utils/view_ops.py @@ -0,0 +1,18 @@ +"""View operations.""" + +from typing import Tuple + +import numpy as np +import torch +from utils import view_transforms + + +def rot_rand(xs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + img_n = xs.size()[0] + x_aug = xs.detach().clone() + x_rot = torch.zeros(img_n, dtype=torch.int64, device=xs.device) + for i in range(img_n): + orientation = np.random.randint(0, 4) + x_aug[i] = view_transforms.rotation_transforms[orientation](xs[i].unsqueeze(0)) + x_rot[i] = orientation + return x_aug, x_rot diff --git a/SwinMM/Pretrain/utils/view_transforms.py b/SwinMM/Pretrain/utils/view_transforms.py new file mode 100644 index 00000000..caa29114 --- /dev/null +++ b/SwinMM/Pretrain/utils/view_transforms.py @@ -0,0 +1 @@ +# ../../WORD/utils/view_transforms.py diff --git a/SwinMM/README.md b/SwinMM/README.md new file mode 100644 index 00000000..374703c3 --- /dev/null +++ b/SwinMM/README.md @@ -0,0 +1,84 @@ +# SwinMM: Masked Multi-view with Swin Transformers for 3D Medical Image Segmentation + +

+ +

+ +## What is SwinMM? + +Masked Multi-view with Swin Transformers, dubbed SwinMM, is the first comprehensive multi-view pipeline for self-supervised medical image analysis. SwinMM yields competitive performance, significantly lower training costs, and higher data efficiency compared to recent state-of-the-art models. SwinMM consists of two key components. + +### Pretrain + +In the pre-training stage, we introduce a masked multi-view encoder that simultaneously trains masked multi-view observations with a diverse set of proxy tasks. These tasks include image reconstruction, rotation, contrastive learning, and a mutual learning paradigm that comprehensively leverages hidden multi-view information from 3D medical data by maximizing the consistency between predictions from different views. + +

+ +

+ +### Finetune + +In the fine-tuning stage, a cross-view decoder is developed to aggregate the multi-view information using a novel cross-view attention block. + +

+ +

+ +## Pre-trained Models + +We present two checkpoints here: + +- [pretrained_ckpt.pt](https://drive.google.com/file/d/1VFT1Oz5UGjAaXLdWAAdeD0mVeLyCQ0ry/view?usp=sharing) +- [WORD_finetuned_ckpt.pt](https://drive.google.com/file/d/1VFT1Oz5UGjAaXLdWAAdeD0mVeLyCQ0ry/view?usp=sharing) + +Here is the sample testing result on WORD + +

+ +

+ +## Installation + +Please check [INSTALL.md](INSTALL.md) for installation instructions. + +## Evaluation + +Testing can be done using the following scripts. Please change `pretrained_dir` and `pretrained_model_name` according to the path of the checkpoint you would like to test, and change `data_dir` and `json_list` according to the datasets. + +```bash +cd WORD +python test_parrallel.py --pretrained_dir ./runs/multiview_101616/ \ + --pretrained_model_name model.pt \ + --distributed \ + --data_dir ./dataset/dataset12_WORD/ \ + --json_list dataset12_WORD.json +``` + +## Training + +Please check [TRAINING.md](TRAINING.md) for training instructions. + +## Acknowledgment + +This work is partially supported by Google Cloud Research Credits program. +This Repo is based on [SwinUNETR](https://github.com/Project-MONAI/research-contributions/tree/main/SwinUNETR), [MONAI](https://monai.io/) and [bagua](https://github.com/BaguaSys/bagua). + +## Citation + +If you find this repository helpful, please consider citing: + +``` +@inproceedings{wang2023SwinMM, + title = {SwinMM: Masked Multi-view with Swin Transformers for 3D Medical Image Segmentation}, + author = {Wang, Yiqing and Li, Zihan and Mei, Jieru and Wei, Zihao and Liu, Li and Wang, Chen and Sang, Shengtian and Yuille, Alan and Xie, Cihang and Zhou, Yuyin}, + booktitle = {MICCAI}, + year = {2023} +} + +@article{cardoso2022monai, + title={Monai: An open-source framework for deep learning in healthcare}, + author={Cardoso, M Jorge and Li, Wenqi and Brown, Richard and Ma, Nic and Kerfoot, Eric and Wang, Yiheng and Murrey, Benjamin and Myronenko, Andriy and Zhao, Can and Yang, Dong and others}, + journal={arXiv preprint arXiv:2211.02701}, + year={2022} +} +``` diff --git a/SwinMM/TRAINING.md b/SwinMM/TRAINING.md new file mode 100644 index 00000000..58d7e4ef --- /dev/null +++ b/SwinMM/TRAINING.md @@ -0,0 +1,33 @@ +# Training + +## Launch Redis + +Launch in-memory database, only need once + +```bash +# It launches redis at ports 39996-39999. +bash ./scripts/start_redis.sh +``` + +**NOTE** + +- If **data or preprocessing** changed, run `pkill redis-server` before further experiments +- Try `--workers` from 8 to 32 for best performance +- First epoch after launch the server could be slow, but should be fast later +- Set `--redis_ports ` according to your redis setup. + +## Pre-training + +```bash +cd Pretrain +bash run.sh +``` + +## Finetuning + +- Prepare pretrained models: Copy the pretrained model to the `pretrained_models` directory in BTCV + +```bash +cd WORD +bash run.sh +``` diff --git a/SwinMM/WORD/README.md b/SwinMM/WORD/README.md new file mode 100644 index 00000000..2089dbb1 --- /dev/null +++ b/SwinMM/WORD/README.md @@ -0,0 +1,42 @@ +# Note for SwinMM Finetuning + +## Training + +### FIXME(outdated) + +```bash +python main.py +--feature_size=48 +--batch_size=1 +--logdir="swin_mm_test/" +--roi_x=64 +--roi_y=64 +--roi_z=64 +--optim_lr=1e-4 +--lrschedule="warmup_cosine" +--infer_overlap=0.5 +--save_checkpoint +--data_dir="/dataset/dataset0/" +--distributed +--use_ssl_pretrained +--pretrained_dir="./pretrained_models/" +--pretrained_model_name="model_bestValRMSE.pt" +``` + +## Testing + +### FIXME(outdated) + +```bash +python test.py +--feature_size=48 +--batch_size=1 +--exp_name="swin_mm_test/" +--roi_x=64 +--roi_y=64 +--roi_z=64 +--infer_overlap=0.5 +--data_dir="/dataset/dataset0/" +--pretrained_dir="./runs/multiview_101021/" +--pretrained_model_name="model.pt" +``` diff --git a/SwinMM/WORD/dataset/Download Finetune Jsons Here b/SwinMM/WORD/dataset/Download Finetune Jsons Here new file mode 100644 index 00000000..e69de29b diff --git a/SwinMM/WORD/dataset/__init__.py b/SwinMM/WORD/dataset/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/SwinMM/WORD/inferers.py b/SwinMM/WORD/inferers.py new file mode 100644 index 00000000..c943ca13 --- /dev/null +++ b/SwinMM/WORD/inferers.py @@ -0,0 +1,292 @@ +"""Multiview inferer.""" + +import warnings +from typing import Any, Callable, Dict, List, Mapping, Sequence, Tuple, Union + +import torch +import torch.nn.functional as F +from utils import view_ops, view_transforms + +from monai.data.utils import compute_importance_map, dense_patch_slices, get_valid_patch_size +from monai.inferers.utils import _get_scan_interval +from monai.transforms import Resize +from monai.utils import ( + BlendMode, + PytorchPadMode, + convert_data_type, + ensure_tuple, + fall_back_tuple, + look_up_option, + optional_import, +) + +tqdm, _ = optional_import("tqdm", name="tqdm") + + +def double_sliding_window_inference( + inputs: torch.Tensor, + view: int, + roi_size: Union[Sequence[int], int], + sw_batch_size: int, + predictor: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor], Dict[Any, torch.Tensor]]], + overlap: float = 0.25, + mode: Union[BlendMode, str] = BlendMode.CONSTANT, + sigma_scale: Union[Sequence[float], float] = 0.125, + padding_mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT, + cval: float = 0.0, + sw_device: Union[torch.device, str, None] = None, + device: Union[torch.device, str, None] = None, + progress: bool = False, + roi_weight_map: Union[torch.Tensor, None] = None, + *args: Any, + **kwargs: Any, +) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[Any, torch.Tensor]]: + """ + Sliding window inference on two `inputs` with `predictor`. + + The outputs of `predictor` could be a tensor, a tuple, or a dictionary of tensors. + Each output in the tuple or dict value is allowed to have different resolutions with respect to the input. + e.g., the input patch spatial size is [128,128,128], the output (a tuple of two patches) patch sizes + could be ([128,64,256], [64,32,128]). + In this case, the parameter `overlap` and `roi_size` need to be carefully chosen to ensure the output ROI is still + an integer. If the predictor's input and output spatial sizes are not equal, we recommend choosing the parameters + so that `overlap*roi_size*output_size/input_size` is an integer (for each spatial dimension). + + When roi_size is larger than the inputs' spatial size, the input image are padded during inference. + To maintain the same spatial sizes, the output image will be cropped to the original input size. + + Args: + inputs: input image to be processed (assuming NCHW[D]) + roi_size: the spatial window size for inferences. + When its components have None or non-positives, the corresponding inputs dimension will be used. + if the components of the `roi_size` are non-positive values, the transform will use the + corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted + to `(32, 64)` if the second spatial dimension size of img is `64`. + sw_batch_size: the batch size to run window slices. + predictor: given input tensor ``patch_data`` in shape NCHW[D], + The outputs of the function call ``predictor(patch_data)`` should be a tensor, a tuple, or a dictionary + with Tensor values. Each output in the tuple or dict value should have the same batch_size, i.e. NM'H'W'[D']; + where H'W'[D'] represents the output patch's spatial size, M is the number of output channels, + N is `sw_batch_size`, e.g., the input shape is (7, 1, 128,128,128), + the output could be a tuple of two tensors, with shapes: ((7, 5, 128, 64, 256), (7, 4, 64, 32, 128)). + In this case, the parameter `overlap` and `roi_size` need to be carefully chosen + to ensure the scaled output ROI sizes are still integers. + If the `predictor`'s input and output spatial sizes are different, + we recommend choosing the parameters so that ``overlap*roi_size*zoom_scale`` is an integer for each dimension. + overlap: Amount of overlap between scans. + mode: {``"constant"``, ``"gaussian"``} + How to blend output of overlapping windows. Defaults to ``"constant"``. + + - ``"constant``": gives equal weight to all predictions. + - ``"gaussian``": gives less weight to predictions on edges of windows. + + sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``"gaussian"``. + Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``. + When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding + spatial dimensions. + padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``} + Padding mode for ``inputs``, when ``roi_size`` is larger than inputs. Defaults to ``"constant"`` + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + cval: fill value for 'constant' padding mode. Default: 0 + sw_device: device for the window data. + By default the device (and accordingly the memory) of the `inputs` is used. + Normally `sw_device` should be consistent with the device where `predictor` is defined. + device: device for the stitched output prediction. + By default the device (and accordingly the memory) of the `inputs` is used. If for example + set to device=torch.device('cpu') the gpu memory consumption is less and independent of the + `inputs` and `roi_size`. Output is on the `device`. + progress: whether to print a `tqdm` progress bar. + roi_weight_map: pre-computed (non-negative) weight map for each ROI. + If not given, and ``mode`` is not `constant`, this map will be computed on the fly. + args: optional args to be passed to ``predictor``. + kwargs: optional keyword args to be passed to ``predictor``. + + Note: + - input must be channel-first and have a batch dim, supports N-D sliding window. + + """ + compute_dtype = inputs.dtype + num_spatial_dims = len(inputs.shape) - 2 + if overlap < 0 or overlap >= 1: + raise ValueError("overlap must be >= 0 and < 1.") + + # determine image spatial size and batch size + # Note: all input images must have the same image size and batch size + batch_size, _, *image_size_ = inputs.shape + + if device is None: + device = inputs.device + if sw_device is None: + sw_device = inputs.device + + roi_size = fall_back_tuple(roi_size, image_size_) + # in case that image size is smaller than roi size + image_size = tuple(max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims)) + pad_size = [] + for k in range(len(inputs.shape) - 1, 1, -1): + diff = max(roi_size[k - 2] - inputs.shape[k], 0) + half = diff // 2 + pad_size.extend([half, diff - half]) + inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode).value, value=cval) + # inputs2 = F.pad(inputs2, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode).value, value=cval) + + scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap) + + # Store all slices in list + slices = dense_patch_slices(image_size, roi_size, scan_interval) + num_win = len(slices) # number of windows per image + total_slices = num_win * batch_size # total number of windows + + # Create window-level importance map + valid_patch_size = get_valid_patch_size(image_size, roi_size) + if valid_patch_size == roi_size and (roi_weight_map is not None): + importance_map = roi_weight_map + else: + try: + importance_map = compute_importance_map(valid_patch_size, mode=mode, sigma_scale=sigma_scale, device=device) + except BaseException as e: + raise RuntimeError( + "Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'." + ) from e + importance_map = convert_data_type(importance_map, torch.Tensor, device, compute_dtype)[0] # type: ignore + # handle non-positive weights + min_non_zero = max(importance_map[importance_map != 0].min().item(), 1e-3) + importance_map = torch.clamp(importance_map.to(torch.float32), min=min_non_zero).to(compute_dtype) + + # Perform predictions + dict_key, output_image_list_1, output_image_list_2, count_map_list = None, [], [], [] + _initialized_ss = -1 + is_tensor_output = True # whether the predictor's output is a tensor (instead of dict/tuple) + + # for each patch + for slice_g in tqdm(range(0, total_slices, sw_batch_size)) if progress else range(0, total_slices, sw_batch_size): + slice_range = range(slice_g, min(slice_g + sw_batch_size, total_slices)) + unravel_slice = [ + [slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)] + list(slices[idx % num_win]) + for idx in slice_range + ] + window_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device) + view_list = [view, (view + 1) % len(view_transforms.permutation_transforms)] + window_data_list = [view_ops.get_permute_transform(0, dst)(window_data) for dst in view_list] + # window_data_2 = torch.cat([inputs2[win_slice] for win_slice in unravel_slice]).to(sw_device) + seg_prob_out_1, seg_prob_out_2 = predictor( + window_data_list[0], window_data_list[1], view_list, *args, **kwargs + ) # batched patch segmentation + seg_prob_out_1, seg_prob_out_2 = view_ops.permute_inverse([seg_prob_out_1, seg_prob_out_2], view_list) + + # convert seg_prob_out to tuple seg_prob_tuple, this does not allocate new memory. + seg_prob_tuple_1: Tuple[torch.Tensor, ...] + seg_prob_tuple_2: Tuple[torch.Tensor, ...] + if isinstance(seg_prob_out_1, torch.Tensor): + seg_prob_tuple_1 = (seg_prob_out_1,) + seg_prob_tuple_2 = (seg_prob_out_2,) + elif isinstance(seg_prob_out_1, Mapping): + if dict_key is None: + dict_key = sorted(seg_prob_out_1.keys()) # track predictor's output keys + seg_prob_tuple_1 = tuple(seg_prob_out_1[k] for k in dict_key) + seg_prob_tuple_2 = tuple(seg_prob_out_2[k] for k in dict_key) + is_tensor_output = False + else: + seg_prob_tuple_1 = ensure_tuple(seg_prob_out_1) + seg_prob_tuple_2 = ensure_tuple(seg_prob_out_2) + is_tensor_output = False + + # for each output in multi-output list + for ss in range(len(seg_prob_tuple_1)): + seg_prob_1 = seg_prob_tuple_1[ss].to(device) # BxCxMxNxP or BxCxMxN + seg_prob_2 = seg_prob_tuple_2[ss].to(device) + + # compute zoom scale: out_roi_size/in_roi_size + zoom_scale = [] + for axis, (img_s_i, out_w_i, in_w_i) in enumerate( + zip(image_size, seg_prob_1.shape[2:], window_data.shape[2:]) + ): + _scale = out_w_i / float(in_w_i) + if not (img_s_i * _scale).is_integer(): + warnings.warn( + f"For spatial axis: {axis}, output[{ss}] will have non-integer shape. Spatial " + f"zoom_scale between output[{ss}] and input is {_scale}. Please pad inputs." + ) + zoom_scale.append(_scale) + + if _initialized_ss < ss: # init. the ss-th buffer at the first iteration + # construct multi-resolution outputs + output_classes = seg_prob_1.shape[1] + output_shape = [batch_size, output_classes] + [ + int(image_size_d * zoom_scale_d) for image_size_d, zoom_scale_d in zip(image_size, zoom_scale) + ] + # allocate memory to store the full output and the count for overlapping parts + output_image_list_1.append(torch.zeros(output_shape, dtype=compute_dtype, device=device)) + output_image_list_2.append(torch.zeros(output_shape, dtype=compute_dtype, device=device)) + count_map_list.append(torch.zeros([1, 1] + output_shape[2:], dtype=compute_dtype, device=device)) + _initialized_ss += 1 + + # resizing the importance_map + resizer = Resize(spatial_size=seg_prob_1.shape[2:], mode="nearest", anti_aliasing=False) + + # store the result in the proper location of the full output. Apply weights from importance map. + for idx, original_idx in zip(slice_range, unravel_slice): + # zoom roi + original_idx_zoom = list(original_idx) # 4D for 2D image, 5D for 3D image + for axis in range(2, len(original_idx_zoom)): + zoomed_start = original_idx[axis].start * zoom_scale[axis - 2] + zoomed_end = original_idx[axis].stop * zoom_scale[axis - 2] + if not zoomed_start.is_integer() or (not zoomed_end.is_integer()): + warnings.warn( + f"For axis-{axis-2} of output[{ss}], the output roi range is not int. " + f"Input roi range is ({original_idx[axis].start}, {original_idx[axis].stop}). " + f"Spatial zoom_scale between output[{ss}] and input is {zoom_scale[axis - 2]}. " + f"Corresponding output roi range is ({zoomed_start}, {zoomed_end}).\n" + f"Please change overlap ({overlap}) or roi_size ({roi_size[axis-2]}) for axis-{axis-2}. " + "Tips: if overlap*roi_size*zoom_scale is an integer, it usually works." + ) + original_idx_zoom[axis] = slice(int(zoomed_start), int(zoomed_end), None) + importance_map_zoom = resizer(importance_map.unsqueeze(0))[0].to(compute_dtype) + # store results and weights + output_image_list_1[ss][original_idx_zoom] += importance_map_zoom * seg_prob_1[idx - slice_g] + output_image_list_2[ss][original_idx_zoom] += importance_map_zoom * seg_prob_2[idx - slice_g] + count_map_list[ss][original_idx_zoom] += ( + importance_map_zoom.unsqueeze(0).unsqueeze(0).expand(count_map_list[ss][original_idx_zoom].shape) + ) + + # account for any overlapping sections + for ss in range(len(output_image_list_1)): + count_map_pop = count_map_list.pop(0) + output_image_list_1[ss] = (output_image_list_1[ss] / count_map_pop).to(compute_dtype) + output_image_list_2[ss] = (output_image_list_2[ss] / count_map_pop).to(compute_dtype) + + # remove padding if image_size smaller than roi_size + for ss in range(len(output_image_list_1)): + output_i_1, output_i_2 = output_image_list_1[ss], output_image_list_2[ss] + if torch.isnan(output_i_1).any() or torch.isinf(output_i_1).any(): + warnings.warn("Sliding window inference results contain NaN or Inf.") + if torch.isnan(output_i_2).any() or torch.isinf(output_i_2).any(): + warnings.warn("Sliding window inference results contain NaN or Inf.") + + zoom_scale = [ + seg_prob_map_shape_d / roi_size_d + for seg_prob_map_shape_d, roi_size_d in zip(output_i_1.shape[2:], roi_size) + ] + + final_slicing: List[slice] = [] + for sp in range(num_spatial_dims): + slice_dim = slice(pad_size[sp * 2], image_size_[num_spatial_dims - sp - 1] + pad_size[sp * 2]) + slice_dim = slice( + int(round(slice_dim.start * zoom_scale[num_spatial_dims - sp - 1])), + int(round(slice_dim.stop * zoom_scale[num_spatial_dims - sp - 1])), + ) + final_slicing.insert(0, slice_dim) + while len(final_slicing) < len(output_i_1.shape): + final_slicing.insert(0, slice(None)) + output_image_list_1[ss] = output_i_1[final_slicing] + output_image_list_2[ss] = output_i_2[final_slicing] + + if dict_key is not None: # if output of predictor is a dict + final_output_1 = dict(zip(dict_key, output_image_list_1)) + final_output_2 = dict(zip(dict_key, output_image_list_2)) + else: + final_output_1 = tuple(output_image_list_1) # type: ignore + final_output_2 = tuple(output_image_list_2) # type: ignore + final_output_1 = final_output_1[0] if is_tensor_output else final_output_1 # type: ignore + final_output_2 = final_output_2[0] if is_tensor_output else final_output_2 # type: ignore + return final_output_1, final_output_2 diff --git a/SwinMM/WORD/main.py b/SwinMM/WORD/main.py new file mode 100644 index 00000000..6d5872b4 --- /dev/null +++ b/SwinMM/WORD/main.py @@ -0,0 +1,292 @@ +# Copyright 2020 - 2022 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import os +from functools import partial + +import numpy as np +import timm.optim.optim_factory as optim_factory +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn.parallel +import torch.utils.data.distributed +from inferers import double_sliding_window_inference +from models import SwinUNETR +from optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR +from timm.utils import setup_default_logging +from torch.nn import KLDivLoss +from trainer import run_training +from utils.data_utils import get_loader +from utils.dataset_in_memory import hijack_bagua_serialization + +from monai.losses import DiceCELoss +from monai.metrics import DiceMetric +from monai.transforms import AsDiscrete +from monai.utils.enums import MetricReduction + +parser = argparse.ArgumentParser(description="Swin UNETR segmentation pipeline") +parser.add_argument("--checkpoint", default=None, help="start training from saved checkpoint") +parser.add_argument("--logdir", default="multiview_101616/", type=str, help="directory to save the tensorboard logs") +parser.add_argument( + "--pretrained_dir", default="./pretrained_models/", type=str, help="pretrained checkpoint directory" +) +parser.add_argument("--data_dir", default="./dataset/dataset12_WORD/", type=str, help="dataset directory") +parser.add_argument("--json_list", default="dataset12_WORD.json", type=str, help="dataset json file") +parser.add_argument("--pretrained_model_name", default="model_bestValRMSE.pt", type=str, help="pretrained model name") +parser.add_argument("--save_checkpoint", action="store_true", help="save checkpoint during training") +parser.add_argument("--max_epochs", default=1500, type=int, help="max number of training epochs") +parser.add_argument("--batch_size", default=1, type=int, help="number of batch size") +parser.add_argument("--sw_batch_size", default=16, type=int, help="number of sliding window batch size") +parser.add_argument("--optim_lr", default=1e-4, type=float, help="optimization learning rate") +parser.add_argument("--optim_name", default="adamw", type=str, help="optimization algorithm") +parser.add_argument("--reg_weight", default=1e-5, type=float, help="regularization weight") +parser.add_argument("--layer_decay", default=1.0, type=float, help="layer-wise learning rate decay") +parser.add_argument("--momentum", default=0.99, type=float, help="momentum") +parser.add_argument("--noamp", action="store_true", help="do NOT use amp for training") +parser.add_argument("--val_every", default=100, type=int, help="validation frequency") +parser.add_argument("--val_start", default=1000, type=int, help="val start from epoch") +parser.add_argument("--unsuper_every", default=1, type=int, help="unsupervised training frequency") +parser.add_argument("--unsuper_start", default=100, type=int, help="unsupervised training frequency") +parser.add_argument("--unsupervised", action="store_true", help="start unsupervised training") +parser.add_argument("--distributed", action="store_true", help="start distributed training") +parser.add_argument("--world_size", default=1, type=int, help="number of nodes for distributed training") +parser.add_argument("--rank", default=0, type=int, help="node rank for distributed training") +parser.add_argument("--dist-url", default="tcp://127.0.0.1:23456", type=str, help="distributed url") +parser.add_argument("--dist-backend", default="nccl", type=str, help="distributed backend") +parser.add_argument("--workers", default=4, type=int, help="number of workers") +parser.add_argument("--feature_size", default=48, type=int, help="feature size") +parser.add_argument("--in_channels", default=1, type=int, help="number of input channels") +parser.add_argument("--out_channels", default=17, type=int, help="number of output channels") +parser.add_argument("--use_normal_dataset", action="store_true", help="use monai Dataset class") +parser.add_argument("--use_normal_dataset_val", action="store_true", help="use monai Dataset class for val") +parser.add_argument( + "--nouse_multi_epochs_loader", + action="store_true", + help="not use the multi-epochs-loader to save time at the beginning of every epoch", +) +parser.add_argument("--a_min", default=-175.0, type=float, help="a_min in ScaleIntensityRanged") +parser.add_argument("--a_max", default=250.0, type=float, help="a_max in ScaleIntensityRanged") +parser.add_argument("--b_min", default=0.0, type=float, help="b_min in ScaleIntensityRanged") +parser.add_argument("--b_max", default=1.0, type=float, help="b_max in ScaleIntensityRanged") +parser.add_argument("--space_x", default=1.5, type=float, help="spacing in x direction") +parser.add_argument("--space_y", default=1.5, type=float, help="spacing in y direction") +parser.add_argument("--space_z", default=2.0, type=float, help="spacing in z direction") +parser.add_argument("--roi_x", default=64, type=int, help="roi size in x direction") +parser.add_argument("--roi_y", default=64, type=int, help="roi size in y direction") +parser.add_argument("--roi_z", default=64, type=int, help="roi size in z direction") +parser.add_argument("--dropout_rate", default=0.0, type=float, help="dropout rate") +parser.add_argument("--dropout_path_rate", default=0.0, type=float, help="drop path rate") +parser.add_argument("--RandFlipd_prob", default=0.2, type=float, help="RandFlipd aug probability") +parser.add_argument("--RandRotate90d_prob", default=0.2, type=float, help="RandRotate90d aug probability") +parser.add_argument("--RandScaleIntensityd_prob", default=0.1, type=float, help="RandScaleIntensityd aug probability") +parser.add_argument("--RandShiftIntensityd_prob", default=0.1, type=float, help="RandShiftIntensityd aug probability") +parser.add_argument("--infer_overlap", default=0.5, type=float, help="sliding window inference overlap") +parser.add_argument("--lrschedule", default="warmup_cosine", type=str, help="type of learning rate scheduler") +parser.add_argument("--warmup_epochs", default=50, type=int, help="number of warmup epochs") +parser.add_argument("--resume_ckpt", action="store_true", help="resume training from pretrained checkpoint") +parser.add_argument("--smooth_dr", default=1e-6, type=float, help="constant added to dice denominator to avoid nan") +parser.add_argument("--smooth_nr", default=0.0, type=float, help="constant added to dice numerator to avoid zero") +parser.add_argument("--use_checkpoint", action="store_true", help="use gradient checkpointing to save memory") +parser.add_argument("--use_ssl_pretrained", action="store_true", help="use self-supervised pretrained weights") +parser.add_argument("--spatial_dims", default=3, type=int, help="spatial dimension of input data") +parser.add_argument("--squared_dice", action="store_true", help="use squared Dice") +parser.add_argument("--norm_name", default="batch", help="multi gpu use") +parser.add_argument( + "--cross_attention_in_origin_view", action="store_true", help="Whether compute cross attention in original view" +) +parser.add_argument("--redis_ports", nargs="+", type=int, help="redis ports") +parser.add_argument("--redis_compression", type=str, default=None, help="compression method for redis.") + + +def main(): + args = parser.parse_args() + args.amp = not args.noamp + args.logdir = "./runs/" + args.logdir + if args.distributed: + args.ngpus_per_node = torch.cuda.device_count() + print("Found total gpus", args.ngpus_per_node) + args.world_size = args.ngpus_per_node * args.world_size + mp.spawn(main_worker, nprocs=args.ngpus_per_node, args=(args,)) + else: + main_worker(gpu=0, args=args) + + +def main_worker(gpu, args): + if args.redis_compression is not None: + hijack_bagua_serialization(args.redis_compression) + + if args.distributed: + torch.multiprocessing.set_start_method("fork", force=True) + np.set_printoptions(formatter={"float": "{: 0.3f}".format}, suppress=True) + args.gpu = gpu + if args.distributed: + args.rank = args.rank * args.ngpus_per_node + gpu + dist.init_process_group( + backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank + ) + torch.cuda.set_device(args.gpu) + torch.backends.cudnn.benchmark = True + args.test_mode = False + loader = get_loader(args) + print(args.rank, " gpu", args.gpu) + if args.rank == 0: + setup_default_logging() + logging.info(f"Batch size is: {args.batch_size}, epochs: {args.max_epochs}") + inf_size = [args.roi_x, args.roi_y, args.roi_z] + + pretrained_dir = args.pretrained_dir + model = SwinUNETR( + img_size=(args.roi_x, args.roi_y, args.roi_z), + in_channels=args.in_channels, + out_channels=args.out_channels, + feature_size=args.feature_size, + fusion_depths=(1, 1, 1, 1, 1, 1), + drop_rate=0.0, + attn_drop_rate=0.0, + dropout_path_rate=args.dropout_path_rate, + use_checkpoint=args.use_checkpoint, + cross_attention_in_origin_view=args.cross_attention_in_origin_view, + ) + + if args.resume_ckpt: + model_dict = torch.load(os.path.join(pretrained_dir, args.pretrained_model_name), map_location="cpu")[ + "state_dict" + ] + model.load_state_dict(model_dict) + logging.info("Use pretrained weights") + + if args.use_ssl_pretrained: + try: + model_dict = torch.load(os.path.join(pretrained_dir, args.pretrained_model_name), map_location="cpu") + state_dict = model_dict["state_dict"] + # fix potential differences in state dict keys from pre-training to + # fine-tuning + if "module." in list(state_dict.keys())[0]: + logging.info("Tag 'module.' found in state dict - fixing!") + for key in list(state_dict.keys()): + state_dict[key.replace("module.", "")] = state_dict.pop(key) + if "swin_vit" in list(state_dict.keys())[0]: + logging.info("Tag 'swin_vit' found in state dict - fixing!") + for key in list(state_dict.keys()): + state_dict[key.replace("swin_vit", "swinViT")] = state_dict.pop(key) + # We now load model weights, setting param `strict` to False, i.e.: + # this load the encoder weights (Swin-ViT, SSL pre-trained), but leaves + # the decoder weights untouched (CNN UNet decoder). + model.load_state_dict(state_dict, strict=False) + logging.info("Using pretrained self-supervised Swin UNETR backbone weights !") + except ValueError: + raise ValueError("Self-supervised pre-trained weights not available for" + str(args.model_name)) + + if args.squared_dice: + dice_loss = DiceCELoss( + to_onehot_y=True, softmax=True, squared_pred=True, smooth_nr=args.smooth_nr, smooth_dr=args.smooth_dr + ) + else: + dice_loss = DiceCELoss(to_onehot_y=True, softmax=True) + mutual_loss = KLDivLoss(reduction="mean") # CosineSimilarity(dim = 1) + post_label = AsDiscrete(to_onehot=True, n_classes=args.out_channels) + post_pred = AsDiscrete(argmax=True, to_onehot=True, n_classes=args.out_channels) + dice_acc = DiceMetric(include_background=True, reduction=MetricReduction.MEAN, get_not_nans=True) + model_inferer = partial( + double_sliding_window_inference, + roi_size=inf_size, + sw_batch_size=args.sw_batch_size, + predictor=model, + overlap=args.infer_overlap, + ) + + pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + logging.info(f"Total parameters count: {pytorch_total_params}") + + best_acc = 0 + start_epoch = 0 + + if args.checkpoint is not None: + checkpoint = torch.load(args.checkpoint, map_location="cpu") + from collections import OrderedDict + + new_state_dict = OrderedDict() + for k, v in checkpoint["state_dict"].items(): + new_state_dict[k.replace("backbone.", "")] = v + model.load_state_dict(new_state_dict, strict=False) + if "epoch" in checkpoint: + start_epoch = checkpoint["epoch"] + 1 + if "best_acc" in checkpoint: + best_acc = checkpoint["best_acc"] + logging.info("=> loaded checkpoint '{}' (epoch {}) (bestacc {})".format(args.checkpoint, start_epoch, best_acc)) + + model.cuda(args.gpu) + model_without_ddp = model + if args.distributed: + torch.cuda.set_device(args.gpu) + if args.norm_name == "batch": + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + model.cuda(args.gpu) + model_without_ddp = model + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[args.gpu], output_device=args.gpu, broadcast_buffers=False, find_unused_parameters=True + ) + + # build optimizer with layer-wise lr decay (lrd) + param_groups = optim_factory.param_groups_layer_decay( + model_without_ddp, + weight_decay=args.reg_weight, + no_weight_decay_list=model_without_ddp.no_weight_decay(), + layer_decay=args.layer_decay, + verbose=False, + ) + if args.optim_name == "adam": + optimizer = torch.optim.Adam(param_groups, lr=args.optim_lr) + elif args.optim_name == "adamw": + optimizer = torch.optim.AdamW(param_groups, lr=args.optim_lr) + elif args.optim_name == "sgd": + optimizer = torch.optim.SGD(param_groups, lr=args.optim_lr, momentum=args.momentum, nesterov=True) + else: + raise ValueError("Unsupported Optimization Procedure: " + str(args.optim_name)) + + if args.lrschedule == "warmup_cosine": + scheduler = LinearWarmupCosineAnnealingLR( + optimizer, warmup_epochs=args.warmup_epochs, max_epochs=args.max_epochs + ) + elif args.lrschedule == "cosine_anneal": + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.max_epochs) + if args.checkpoint is not None: + scheduler.step(epoch=start_epoch) + else: + scheduler = None + + unsupervised_loader = None + if args.unsupervised: + unsupervised_loader = loader[2] + accuracy = run_training( + model=model, + train_loader=loader[0], + val_loader=loader[1], + unsupervised_loader=unsupervised_loader, + optimizer=optimizer, + self_crit=dice_loss, + mutual_crit=mutual_loss, + acc_func=dice_acc, + args=args, + model_inferer=model_inferer, + scheduler=scheduler, + start_epoch=start_epoch, + post_label=post_label, + post_pred=post_pred, + ) + return accuracy + + +if __name__ == "__main__": + main() diff --git a/SwinMM/WORD/models/__init__.py b/SwinMM/WORD/models/__init__.py new file mode 100644 index 00000000..a639a968 --- /dev/null +++ b/SwinMM/WORD/models/__init__.py @@ -0,0 +1 @@ +from .swin_unetr import * diff --git a/SwinMM/WORD/models/cross_attention.py b/SwinMM/WORD/models/cross_attention.py new file mode 100644 index 00000000..61ca832f --- /dev/null +++ b/SwinMM/WORD/models/cross_attention.py @@ -0,0 +1,161 @@ +"""TransFusion from TransFusion: Multi-view Divergent Fusion for Medical Image Segmentation with Transformers.""" +import copy +import math +from typing import Sequence, Union + +import numpy as np +import torch +import torch.nn as nn +from utils.view_ops import get_permute_transform, permute_inverse + + +class Attention(nn.Module): + def __init__(self, num_heads=8, hidden_size=768, atte_dropout_rate=0.0): + super(Attention, self).__init__() + # self.vis = vis + self.num_attention_heads = num_heads + self.attention_head_size = int(hidden_size / self.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(hidden_size, self.all_head_size) + self.key = nn.Linear(hidden_size, self.all_head_size) + self.value = nn.Linear(hidden_size, self.all_head_size) + + self.out = nn.Linear(hidden_size, hidden_size) + self.attn_dropout = nn.Dropout(atte_dropout_rate) + self.proj_dropout = nn.Dropout(atte_dropout_rate) + + self.softmax = nn.Softmax(dim=-1) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, x_1, x_2): + mixed_query_layer_1 = self.query(x_1) + mixed_key_layer_1 = self.key(x_1) + mixed_value_layer_1 = self.value(x_1) + query_layer_1 = self.transpose_for_scores(mixed_query_layer_1) + key_layer_1 = self.transpose_for_scores(mixed_key_layer_1) + value_layer_1 = self.transpose_for_scores(mixed_value_layer_1) + mixed_query_layer_2 = self.query(x_2) + mixed_key_layer_2 = self.key(x_2) + mixed_value_layer_2 = self.value(x_2) + query_layer_2 = self.transpose_for_scores(mixed_query_layer_2) + key_layer_2 = self.transpose_for_scores(mixed_key_layer_2) + value_layer_2 = self.transpose_for_scores(mixed_value_layer_2) + + attention_scores_1 = torch.matmul(query_layer_1, key_layer_2.transpose(-1, -2)) + attention_scores_1 = attention_scores_1 / math.sqrt(self.attention_head_size) + attention_probs_1 = self.softmax(attention_scores_1) + # weights_st = attention_probs_st if self.vis else None + attention_probs_1 = self.attn_dropout(attention_probs_1) + context_layer_1 = torch.matmul(attention_probs_1, value_layer_2) + context_layer_1 = context_layer_1.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape_1 = context_layer_1.size()[:-2] + (self.all_head_size,) + context_layer_1 = context_layer_1.view(*new_context_layer_shape_1) + attention_output_1 = self.out(context_layer_1) + attention_output_1 = self.proj_dropout(attention_output_1) + + attention_scores_2 = torch.matmul(query_layer_2, key_layer_1.transpose(-1, -2)) + attention_scores_2 = attention_scores_2 / math.sqrt(self.attention_head_size) + attention_probs_2 = self.softmax(attention_scores_2) + # weights_st = attention_probs_st if self.vis else None + attention_probs_2 = self.attn_dropout(attention_probs_2) + context_layer_2 = torch.matmul(attention_probs_2, value_layer_1) + context_layer_2 = context_layer_2.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape_2 = context_layer_2.size()[:-2] + (self.all_head_size,) + context_layer_2 = context_layer_2.view(*new_context_layer_shape_2) + attention_output_2 = self.out(context_layer_2) + attention_output_2 = self.proj_dropout(attention_output_2) + + return attention_output_1, attention_output_2 + + +class Block(nn.Module): + def __init__(self, hidden_size=768, mlp_dim=1536, dropout_rate=0.5, num_heads=8, atte_dropout_rate=0.0): + super(Block, self).__init__() + + del mlp_dim + del dropout_rate + + self.hidden_size = hidden_size + self.attention_norm = nn.LayerNorm(hidden_size, eps=1e-6) + self.attn = Attention(num_heads=num_heads, hidden_size=hidden_size, atte_dropout_rate=atte_dropout_rate) + + def forward(self, x_1, x_2): + x_1 = self.attention_norm(x_1) + x_2 = self.attention_norm(x_2) + x_1, x_2 = self.attn(x_1, x_2) + return x_1, x_2 + + +class TransFusion(nn.Module): + def __init__( + self, + hidden_size: int = 768, + num_layers: int = 6, + mlp_dim: int = 1536, + dropout_rate: float = 0.5, + num_heads: int = 8, + atte_dropout_rate: float = 0.0, + roi_size: Union[Sequence[int], int] = (64, 64, 64), + scale: int = 16, + cross_attention_in_origin_view: bool = False, + ): + super().__init__() + if isinstance(roi_size, int): + roi_size = [roi_size for _ in range(3)] + self.cross_attention_in_origin_view = cross_attention_in_origin_view + patch_size = (1, 1, 1) + n_patches = ( + (roi_size[0] // patch_size[0] // scale) + * (roi_size[1] // patch_size[1] // scale) + * (roi_size[2] // patch_size[2] // scale) + ) + self.layer = nn.ModuleList() + self.encoder_norm = nn.LayerNorm(hidden_size, eps=1e-6) + self.patch_embeddings = nn.Conv3d( + in_channels=hidden_size, out_channels=hidden_size, kernel_size=patch_size, stride=patch_size + ) + self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, hidden_size)) + self.dropout = nn.Dropout(dropout_rate) + for _ in range(num_layers): + layer = Block( + hidden_size=hidden_size, + mlp_dim=mlp_dim, + dropout_rate=dropout_rate, + num_heads=num_heads, + atte_dropout_rate=atte_dropout_rate, + ) + self.layer.append(copy.deepcopy(layer)) + + def forward(self, x_1, x_2, view_list): + if self.cross_attention_in_origin_view: + x_1, x_2 = permute_inverse([x_1, x_2], view_list) + else: + # Align x_2 to x_1. + x_2 = get_permute_transform(*view_list[::-1])(x_2) + x_1 = self.patch_embeddings(x_1) + x_2 = self.patch_embeddings(x_2) + x_1 = x_1.flatten(2).transpose(-1, -2) + x_2 = x_2.flatten(2).transpose(-1, -2) + x_1 = x_1 + self.position_embeddings + x_2 = x_2 + self.position_embeddings + x_1 = self.dropout(x_1) + x_2 = self.dropout(x_2) + for layer_block in self.layer: + x_1, x_2 = layer_block(x_1, x_2) + x_1 = self.encoder_norm(x_1) + x_2 = self.encoder_norm(x_2) + B, n_patch, hidden = x_1.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden) + l, h, w = int(np.cbrt(n_patch)), int(np.cbrt(n_patch)), int(np.cbrt(n_patch)) + x_1 = x_1.permute(0, 2, 1).contiguous().view(B, hidden, l, h, w) + x_2 = x_2.permute(0, 2, 1).contiguous().view(B, hidden, l, h, w) + if self.cross_attention_in_origin_view: + x_1, x_2 = permute_inverse([x_1, x_2], view_list) + else: + x_2 = get_permute_transform(*view_list)(x_2) + + return x_1, x_2 diff --git a/SwinMM/WORD/models/swin_unetr.py b/SwinMM/WORD/models/swin_unetr.py new file mode 100644 index 00000000..ea7f96fa --- /dev/null +++ b/SwinMM/WORD/models/swin_unetr.py @@ -0,0 +1,128 @@ +"""SwinUNETR with cross attention.""" +from typing import MutableMapping, Sequence, Tuple, Union + +import torch +from models import cross_attention + +from monai.networks import blocks +from monai.networks.nets import swin_unetr + +__all__ = ["SwinUNETR"] + +FeaturesDictType = MutableMapping[str, torch.Tensor] + + +class SwinUNETR(swin_unetr.SwinUNETR): + """SwinUNETR with cross attention.""" + + def __init__( + self, + img_size: Union[Sequence[int], int], + *args, + num_heads: Sequence[int] = (3, 6, 12, 24), + feature_size: int = 24, + norm_name: Union[Tuple, str] = "instance", + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + spatial_dims: int = 3, + fusion_depths: Sequence[int] = (2, 2, 2, 2, 2, 2), + cross_attention_in_origin_view: bool = False, + **kwargs, + ) -> None: + """ + Args: + fusion_depths: TODO(yiqing). + cross_attention_in_origin_view: A bool indicates whether compute cross attention in origin view. + If not, compute cross attention in the view of the first input. + + """ + super().__init__( + img_size, + *args, + num_heads=num_heads, + feature_size=feature_size, + norm_name=norm_name, + spatial_dims=spatial_dims, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + **kwargs, + ) + + self.encoder5 = blocks.UnetrBasicBlock( + spatial_dims=spatial_dims, + in_channels=8 * feature_size, + out_channels=8 * feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.cross_atte6 = cross_attention.TransFusion( + hidden_size=feature_size * 16, + num_layers=fusion_depths[5], + mlp_dim=feature_size * 32, + num_heads=num_heads[3], + dropout_rate=drop_rate, + atte_dropout_rate=attn_drop_rate, + roi_size=img_size, + scale=32, + cross_attention_in_origin_view=cross_attention_in_origin_view, + ) + + def forward_view_encoder(self, x): + """Encode features.""" + x_hiddens = self.swinViT(x, self.normalize) + x_enc0 = self.encoder1(x) + x_enc1 = self.encoder2(x_hiddens[0]) + x_enc2 = self.encoder3(x_hiddens[1]) + x_enc3 = self.encoder4(x_hiddens[2]) + x_enc4 = self.encoder5(x_hiddens[3]) # xa_hidden[3] + x_dec4 = self.encoder10(x_hiddens[4]) + return {"enc0": x_enc0, "enc1": x_enc1, "enc2": x_enc2, "enc3": x_enc3, "enc4": x_enc4, "dec4": x_dec4} + + def forward_view_decoder(self, x_encoded: FeaturesDictType) -> torch.Tensor: + """Decode features.""" + x_dec3 = self.decoder5(x_encoded["dec4"], x_encoded["enc4"]) + x_dec2 = self.decoder4(x_dec3, x_encoded["enc3"]) + x_dec1 = self.decoder3(x_dec2, x_encoded["enc2"]) + x_dec0 = self.decoder2(x_dec1, x_encoded["enc1"]) + x_out = self.decoder1(x_dec0, x_encoded["enc0"]) + x_logits = self.out(x_out) + return x_logits + + def forward_view_cross_attention( + self, xa_encoded: FeaturesDictType, xb_encoded: FeaturesDictType, views: Sequence[int] + ) -> Tuple[FeaturesDictType, FeaturesDictType]: + """Inplace cross attention between views.""" + xa_encoded["dec4"], xb_encoded["dec4"] = self.cross_atte6(xa_encoded["dec4"], xb_encoded["dec4"], views) + return xa_encoded, xb_encoded + + def forward(self, xa: torch.Tensor, xb: torch.Tensor, views: Sequence[int]) -> Sequence[torch.Tensor]: + """Two views forward.""" + xa_encoded = self.forward_view_encoder(xa) + xb_encoded = self.forward_view_encoder(xb) + + xa_encoded, xb_encoded = self.forward_view_cross_attention(xa_encoded, xb_encoded, views) + return [self.forward_view_decoder(val) for val in [xa_encoded, xb_encoded]] + + def no_weight_decay(self): + """Disable weight_decay on specific weights.""" + nwd = {"swinViT.absolute_pos_embed"} + for n, _ in self.named_parameters(): + if "relative_position_bias_table" in n: + nwd.add(n) + return nwd + + def group_matcher(self, coarse=False): + """Layer counting helper, used by timm.""" + return dict( + stem=r"^swinViT\.absolute_pos_embed|patch_embed", # stem and embed + blocks=r"^swinViT\.layers(\d+)\.0" + if coarse + else [ + (r"^swinViT\.layers(\d+)\.0.downsample", (0,)), + (r"^swinViT\.layers(\d+)\.0\.\w+\.(\d+)", None), + (r"^swinViT\.norm", (99999,)), + ], + ) diff --git a/SwinMM/WORD/optimizers/__init__.py b/SwinMM/WORD/optimizers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/SwinMM/WORD/optimizers/lr_scheduler.py b/SwinMM/WORD/optimizers/lr_scheduler.py new file mode 100644 index 00000000..0c352927 --- /dev/null +++ b/SwinMM/WORD/optimizers/lr_scheduler.py @@ -0,0 +1,172 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import warnings +from typing import List + +from torch import nn as nn +from torch.optim import Adam, Optimizer +from torch.optim.lr_scheduler import LambdaLR, _LRScheduler + +__all__ = ["LinearLR", "ExponentialLR"] + + +class _LRSchedulerMONAI(_LRScheduler): + """Base class for increasing the learning rate between two boundaries over a number + of iterations""" + + def __init__(self, optimizer: Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1) -> None: + """ + Args: + optimizer: wrapped optimizer. + end_lr: the final learning rate. + num_iter: the number of iterations over which the test occurs. + last_epoch: the index of last epoch. + Returns: + None + """ + self.end_lr = end_lr + self.num_iter = num_iter + super(_LRSchedulerMONAI, self).__init__(optimizer, last_epoch) + + +class LinearLR(_LRSchedulerMONAI): + """Linearly increases the learning rate between two boundaries over a number of + iterations. + """ + + def get_lr(self): + r = self.last_epoch / (self.num_iter - 1) + return [base_lr + r * (self.end_lr - base_lr) for base_lr in self.base_lrs] + + +class ExponentialLR(_LRSchedulerMONAI): + """Exponentially increases the learning rate between two boundaries over a number of + iterations. + """ + + def get_lr(self): + r = self.last_epoch / (self.num_iter - 1) + return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs] + + +class WarmupCosineSchedule(LambdaLR): + """Linear warmup and then cosine decay. + Based on https://huggingface.co/ implementation. + """ + + def __init__( + self, optimizer: Optimizer, warmup_steps: int, t_total: int, cycles: float = 0.5, last_epoch: int = -1 + ) -> None: + """ + Args: + optimizer: wrapped optimizer. + warmup_steps: number of warmup iterations. + t_total: total number of training iterations. + cycles: cosine cycles parameter. + last_epoch: the index of last epoch. + Returns: + None + """ + self.warmup_steps = warmup_steps + self.t_total = t_total + self.cycles = cycles + super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch) + + def lr_lambda(self, step): + if step < self.warmup_steps: + return float(step) / float(max(1.0, self.warmup_steps)) + progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps)) + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.cycles) * 2.0 * progress))) + + +class LinearWarmupCosineAnnealingLR(_LRScheduler): + def __init__( + self, + optimizer: Optimizer, + warmup_epochs: int, + max_epochs: int, + warmup_start_lr: float = 0.0, + eta_min: float = 0.0, + last_epoch: int = -1, + ) -> None: + """ + Args: + optimizer (Optimizer): Wrapped optimizer. + warmup_epochs (int): Maximum number of iterations for linear warmup + max_epochs (int): Maximum number of iterations + warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0. + eta_min (float): Minimum learning rate. Default: 0. + last_epoch (int): The index of last epoch. Default: -1. + """ + self.warmup_epochs = warmup_epochs + self.max_epochs = max_epochs + self.warmup_start_lr = warmup_start_lr + self.eta_min = eta_min + + super(LinearWarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch) + + def get_lr(self) -> List[float]: + """ + Compute learning rate using chainable form of the scheduler + """ + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", UserWarning + ) + + if self.last_epoch == 0: + return [self.warmup_start_lr] * len(self.base_lrs) + elif self.last_epoch < self.warmup_epochs: + return [ + group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) + for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) + ] + elif self.last_epoch == self.warmup_epochs: + return self.base_lrs + elif (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0: + return [ + group["lr"] + + (base_lr - self.eta_min) * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2 + for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) + ] + + return [ + (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) + / ( + 1 + + math.cos( + math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs) + ) + ) + * (group["lr"] - self.eta_min) + + self.eta_min + for group in self.optimizer.param_groups + ] + + def _get_closed_form_lr(self) -> List[float]: + """ + Called when epoch is passed as a param to the `step` function of the scheduler. + """ + if self.last_epoch < self.warmup_epochs: + return [ + self.warmup_start_lr + self.last_epoch * (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) + for base_lr in self.base_lrs + ] + + return [ + self.eta_min + + 0.5 + * (base_lr - self.eta_min) + * (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) + for base_lr in self.base_lrs + ] diff --git a/SwinMM/WORD/outputs/__init__.py b/SwinMM/WORD/outputs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/SwinMM/WORD/pretrained_models/__init__.py b/SwinMM/WORD/pretrained_models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/SwinMM/WORD/requirements.txt b/SwinMM/WORD/requirements.txt new file mode 100644 index 00000000..cdfb84e3 --- /dev/null +++ b/SwinMM/WORD/requirements.txt @@ -0,0 +1 @@ +timm>=0.6 diff --git a/SwinMM/WORD/run.sh b/SwinMM/WORD/run.sh new file mode 100644 index 00000000..95847627 --- /dev/null +++ b/SwinMM/WORD/run.sh @@ -0,0 +1,9 @@ +python -m torch.distributed.launch --nproc_per_node=8 --master_port=11223 main.py --batch_size=2 \ + --num_steps=30000 \ + --lrdecay \ + --eval_num=500 \ + --lr=5e-4 \ + --decay=0.1 \ + --norm_pix_loss \ + --redis_ports 39996 39997 39998 39999 \ + --redis_compression zlib diff --git a/SwinMM/WORD/runs/__init__.py b/SwinMM/WORD/runs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/SwinMM/WORD/test.py b/SwinMM/WORD/test.py new file mode 100644 index 00000000..98199d25 --- /dev/null +++ b/SwinMM/WORD/test.py @@ -0,0 +1,160 @@ +# Copyright 2020 - 2022 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import nibabel as nib +import numpy as np +import torch +from inferers import double_sliding_window_inference +from models import SwinUNETR +from utils.data_utils import get_loader +from utils.misc import resample_3d + +from monai.metrics import compute_average_surface_distance, compute_hausdorff_distance, compute_meandice +from monai.networks.utils import one_hot +from monai.transforms import Spacing + +parser = argparse.ArgumentParser(description="Swin UNETR segmentation pipeline") +parser.add_argument( + "--pretrained_dir", default="./runs/multiview_101616/", type=str, help="pretrained checkpoint directory" +) +parser.add_argument("--data_dir", default="./dataset/dataset12_WORD/", type=str, help="dataset directory") +parser.add_argument("--exp_name", default="multiview_101616/", type=str, help="experiment name") +parser.add_argument("--json_list", default="dataset12_WORD.json", type=str, help="dataset json file") +parser.add_argument("--pretrained_model_name", default="model.pt", type=str, help="pretrained model name") +parser.add_argument("--feature_size", default=48, type=int, help="feature size") +parser.add_argument("--infer_overlap", default=0.7, type=float, help="sliding window inference overlap") +parser.add_argument("--in_channels", default=1, type=int, help="number of input channels") +parser.add_argument("--out_channels", default=17, type=int, help="number of output channels") +parser.add_argument("--a_min", default=-175.0, type=float, help="a_min in ScaleIntensityRanged") +parser.add_argument("--a_max", default=250.0, type=float, help="a_max in ScaleIntensityRanged") +parser.add_argument("--b_min", default=0.0, type=float, help="b_min in ScaleIntensityRanged") +parser.add_argument("--b_max", default=1.0, type=float, help="b_max in ScaleIntensityRanged") +parser.add_argument("--space_x", default=1.5, type=float, help="spacing in x direction") +parser.add_argument("--space_y", default=1.5, type=float, help="spacing in y direction") +parser.add_argument("--space_z", default=2.0, type=float, help="spacing in z direction") +parser.add_argument("--roi_x", default=64, type=int, help="roi size in x direction") +parser.add_argument("--roi_y", default=64, type=int, help="roi size in y direction") +parser.add_argument("--roi_z", default=64, type=int, help="roi size in z direction") +parser.add_argument("--dropout_rate", default=0.0, type=float, help="dropout rate") +parser.add_argument("--distributed", action="store_true", help="start distributed training") +parser.add_argument("--workers", default=8, type=int, help="number of workers") +parser.add_argument("--RandFlipd_prob", default=0.2, type=float, help="RandFlipd aug probability") +parser.add_argument("--RandRotate90d_prob", default=0.2, type=float, help="RandRotate90d aug probability") +parser.add_argument("--RandScaleIntensityd_prob", default=0.1, type=float, help="RandScaleIntensityd aug probability") +parser.add_argument("--RandShiftIntensityd_prob", default=0.1, type=float, help="RandShiftIntensityd aug probability") +parser.add_argument("--spatial_dims", default=3, type=int, help="spatial dimension of input data") +parser.add_argument("--use_checkpoint", action="store_true", help="use gradient checkpointing to save memory") + +spacing = Spacing(pixdim=(1, 1, 1), mode="nearest") +hd_per = 95 +view = ["Cor1", "Sag2", "Sag1", "Axi2", "Axi1", "Cor2", "Fuse"] + + +def main(): + args = parser.parse_args() + args.test_mode = True + output_directory = "./outputs/" + args.exp_name + if not os.path.exists(output_directory): + os.makedirs(output_directory) + val_loader = get_loader(args) + pretrained_dir = args.pretrained_dir + model_name = args.pretrained_model_name + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + pretrained_pth = os.path.join(pretrained_dir, model_name) + model = SwinUNETR( + img_size=(args.roi_x, args.roi_y, args.roi_z), + in_channels=args.in_channels, + out_channels=args.out_channels, + feature_size=args.feature_size, + fusion_depths=(1, 1, 1, 1, 1, 1), + drop_rate=0.0, + attn_drop_rate=0.0, + use_checkpoint=args.use_checkpoint, + ) + + model_dict = torch.load(pretrained_pth)["state_dict"] + model.load_state_dict(model_dict) + model.eval() + model.to(device) + dice_out = np.zeros((len(val_loader), len(view), args.out_channels - 1)) + hd_out = np.zeros((len(val_loader), len(view), args.out_channels - 1)) + asd_out = np.zeros((len(val_loader), len(view), args.out_channels - 1)) + + with torch.no_grad(): + for id, batch in enumerate(val_loader): + val_inputs, val_labels = (batch["image"].to(device), batch["label"].to(device)) + original_affine = batch["label_meta_dict"]["affine"][0].numpy() + _, _, h, w, d = val_labels.shape + target_shape = (h, w, d) + val_labels = val_labels.cpu().numpy()[0, :, :, :, :] + img_name = batch["image_meta_dict"]["filename_or_obj"][0].split("/")[-1] + print("Inference on case {}".format(img_name)) + torch.cuda.empty_cache() + output_list = [] + val_fuse = 0 + + val_labels = spacing(val_labels, original_affine)[0] + val_labels = np.expand_dims(val_labels, axis=0) + val_labels = one_hot(torch.from_numpy(val_labels), num_classes=args.out_channels, dim=1) + + for i in range(3): + val_outputs_1, val_outputs_2 = double_sliding_window_inference( + val_inputs, + i, + (args.roi_x, args.roi_y, args.roi_z), + 16, + model, + overlap=args.infer_overlap, + mode="gaussian", + ) + + val_outputs_1 = torch.softmax(val_outputs_1, 1).cpu().numpy()[0] + val_outputs_2 = torch.softmax(val_outputs_2, 1).cpu().numpy()[0] + val_fuse = val_fuse + val_outputs_1 + val_outputs_2 + output_list.append(val_outputs_1) + output_list.append(val_outputs_2) + output_list.append(val_fuse) + + for i, output in enumerate(output_list): + output = np.argmax(output, axis=0, keepdims=False) + output = resample_3d(output, target_shape) + target_ornt = nib.orientations.axcodes2ornt(tuple(nib.aff2axcodes(original_affine))) + out_ornt = [[0, 1], [1, 1], [2, 1]] + ornt_transf = nib.orientations.ornt_transform(out_ornt, target_ornt) + output = nib.orientations.apply_orientation(output, ornt_transf) + nib.save( + nib.Nifti1Image(output[::-1, ::-1, :].astype(np.uint8), affine=original_affine), + os.path.join(output_directory, view[i] + "_" + img_name), + ) + output = np.expand_dims(spacing(np.expand_dims(output, axis=(0)), original_affine)[0], axis=0) + output = one_hot(torch.from_numpy(output), num_classes=args.out_channels, dim=1) + print(output.shape, val_labels.shape) + dice_ = compute_meandice(output, val_labels, include_background=False).numpy()[0] + hd_ = compute_hausdorff_distance(output, val_labels, percentile=hd_per).numpy()[0] + asd_ = compute_average_surface_distance(output, val_labels).numpy()[0] + print("{} View Mean Dice: {}".format(view[i], np.mean(dice_))) + print("{} View Mean HD: {}".format(view[i], np.mean(hd_))) + print("{} View Mean ASD: {}".format(view[i], np.mean(asd_))) + dice_out[id, i, :] = dice_ + hd_out[id, i, :] = hd_ + asd_out[id, i, :] = asd_ + + for i in range(len(view)): + print("Overall {} View Mean Dice: {}".format(view[i], np.mean(dice_out[:, i, :], axis=0))) + print("Overall {} View Mean HD: {}".format(view[i], np.mean(hd_out[:, i, :], axis=0))) + print("Overall {} View Mean ASD: {}".format(view[i], np.mean(asd_out[:, i, :], axis=0))) + + +if __name__ == "__main__": + main() diff --git a/SwinMM/WORD/test_parallel.py b/SwinMM/WORD/test_parallel.py new file mode 100644 index 00000000..a1e8a358 --- /dev/null +++ b/SwinMM/WORD/test_parallel.py @@ -0,0 +1,232 @@ +import argparse +import logging +import os + +import nibabel as nib +import numpy as np +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from inferers import double_sliding_window_inference +from models import SwinUNETR +from timm.utils import setup_default_logging +from utils.data_utils import get_loader +from utils.misc import dice, distributed_all_gather, resample_3d + +from monai.metrics import compute_average_surface_distance, compute_hausdorff_distance, compute_meandice +from monai.networks.utils import one_hot +from monai.transforms import Spacing + +parser = argparse.ArgumentParser(description="Swin UNETR segmentation pipeline") +parser.add_argument( + "--pretrained_dir", default="./runs/multiview_101616/", type=str, help="pretrained checkpoint directory" +) +parser.add_argument("--data_dir", default="./dataset/dataset12_WORD/", type=str, help="dataset directory") +parser.add_argument("--exp_name", default="multiview_101616/", type=str, help="experiment name") +parser.add_argument("--json_list", default="dataset12_WORD.json", type=str, help="dataset json file") +parser.add_argument("--pretrained_model_name", default="model.pt", type=str, help="pretrained model name") +parser.add_argument("--feature_size", default=48, type=int, help="feature size") +parser.add_argument("--infer_overlap", default=0.7, type=float, help="sliding window inference overlap") +parser.add_argument("--in_channels", default=1, type=int, help="number of input channels") +parser.add_argument("--out_channels", default=17, type=int, help="number of output channels") +parser.add_argument("--a_min", default=-175.0, type=float, help="a_min in ScaleIntensityRanged") +parser.add_argument("--a_max", default=250.0, type=float, help="a_max in ScaleIntensityRanged") +parser.add_argument("--b_min", default=0.0, type=float, help="b_min in ScaleIntensityRanged") +parser.add_argument("--b_max", default=1.0, type=float, help="b_max in ScaleIntensityRanged") +parser.add_argument("--space_x", default=1.5, type=float, help="spacing in x direction") +parser.add_argument("--space_y", default=1.5, type=float, help="spacing in y direction") +parser.add_argument("--space_z", default=2.0, type=float, help="spacing in z direction") +parser.add_argument("--roi_x", default=64, type=int, help="roi size in x direction") +parser.add_argument("--roi_y", default=64, type=int, help="roi size in y direction") +parser.add_argument("--roi_z", default=64, type=int, help="roi size in z direction") +parser.add_argument("--dropout_rate", default=0.0, type=float, help="dropout rate") +parser.add_argument("--distributed", action="store_true", help="start distributed training") +parser.add_argument("--world_size", default=1, type=int, help="number of nodes for distributed training") +parser.add_argument("--rank", default=0, type=int, help="node rank for distributed training") +parser.add_argument("--dist-url", default="tcp://127.0.0.1:23456", type=str, help="distributed url") +parser.add_argument("--dist-backend", default="nccl", type=str, help="distributed backend") +parser.add_argument("--workers", default=8, type=int, help="number of workers") +parser.add_argument("--use_normal_dataset", action="store_true", help="use monai Dataset class") +parser.add_argument( + "--nouse_multi_epochs_loader", + action="store_true", + help="not use the multi-epochs-loader to save time at the beginning of every epoch", +) +parser.add_argument("--RandFlipd_prob", default=0.2, type=float, help="RandFlipd aug probability") +parser.add_argument("--RandRotate90d_prob", default=0.2, type=float, help="RandRotate90d aug probability") +parser.add_argument("--RandScaleIntensityd_prob", default=0.1, type=float, help="RandScaleIntensityd aug probability") +parser.add_argument("--RandShiftIntensityd_prob", default=0.1, type=float, help="RandShiftIntensityd aug probability") +parser.add_argument("--spatial_dims", default=3, type=int, help="spatial dimension of input data") +parser.add_argument("--use_checkpoint", action="store_true", help="use gradient checkpointing to save memory") +parser.add_argument( + "--cross_attention_in_origin_view", action="store_true", help="Whether compute cross attention in original view" +) + +spacing = Spacing(pixdim=(1, 1, 1), mode="nearest") +hd_per = 95 +view = ["Cor1", "Sag2", "Sag1", "Axi2", "Axi1", "Cor2", "Fuse"] + + +def main(): + os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" + args = parser.parse_args() + output_directory = "./outputs/" + args.exp_name + if not os.path.exists(output_directory): + os.makedirs(output_directory) + if args.distributed: + args.ngpus_per_node = torch.cuda.device_count() + print("Found total gpus", args.ngpus_per_node) + args.world_size = args.ngpus_per_node * args.world_size + mp.spawn(main_worker, nprocs=args.ngpus_per_node, args=(args,)) + else: + main_worker(gpu=0, args=args) + + +def main_worker(gpu, args): + output_directory = "./outputs/" + args.exp_name + if args.distributed: + torch.multiprocessing.set_start_method("fork", force=True) + # np.set_printoptions(formatter={"float": "{: 0.3f}".format}, suppress=True) + args.gpu = gpu + if args.distributed: + args.rank = args.rank * args.ngpus_per_node + gpu + dist.init_process_group( + backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank + ) + torch.cuda.set_device(args.gpu) + torch.backends.cudnn.benchmark = True + args.test_mode = True + val_loader = get_loader(args) + print(args.rank, " gpu", args.gpu) + if args.rank == 0: + setup_default_logging() + # logging.info(f"Batch size is: {args.batch_size}, epochs: {args.max_epochs}") + + pretrained_dir = args.pretrained_dir + model_name = args.pretrained_model_name + pretrained_pth = os.path.join(pretrained_dir, model_name) + model = SwinUNETR( + img_size=(args.roi_x, args.roi_y, args.roi_z), + in_channels=args.in_channels, + out_channels=args.out_channels, + feature_size=args.feature_size, + fusion_depths=(1, 1, 1, 1, 1, 1), + drop_rate=0.0, + attn_drop_rate=0.0, + use_checkpoint=args.use_checkpoint, + cross_attention_in_origin_view=args.cross_attention_in_origin_view, + ) + model.load_state_dict(torch.load(pretrained_pth, map_location="cpu")["state_dict"]) + model.cuda(args.gpu) + model.eval() + model_without_ddp = model + if args.distributed: + torch.cuda.set_device(args.gpu) + model.cuda(args.gpu) + model_without_ddp = model + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[args.gpu], output_device=args.gpu, broadcast_buffers=False, find_unused_parameters=True + ) + + dice_all = [] # np.zeros((len(val_loader), len(view), args.out_channels - 1)) + hd_all = [] # np.zeros((len(val_loader), len(view), args.out_channels - 1)) + asd_all = [] # np.zeros((len(val_loader), len(view), args.out_channels - 1)) + + with torch.no_grad(): + for id, batch in enumerate(val_loader): + val_inputs, val_labels = (batch["image"].cuda(args.gpu), batch["label"].cpu()) + original_affine = batch["label_meta_dict"]["affine"][0].numpy() + _, _, h, w, d = val_labels.shape + target_shape = (h, w, d) + val_labels = val_labels.numpy()[0, :, :, :, :] + img_name = batch["image_meta_dict"]["filename_or_obj"][0].split("/")[-1] + print("Inference on case {}".format(img_name)) + torch.cuda.empty_cache() + output_list = [] + val_fuse = 0 + + val_labels = spacing(val_labels, original_affine)[0] + val_labels = np.expand_dims(val_labels, axis=0) + val_labels = one_hot(torch.from_numpy(val_labels), num_classes=args.out_channels, dim=1) + + for i in range(3): + # j = i + 3 + val_outputs_1, val_outputs_2 = double_sliding_window_inference( + val_inputs, + i, + (args.roi_x, args.roi_y, args.roi_z), + 16, + model, + overlap=args.infer_overlap, + mode="gaussian", + ) + + val_outputs_1 = torch.softmax(val_outputs_1, 1).cpu().numpy()[0] + val_outputs_2 = torch.softmax(val_outputs_2, 1).cpu().numpy()[0] + val_fuse = val_fuse + val_outputs_1 + val_outputs_2 + output_list.append(val_outputs_1) + output_list.append(val_outputs_2) + output_list.append(val_fuse) + print("Inference finished on case {}".format(img_name)) + + dice = np.zeros((len(view), args.out_channels - 1)) + hd = np.zeros((len(view), args.out_channels - 1)) + asd = np.zeros((len(view), args.out_channels - 1)) + for i, output in enumerate(output_list): + output = np.argmax(output, axis=0, keepdims=False) + output = resample_3d(output, target_shape) + target_ornt = nib.orientations.axcodes2ornt(tuple(nib.aff2axcodes(original_affine))) + out_ornt = [[0, 1], [1, 1], [2, 1]] + ornt_transf = nib.orientations.ornt_transform(out_ornt, target_ornt) + output = nib.orientations.apply_orientation(output, ornt_transf) + nib.save( + nib.Nifti1Image(output.astype(np.uint8), affine=original_affine), + os.path.join(output_directory, view[i] + "_" + img_name), + ) + output = np.expand_dims(spacing(np.expand_dims(output, axis=(0)), original_affine)[0], axis=0) + output = one_hot(torch.from_numpy(output), num_classes=args.out_channels, dim=1) + # print(output.shape, val_labels.shape) + dice_ = compute_meandice(output, val_labels, include_background=False).numpy()[0] + hd_ = compute_hausdorff_distance(output, val_labels, percentile=hd_per).numpy()[0] + asd_ = compute_average_surface_distance(output, val_labels).numpy()[0] + print("{} {} View Mean Dice: {}".format(img_name, view[i], np.mean(dice_))) + print("{} {} View Mean HD: {}".format(img_name, view[i], np.mean(hd_))) + print("{} {} View Mean ASD: {}".format(img_name, view[i], np.mean(asd_))) + dice[i, :] = dice_ + hd[i, :] = hd_ + asd[i, :] = asd_ + dice_all.append(dice) + hd_all.append(hd) + asd_all.append(asd) + + dice_all = torch.tensor(np.stack(dice_all, axis=0)).cuda(args.gpu) + hd_all = torch.tensor(np.stack(hd_all, axis=0)).cuda(args.gpu) + asd_all = torch.tensor(np.stack(asd_all, axis=0)).cuda(args.gpu) + dice_list = distributed_all_gather([dice_all], out_numpy=False, is_valid=True) + hd_list = distributed_all_gather([hd_all], out_numpy=False, is_valid=True) + asd_list = distributed_all_gather([asd_all], out_numpy=False, is_valid=True) + dice_list = torch.flatten(torch.stack(dice_list[0], axis=0), start_dim=0, end_dim=1).cpu().numpy() + hd_list = torch.flatten(torch.stack(hd_list[0], axis=0), start_dim=0, end_dim=1).cpu().numpy() + asd_list = torch.flatten(torch.stack(asd_list[0], axis=0), start_dim=0, end_dim=1).cpu().numpy() + + if args.rank == 0: + for i in range(len(view)): + print(dice_list.shape) + print("Overall {} View Mean Dice: {}".format(view[i], np.mean(dice_list[:, i, :], axis=0))) + print("Overall {} View Mean HD: {}".format(view[i], np.mean(hd_list[:, i, :], axis=0))) + print("Overall {} View Mean ASD: {}".format(view[i], np.mean(asd_list[:, i, :], axis=0))) + np.savetxt( + os.path.join(output_directory, view[i] + "Dice.txt"), + np.mean(dice_list[:, i, :], axis=0), + delimiter="\t", + ) + np.savetxt( + os.path.join(output_directory, view[i] + "HD.txt"), np.mean(hd_list[:, i, :], axis=0), delimiter="\t" + ) + np.savetxt( + os.path.join(output_directory, view[i] + "ASD.txt"), np.mean(asd_list[:, i, :], axis=0), delimiter="\t" + ) + + +if __name__ == "__main__": + main() diff --git a/SwinMM/WORD/trainer.py b/SwinMM/WORD/trainer.py new file mode 100644 index 00000000..5ea97b51 --- /dev/null +++ b/SwinMM/WORD/trainer.py @@ -0,0 +1,296 @@ +# Copyright 2020 - 2022 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time + +import numpy as np +import torch +import torch.distributed +import torch.nn.functional as F +import torch.nn.parallel +import torch.utils.data.distributed +from torch.cuda.amp import GradScaler, autocast +from torch.utils.tensorboard import SummaryWriter +from utils import view_ops +from utils.misc import AverageMeter, distributed_all_gather + +from monai.data import decollate_batch + + +def train_epoch(model, loader, optimizer, scaler, epoch, self_crit, mutual_crit, args): + model.train() + start_time = time.time() + run_loss = AverageMeter() + run_self_loss = AverageMeter() + run_mutual_loss = AverageMeter() + + for idx, batch_data in enumerate(loader): + for param in model.parameters(): + param.grad = None + if isinstance(batch_data, list): + data, target = batch_data + else: + data, target = batch_data["image"], batch_data["label"] + data = data.cuda(args.rank) + if not args.unsupervised: + target = target.cuda(args.rank) + data_list, view_list = view_ops.permute_rand(data) + + loss = 0 + self_loss_list, mutual_loss_list = [], [] + with autocast(enabled=args.amp): + output1, output2 = model(data_list[0], data_list[1], view_list) + out_list = [output1, output2] + out_list = view_ops.permute_inverse(out_list, view_list) + if args.unsupervised: + target = torch.argmax( + (torch.softmax(out_list[0], dim=1) + torch.softmax(out_list[1], dim=1)) / 2, dim=1, keepdim=True + ).cuda(args.rank) + for i in range(len(out_list)): + self_loss = self_crit(out_list[i], target) + mutual_loss = 0 + for j in range(len(out_list)): # KL divergence + if i != j: + mutual_end = mutual_crit(F.log_softmax(out_list[i], dim=1), F.softmax(out_list[j], dim=1)) + mutual_loss += mutual_end + loss = loss + (self_loss + mutual_loss / (len(out_list) - 1)) / len(out_list) + self_loss_list.append(self_loss.item()) + mutual_loss_list.append(mutual_loss.item()) + self_loss = torch.mean(torch.tensor(self_loss_list)).cuda(args.rank) + mutual_loss = torch.mean(torch.tensor(mutual_loss_list)).cuda(args.rank) + + if args.amp: + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + else: + loss.backward() + optimizer.step() + if args.distributed: + is_valid = True + loss_list = distributed_all_gather([loss], out_numpy=True, is_valid=is_valid) + run_loss.update( + np.mean(np.mean(np.stack(loss_list, axis=0), axis=0), axis=0), n=args.batch_size * args.world_size + ) + self_loss_list = distributed_all_gather([self_loss], out_numpy=True, is_valid=is_valid) + run_self_loss.update( + np.mean(np.mean(np.stack(self_loss_list, axis=0), axis=0), axis=0), n=args.batch_size * args.world_size + ) + mutual_loss_list = distributed_all_gather([mutual_loss], out_numpy=True, is_valid=is_valid) + run_mutual_loss.update( + np.mean(np.mean(np.stack(mutual_loss_list, axis=0), axis=0), axis=0), + n=args.batch_size * args.world_size, + ) + else: + run_loss.update(loss.item(), n=args.batch_size) + run_self_loss.update(self_loss.item(), n=args.batch_size) + run_mutual_loss.update(mutual_loss.item(), n=args.batch_size) + if args.rank == 0: + print( + "Epoch {}/{} {}/{}".format(epoch, args.max_epochs, idx, len(loader)), + "loss: {:.4f}".format(run_loss.avg), + "self_loss: {:.4f}".format(run_self_loss.avg), + "mutual_loss: {:.4f}".format(run_mutual_loss.avg), + "time {:.2f}s".format(time.time() - start_time), + ) + start_time = time.time() + for param in model.parameters(): + param.grad = None + return run_loss.avg, run_self_loss.avg, run_mutual_loss.avg + + +def val_epoch(model, loader, epoch, acc_func, args, model_inferer=None, post_label=None, post_pred=None): + model.eval() + run_acc = AverageMeter() + start_time = time.time() + with torch.no_grad(): + for idx, batch_data in enumerate(loader): + if isinstance(batch_data, list): + data, target = batch_data + else: + data, target = batch_data["image"], batch_data["label"] + data = data.cuda(args.rank) + torch.cuda.empty_cache() + with autocast(enabled=args.amp): + i = np.random.randint(0, 3) + if model_inferer is not None: + output1, output2 = model_inferer(data, i) + else: + output1, output2 = model(data, i) + output1, output2, target = output1.cpu(), output2.cpu(), target.cpu() + val_labels_list = decollate_batch(target) + val_labels_convert = [post_label(val_label_tensor) for val_label_tensor in val_labels_list] + out = (output1 + output2) / 2 + val_outputs_list = decollate_batch(out) + val_output_convert = [post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list] + acc_func.reset() + acc_func(y_pred=val_output_convert, y=val_labels_convert) + acc, not_nans = acc_func.aggregate() + acc, not_nans = acc.cuda(args.rank), not_nans.cuda(args.rank) + + if args.distributed: + is_valid = True + acc_list, not_nans_list = distributed_all_gather([acc, not_nans], out_numpy=True, is_valid=is_valid) + for al, nl in zip(acc_list, not_nans_list): + run_acc.update(al, n=nl) + + else: + run_acc.update(acc.cpu().numpy(), n=not_nans.cpu().numpy()) + + if args.rank == 0: + avg_acc = np.mean(run_acc.avg) + print( + "Val {}/{} {}/{}".format(epoch, args.max_epochs, idx, len(loader)), + "acc", + avg_acc, + "time {:.2f}s".format(time.time() - start_time), + ) + start_time = time.time() + return run_acc.avg + + +def save_checkpoint(model, epoch, args, filename="model.pt", best_acc=0, optimizer=None, scheduler=None): + state_dict = model.state_dict() if not args.distributed else model.module.state_dict() + save_dict = {"epoch": epoch, "best_acc": best_acc, "state_dict": state_dict} + if optimizer is not None: + save_dict["optimizer"] = optimizer.state_dict() + if scheduler is not None: + save_dict["scheduler"] = scheduler.state_dict() + filename = os.path.join(args.logdir, filename) + torch.save(save_dict, filename) + print("Saving checkpoint", filename) + + +def run_training( + model, + train_loader, + val_loader, + unsupervised_loader, + optimizer, + self_crit, + mutual_crit, + acc_func, + args, + model_inferer=None, + scheduler=None, + start_epoch=0, + post_label=None, + post_pred=None, +): + writer = None + if args.logdir is not None and args.rank == 0: + writer = SummaryWriter(log_dir=args.logdir) + if args.rank == 0: + print("Writing Tensorboard logs to ", args.logdir) + scaler = None + if args.amp: + scaler = GradScaler() + val_acc_max = 0.0 + for epoch in range(start_epoch, args.max_epochs): + if args.distributed: + train_loader.sampler.set_epoch(epoch) + torch.distributed.barrier() + print(args.rank, time.ctime(), "Epoch:", epoch) + epoch_time = time.time() + train_loss, self_loss, mutual_loss = train_epoch( + model, + train_loader, + optimizer, + scaler=scaler, + epoch=epoch, + self_crit=self_crit, + mutual_crit=mutual_crit, + args=args, + ) + if args.rank == 0: + print( + "Final training {}/{}".format(epoch, args.max_epochs - 1), + "loss: {:.4f}".format(train_loss), + "self loss: {:.4f}".format(self_loss), + "mutual loss: {:.4f}".format(mutual_loss), + "time {:.2f}s".format(time.time() - epoch_time), + ) + if args.rank == 0 and writer is not None: + writer.add_scalar("train_loss", train_loss, epoch) + writer.add_scalar("self_loss", self_loss, epoch) + writer.add_scalar("mutual_loss", mutual_loss, epoch) + + if args.unsupervised and (epoch + 1) % args.unsuper_every == 0: + if args.distributed: + unsupervised_loader.sampler.set_epoch(epoch) + torch.distributed.barrier() + print(args.rank, time.ctime(), "Epoch:", epoch) + epoch_time = time.time() + train_loss, mutual_loss = train_epoch( + model, + unsupervised_loader, + optimizer, + scaler=scaler, + epoch=epoch, + self_crit=self_crit, + mutual_crit=mutual_crit, + args=args, + ) + if args.rank == 0: + print( + "Final unsupervised training {}/{}".format(epoch, args.max_epochs - 1), + "loss: {:.4f}".format(train_loss), + "mutual loss: {:.4f}".format(mutual_loss), + "time {:.2f}s".format(time.time() - epoch_time), + ) + if args.rank == 0 and writer is not None: + writer.add_scalar("train_unsupervised_loss", train_loss, epoch) + writer.add_scalar("unsupervised_self_loss", self_loss, epoch) + writer.add_scalar("unsupervised_mutual_loss", mutual_loss, epoch) + + if epoch >= args.val_start and (epoch + 1) % args.val_every == 0: + if args.distributed: + torch.distributed.barrier() + epoch_time = time.time() + val_avg_acc = val_epoch( + model, + val_loader, + epoch=epoch, + acc_func=acc_func, + model_inferer=model_inferer, + args=args, + post_label=post_label, + post_pred=post_pred, + ) + + val_avg_acc = np.mean(val_avg_acc) + + if args.rank == 0: + print( + "Final validation {}/{}".format(epoch, args.max_epochs - 1), + "acc", + val_avg_acc, + "time {:.2f}s".format(time.time() - epoch_time), + ) + if writer is not None: + writer.add_scalar("val_acc", val_avg_acc, epoch) + if val_avg_acc > val_acc_max: + print("new best ({:.6f} --> {:.6f}). ".format(val_acc_max, val_avg_acc)) + val_acc_max = val_avg_acc + if args.rank == 0 and args.logdir is not None and args.save_checkpoint: + save_checkpoint( + model, epoch, args, best_acc=val_acc_max, optimizer=optimizer, scheduler=scheduler + ) + if args.rank == 0 and args.logdir is not None and args.save_checkpoint: + save_checkpoint(model, epoch, args, best_acc=val_acc_max, filename="model_final.pt") + + if scheduler is not None: + scheduler.step() + + print("Training Finished !, Best Accuracy: ", val_acc_max) + + return val_acc_max diff --git a/SwinMM/WORD/utils/__init__.py b/SwinMM/WORD/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/SwinMM/WORD/utils/data_utils.py b/SwinMM/WORD/utils/data_utils.py new file mode 100644 index 00000000..2788cce5 --- /dev/null +++ b/SwinMM/WORD/utils/data_utils.py @@ -0,0 +1,176 @@ +# Copyright 2020 - 2022 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Any, Callable, MutableMapping, Sequence + +import timm.data +import torch +import torch.utils.data +from utils import dataset_in_memory + +from monai import data, transforms +from monai.data import load_decathlon_datalist +from monai.data.utils import list_data_collate + + +def get_dataset_kwargs(dataset_name: str, stage: str, use_normal_dataset: bool, args) -> MutableMapping[str, Any]: + dataset_kwargs = {} + if not use_normal_dataset: + dataset_kwargs = dict( + dataset_name=f"{stage}_{dataset_name}", + hosts=[{"host": "localhost", "port": str(port)} for port in args.redis_ports], + cluster_mode=True, + capacity_per_node=200 * 1024 * 1024 * 1024, + writer_buffer_size=0, # Disable write buffer + ) + return dataset_kwargs + + +def create_dataloader( + data_files: Sequence[Any], + is_testing: bool, + transform: Callable, + num_workers: int, + is_distributed: bool, + with_cache: bool, + use_multi_epochs_loader: bool, + batch_size: int, + dataset_kwargs: MutableMapping[str, Any], +) -> torch.utils.data.DataLoader: + if not with_cache: + dataset = data.Dataset(data=data_files, transform=transform, **dataset_kwargs) + else: + dataset = dataset_in_memory.CachedDataset(data=data_files, transform=transform, **dataset_kwargs) + sampler = torch.utils.data.DistributedSampler(dataset, shuffle=not is_testing) if is_distributed else None + + loader_class = data.DataLoader + if use_multi_epochs_loader: + loader_class = timm.data.loader.MultiEpochsDataLoader + loader = loader_class( + dataset, + batch_size=batch_size, + shuffle=False if is_distributed or is_testing else True, + num_workers=num_workers, + sampler=sampler, + pin_memory=True, + persistent_workers=True, + # NOTE(meijieru): otherwise `too many open` + collate_fn=list_data_collate, + ) + return loader + + +def get_loader(args): + data_dir = args.data_dir + datalist_json = os.path.join(data_dir, args.json_list) + train_transform = transforms.Compose( + [ + transforms.LoadImaged(keys=["image", "label"]), + transforms.AddChanneld(keys=["image", "label"]), + transforms.Orientationd(keys=["image", "label"], axcodes="RAS"), + transforms.Spacingd( + keys=["image", "label"], pixdim=(args.space_x, args.space_y, args.space_z), mode=("bilinear", "nearest") + ), + transforms.ScaleIntensityRanged( + keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True + ), + # transforms.CropForegroundd(keys=["image", "label"], source_key="image"), + transforms.RandCropByPosNegLabeld( + keys=["image", "label"], + label_key="label", + spatial_size=(args.roi_x, args.roi_y, args.roi_z), + pos=1, + neg=1, + num_samples=4, + image_key="image", + image_threshold=0, + ), + transforms.RandFlipd(keys=["image", "label"], prob=args.RandFlipd_prob, spatial_axis=0), + transforms.RandFlipd(keys=["image", "label"], prob=args.RandFlipd_prob, spatial_axis=1), + transforms.RandFlipd(keys=["image", "label"], prob=args.RandFlipd_prob, spatial_axis=2), + transforms.RandRotate90d(keys=["image", "label"], prob=args.RandRotate90d_prob, max_k=3), + transforms.RandScaleIntensityd(keys="image", factors=0.1, prob=args.RandScaleIntensityd_prob), + transforms.RandShiftIntensityd(keys="image", offsets=0.1, prob=args.RandShiftIntensityd_prob), + transforms.ToTensord(keys=["image", "label"]), + ] + ) + val_transform = transforms.Compose( + [ + transforms.LoadImaged(keys=["image", "label"]), + transforms.AddChanneld(keys=["image", "label"]), + transforms.Orientationd(keys=["image", "label"], axcodes="RAS"), + transforms.Spacingd( + keys=["image", "label"], pixdim=(args.space_x, args.space_y, args.space_z), mode=("bilinear", "nearest") + ), + transforms.ScaleIntensityRanged( + keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True + ), + # transforms.CropForegroundd(keys=["image", "label"], source_key="image"), + transforms.ToTensord(keys=["image", "label"]), + ] + ) + + test_transform = transforms.Compose( + [ + transforms.LoadImaged(keys=["image", "label"]), + transforms.AddChanneld(keys=["image", "label"]), + transforms.Orientationd(keys=["image", "label"], axcodes="RAS"), + transforms.Spacingd(keys="image", pixdim=(args.space_x, args.space_y, args.space_z), mode="bilinear"), + transforms.ScaleIntensityRanged( + keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True + ), + transforms.ToTensord(keys=["image", "label"]), + ] + ) + + def _get_subset_loader(subset_name: str, transform: Callable, is_testing: bool, use_normal_dataset: bool): + datalist = load_decathlon_datalist(datalist_json, True, subset_name, base_dir=data_dir) + dataset_kwargs = get_dataset_kwargs(subset_name, "finetune", use_normal_dataset, args) + loader = create_dataloader( + datalist, + is_testing, + transform, + args.workers, + args.distributed, + not use_normal_dataset, + not args.nouse_multi_epochs_loader, + args.batch_size, + dataset_kwargs, + ) + return loader + + if args.test_mode: + # Never cache as only go through once. + test_files = load_decathlon_datalist(datalist_json, True, "testing", base_dir=data_dir) + test_ds = data.Dataset(data=test_files, transform=test_transform) + test_sampler = torch.utils.data.DistributedSampler(test_ds, shuffle=False) if args.distributed else None + test_loader = data.DataLoader( + test_ds, + batch_size=1, + shuffle=False, + num_workers=args.workers, + sampler=test_sampler, + pin_memory=True, + persistent_workers=True, + ) + loader = test_loader + else: + train_loader = _get_subset_loader("training", train_transform, False, args.use_normal_dataset) + val_loader = _get_subset_loader("validation", val_transform, True, args.use_normal_dataset_val) + + if args.unsupervised: + unsupervised_loader = _get_subset_loader("unsupervised", train_transform, False, args.use_normal_dataset) + loader = [train_loader, val_loader, unsupervised_loader] + else: + loader = [train_loader, val_loader] + + return loader diff --git a/SwinMM/WORD/utils/dataset_in_memory.py b/SwinMM/WORD/utils/dataset_in_memory.py new file mode 100644 index 00000000..7ee751ab --- /dev/null +++ b/SwinMM/WORD/utils/dataset_in_memory.py @@ -0,0 +1,107 @@ +"""Cache the data using redis. + +TODO(meijieru): zeromp may be better. +""" + +from typing import Callable, MutableMapping, Optional, Sequence, Union + +import bagua.torch_api.contrib.cache_loader as bagua_cache_loader +import torch +import torch.utils.data.dataset as torch_dataset + +import monai.data as monai_data +import monai.transforms as monai_transforms + +_ALL_DATASET_NAMES = set() +_SERIALIZATION_HIJACKED = False + + +def hijack_bagua_serialization(method: str): + """Replace bagua serialization.""" + global _SERIALIZATION_HIJACKED + if _SERIALIZATION_HIJACKED: + raise RuntimeError("Already hijacked.") + + import pickle + + if method == "lz4": + import lz4 + + compress, decompress = lz4.frame.compress, lz4.frame.decompress + elif method == "lzma": + import pylzma as lzma + + compress, decompress = lzma.compress, lzma.decompress + elif method == "zlib": + import zlib + + compress, decompress = zlib.compress, zlib.decompress + else: + raise ValueError(f"Unknown compress method: {method}") + + bagua_cache_loader.serialize = lambda val: compress(pickle.dumps(val)) + bagua_cache_loader.deserialize = lambda val: pickle.loads(decompress(val)) + _SERIALIZATION_HIJACKED = True + + +def is_deterministic_transform(transform) -> bool: + return not ( + isinstance(transform, monai_transforms.Randomizable) or not isinstance(transform, monai_transforms.Transform) + ) + + +class CachedDataset(torch_dataset.Dataset): + def __init__( + self, + data: Sequence, + transform: Optional[Union[Sequence[Callable], Callable]] = None, + as_contiguous: bool = True, + backend: str = "redis", + hosts: Optional[Sequence[MutableMapping[str, str]]] = None, + dataset_name: str = "", + writer_buffer_size: int = 20, + **kwargs, + ) -> None: + super().__init__() + + if hosts is None: + raise ValueError("We don't init bagua, have to manually launch redis") + + # NOTE(meijieru): check if the dataset name is unique, to avoid + # potential confliction. + if not dataset_name or dataset_name in _ALL_DATASET_NAMES: + raise ValueError("Must have an unique name for each dataset.") + _ALL_DATASET_NAMES.add(dataset_name) + + self._dataset = monai_data.Dataset(data=data) + self._cache_loader = bagua_cache_loader.CacheLoader( + backend, dataset_name, writer_buffer_size, hosts=hosts, **kwargs + ) + self.transform = transform + self.as_contiguous = as_contiguous + + def __len__(self): + return len(self._dataset) + + def _apply_non_deterministic_transform(self, item): + for trans in self.transform.transforms: # type:ignore + if not is_deterministic_transform(trans): + item = monai_transforms.apply_transform(trans, item) + return item + + def _apply_deterministic_transform(self, item): + for trans in self.transform.transforms: # type:ignore + # execute all the deterministic transforms + if not is_deterministic_transform(trans): + break + item = monai_transforms.apply_transform(trans, item) + if self.as_contiguous: + item = monai_transforms.convert_to_contiguous(item, memory_format=torch.contiguous_format) + return item + + def _load_item(self, index: int): + return self._apply_deterministic_transform(self._dataset[index]) + + def __getitem__(self, item): + cached_item = self._cache_loader.get(item, self._load_item) + return self._apply_non_deterministic_transform(cached_item) diff --git a/SwinMM/WORD/utils/misc.py b/SwinMM/WORD/utils/misc.py new file mode 100644 index 00000000..6cb88985 --- /dev/null +++ b/SwinMM/WORD/utils/misc.py @@ -0,0 +1,78 @@ +# Copyright 2020 - 2022 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import scipy.ndimage as ndimage +import torch + + +def resample_3d(img, target_size): + imx, imy, imz = img.shape + tx, ty, tz = target_size + zoom_ratio = (float(tx) / float(imx), float(ty) / float(imy), float(tz) / float(imz)) + img_resampled = ndimage.zoom(img, zoom_ratio, order=0, prefilter=False) + return img_resampled + + +class AverageMeter(object): + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = np.where(self.count > 0, self.sum / self.count, self.sum) + + +def distributed_all_gather( + tensor_list, valid_batch_size=None, out_numpy=False, world_size=None, no_barrier=False, is_valid=None +): + if world_size is None: + world_size = torch.distributed.get_world_size() + if valid_batch_size is not None: + valid_batch_size = min(valid_batch_size, world_size) + elif is_valid is not None: + is_valid = torch.tensor(bool(is_valid), dtype=torch.bool, device=tensor_list[0].device) + if not no_barrier: + torch.distributed.barrier() + tensor_list_out = [] + with torch.no_grad(): + if is_valid is not None: + is_valid_list = [torch.zeros_like(is_valid) for _ in range(world_size)] + torch.distributed.all_gather(is_valid_list, is_valid) + is_valid = [x.item() for x in is_valid_list] + for tensor in tensor_list: + gather_list = [torch.zeros_like(tensor) for _ in range(world_size)] + torch.distributed.all_gather(gather_list, tensor) + if valid_batch_size is not None: + gather_list = gather_list[:valid_batch_size] + elif is_valid is not None: + gather_list = [g for g, v in zip(gather_list, is_valid_list) if v] + if out_numpy: + gather_list = [t.cpu().numpy() for t in gather_list] + tensor_list_out.append(gather_list) + return tensor_list_out + + +def dice(x, y): + intersect = np.sum(np.sum(np.sum(x * y))) + y_sum = np.sum(np.sum(np.sum(y))) + if y_sum == 0: + return 0.0 + x_sum = np.sum(np.sum(np.sum(x))) + return 2 * intersect / (x_sum + y_sum) diff --git a/SwinMM/WORD/utils/test_view_transforms.py b/SwinMM/WORD/utils/test_view_transforms.py new file mode 100644 index 00000000..a8cbeb44 --- /dev/null +++ b/SwinMM/WORD/utils/test_view_transforms.py @@ -0,0 +1,39 @@ +"""Unit test for view transforms.""" + +import itertools +import unittest + +import numpy as np +import torch +from utils import view_transforms + + +class ViewTransformsTest(unittest.TestCase): + def test_len(self): + self.assertTrue(len(view_transforms.all_forward_transforms), len(view_transforms.all_backward_transforms)) + + def test_inverse_transforms(self): + x = np.random.uniform(size=(2, 6, 3, 4, 5)) + x_torch = torch.from_numpy(x) + for group_name, transforms in view_transforms.all_forward_transforms.items(): + inverse_transforms = view_transforms.all_backward_transforms[group_name] + self.assertEqual(len(transforms), len(inverse_transforms)) + for key in transforms: + x_recon = inverse_transforms[key](transforms[key](x_torch)).numpy() + np.testing.assert_allclose(x, x_recon) + + def test_get_transforms_func(self): + x = np.random.uniform(size=(2, 6, 3, 4, 5)) + x_torch = torch.from_numpy(x) + + for order in [view_transforms.DEFAULT_ORDER, view_transforms.DEFAULT_ORDER[::-1]]: + views_all = itertools.product(*[view_transforms.all_forward_transforms[gn].keys() for gn in order]) + for views in views_all: + func, inv_func = [ + view_transforms.get_transforms_func(views, order, inverse) for inverse in [False, True] + ] + np.testing.assert_allclose(x, inv_func(func(x_torch)).numpy()) + + +if __name__ == "__main__": + unittest.main() diff --git a/SwinMM/WORD/utils/view_ops.py b/SwinMM/WORD/utils/view_ops.py new file mode 100644 index 00000000..8852e7e6 --- /dev/null +++ b/SwinMM/WORD/utils/view_ops.py @@ -0,0 +1,34 @@ +"""View operations.""" + +from typing import Sequence, Tuple + +import numpy as np +import torch +from utils import view_transforms + +PermuteType = view_transforms.PermuteType +TransformFuncType = view_transforms.TransformFuncType + + +def get_permute_transform(view_src: PermuteType, view_dst: PermuteType) -> TransformFuncType: + """Gets transform function from view src to view dst.""" + + def transform(x: torch.Tensor) -> torch.Tensor: + x_view_0 = view_transforms.permutation_inverse_transforms[view_src](x) + return view_transforms.permutation_transforms[view_dst](x_view_0).contiguous() + + return transform + + +def permute_inverse(xs: Sequence[torch.Tensor], views: Sequence[PermuteType]) -> Sequence[torch.Tensor]: + """Transforms data back to origin view.""" + return [get_permute_transform(view, 0)(x) for x, view in zip(xs, views)] + + +def permute_rand(x: torch.Tensor, num_samples: int = 2) -> Tuple[Sequence[torch.Tensor], Sequence[PermuteType]]: + """Samples different transforms of data.""" + num_permutes = len(view_transforms.permutation_transforms) + if num_samples > num_permutes: + raise ValueError("Duplicate samples.") + view_dsts = np.random.permutation(num_permutes)[:num_samples].tolist() + return [get_permute_transform(0, view)(x) for view in view_dsts], view_dsts diff --git a/SwinMM/WORD/utils/view_transforms.py b/SwinMM/WORD/utils/view_transforms.py new file mode 100644 index 00000000..752ab5bd --- /dev/null +++ b/SwinMM/WORD/utils/view_transforms.py @@ -0,0 +1,69 @@ +"""View operations. + +Input format: [B, C, X, Y, Z, ...] + +NOTE(meijieru): 0 is reserved for identify transform. +""" + +import enum +from typing import Callable, Sequence, Union + +import torch + +RotateType = int +PermuteType = int +TransformFuncType = Callable[[torch.Tensor], torch.Tensor] +# A composition of multiple view transoforms. +TransformsType = Sequence[Union[PermuteType, RotateType]] + + +class GroupName(enum.Enum): + ROTATE = 1 + PERMUTE = 2 + + +DEFAULT_ORDER = (GroupName.ROTATE, GroupName.PERMUTE) + +rotation_transforms = { + 0: lambda x: x, + 1: lambda x: x.rot90(1, (3, 4)), + 2: lambda x: x.rot90(2, (3, 4)), + 3: lambda x: x.rot90(3, (3, 4)), +} +rotation_inverse_transforms = { + 0: lambda x: x, + 1: lambda x: x.rot90(3, (3, 4)), + 2: lambda x: x.rot90(2, (3, 4)), + 3: lambda x: x.rot90(1, (3, 4)), +} +permutation_transforms = {0: lambda x: x, 1: lambda x: x.permute(0, 1, 3, 2, 4), 2: lambda x: x.permute(0, 1, 4, 3, 2)} +permutation_inverse_transforms = { + 0: lambda x: x, + 1: lambda x: x.permute(0, 1, 3, 2, 4), + 2: lambda x: x.permute(0, 1, 4, 3, 2), +} + +all_forward_transforms = {GroupName.ROTATE: rotation_transforms, GroupName.PERMUTE: permutation_transforms} +all_backward_transforms = { + GroupName.ROTATE: rotation_inverse_transforms, + GroupName.PERMUTE: permutation_inverse_transforms, +} + + +def get_transforms_func( + views: TransformsType, orders: Sequence[GroupName] = DEFAULT_ORDER, inverse: bool = False +) -> TransformFuncType: + """Gets sequential transform functions.""" + if len(views) != len(orders): + raise ValueError() + + all_transforms = all_forward_transforms if not inverse else all_backward_transforms + funcs = [all_transforms[group_name][view] for view, group_name in zip(views, orders)] + funcs = funcs if not inverse else funcs[::-1] + + def aux(val): + for func in funcs: + val = func(val) + return val + + return aux diff --git a/SwinMM/figures/ACDC.png b/SwinMM/figures/ACDC.png new file mode 100644 index 00000000..af81301b Binary files /dev/null and b/SwinMM/figures/ACDC.png differ diff --git a/SwinMM/figures/Result.png b/SwinMM/figures/Result.png new file mode 100644 index 00000000..278fe405 Binary files /dev/null and b/SwinMM/figures/Result.png differ diff --git a/SwinMM/figures/SwinMMArch.png b/SwinMM/figures/SwinMMArch.png new file mode 100644 index 00000000..08487b6f Binary files /dev/null and b/SwinMM/figures/SwinMMArch.png differ diff --git a/SwinMM/figures/finetune.png b/SwinMM/figures/finetune.png new file mode 100644 index 00000000..1420a112 Binary files /dev/null and b/SwinMM/figures/finetune.png differ diff --git a/SwinMM/figures/pretrain.png b/SwinMM/figures/pretrain.png new file mode 100644 index 00000000..b2c71254 Binary files /dev/null and b/SwinMM/figures/pretrain.png differ diff --git a/SwinMM/requirements.txt b/SwinMM/requirements.txt new file mode 100644 index 00000000..047387a3 --- /dev/null +++ b/SwinMM/requirements.txt @@ -0,0 +1,33 @@ +apex==0.9.10dev +batchgenerators==0.25 +caffe2==0.8.1 +colorama==0.4.6 +Flask==2.3.2 +gevent==23.7.0 +gorilla==0.4.0 +hypothesis==6.81.2 +lz4==4.3.2 +monai==1.2.0 +nibabel==5.1.0 +numba==0.57.1 +numpy==1.25.1 +pssh==2.3.1 +ptvsd==4.3.2 +pydantic==2.0.3 +pylzma==0.5.0 +pytest==7.4.0 +redis==4.6.0 +Requests==2.31.0 +scipy==1.11.1 +setuptools==65.6.3 +setuptools_rust==1.6.0 +tensorboardX==2.6.1 +tensorflow_datasets==4.9.2 +timm==0.9.2 +torchvision==0.15.2 +tqdm==4.65.0 +transformers==4.30.2 +typing_extensions==4.7.1 +urllib3==1.26.15 +xmlrunner==1.7.7 +xxhash==3.2.0 diff --git a/SwinMM/scripts/setup_env.sh b/SwinMM/scripts/setup_env.sh new file mode 100644 index 00000000..fb4892a4 --- /dev/null +++ b/SwinMM/scripts/setup_env.sh @@ -0,0 +1,5 @@ +# Activate conda corresponding env first +conda install -c conda-forge redis redis-py +pip install timm +pip install bagua-cuda111 +pip install setuptools==59.5.0 diff --git a/SwinMM/scripts/start_redis.sh b/SwinMM/scripts/start_redis.sh new file mode 100644 index 00000000..9f2bd012 --- /dev/null +++ b/SwinMM/scripts/start_redis.sh @@ -0,0 +1,13 @@ +ports="39999 39998 39997 39996" +for port in ${ports}; do + echo "run redis at localhost:${port}" + redis-server \ + --daemonize yes \ + --port ${port} \ + --maxclients 100000 \ + --maxmemory 0 \ + --maxmemory-policy noeviction \ + --appendonly no \ + --save "" \ + --protected-mode no +done