From b265038e616e1b0601d85a21ff0a8047108a67de Mon Sep 17 00:00:00 2001 From: Benjamin Morris Date: Thu, 25 Apr 2024 12:04:18 -0700 Subject: [PATCH] update configs --- configs/brian_config.yaml | 370 ++++++++++++++++++ configs/data/im2im/segmentation_plugin.yaml | 11 +- .../experiment/im2im/segmentation_plugin.yaml | 6 + configs/model/im2im/segmentation_plugin.yaml | 15 +- 4 files changed, 390 insertions(+), 12 deletions(-) create mode 100644 configs/brian_config.yaml diff --git a/configs/brian_config.yaml b/configs/brian_config.yaml new file mode 100644 index 00000000..9b03adaa --- /dev/null +++ b/configs/brian_config.yaml @@ -0,0 +1,370 @@ +experiment_name: YOUR_EXP_NAME +run_name: YOUR_RUN_NAME +task_name: train +tags: +- dev +train: true +test: true +ckpt_path: null +seed: 12345 +data: + _target_: cyto_dl.datamodules.dataframe.DataframeDatamodule + path: //allen/aics/assay-dev/users/Benji/CurrentProjects/im2im_dev/cyto-dl/data/golgi_seg/plugin + cache_dir: /storage/benji.the.kid/cache/plugin_test + num_workers: 4 + batch_size: 1 + pin_memory: true + split_column: null + columns: + - ${source_col} + - ${target_col1} + - ${target_col2} + - ${merge_mask_col} + - ${exclude_mask_col} + - ${base_image_col} + transforms: + train: + _target_: monai.transforms.Compose + transforms: + - _target_: cyto_dl.datamodules.dataframe.utils.RemoveNaNKeysd + - _target_: monai.transforms.LoadImaged + keys: ${source_col} + reader: + - _target_: cyto_dl.image.io.MonaiBioReader + dimension_order_out: ${eval:'"CZYX" if ${spatial_dims}==3 else "CYX"'} + C: ${input_channel} + - _target_: monai.transforms.LoadImaged + keys: ${target_col1} + reader: + - _target_: cyto_dl.image.io.MonaiBioReader + dimension_order_out: ${eval:'"CZYX" if ${spatial_dims}==3 else "CYX"'} + C: 0 + - _target_: monai.transforms.LoadImaged + keys: ${target_col2} + allow_missing_keys: true + reader: + - _target_: cyto_dl.image.io.MonaiBioReader + dimension_order_out: ${eval:'"CZYX" if ${spatial_dims}==3 else "CYX"'} + C: 1 + + - _target_: monai.transforms.ThresholdIntensityd + allow_missing_keys: true + keys: + - ${target_col1} + - ${target_col2} + threshold: 0.1 + above: false + cval: 1 + - _target_: cyto_dl.image.io.PolygonLoaderd + keys: + - ${merge_mask_col} + shape_reference_key: ${target_col1} + missing_key_mode: ignore + - _target_: cyto_dl.image.io.PolygonLoaderd + keys: + - ${exclude_mask_col} + shape_reference_key: ${target_col1} + missing_key_mode: create + - _target_: monai.transforms.NormalizeIntensityd + keys: ${source_col} + channel_wise: true + - _target_: cyto_dl.image.transforms.merge.Merged + mask_key: ${merge_mask_col} + image_keys: + - ${target_col1} + - ${target_col2} + base_image_key: ${base_image_col} + output_name: target + - _target_: monai.transforms.ToTensord + keys: + - ${source_col} + - target + - ${exclude_mask_col} + dtype: float16 + - _target_: cyto_dl.image.transforms.RandomMultiScaleCropd + keys: + - ${source_col} + - target + - ${exclude_mask_col} + patch_shape: ${data._aux.patch_shape} + patch_per_image: 8 + scales_dict: ${kv_to_dict:${data._aux._scales_dict}} + + test: + _target_: monai.transforms.Compose + transforms: + - _target_: cyto_dl.datamodules.dataframe.utils.RemoveNaNKeysd + - _target_: monai.transforms.LoadImaged + keys: ${source_col} + reader: + - _target_: cyto_dl.image.io.MonaiBioReader + dimension_order_out: ${eval:'"CZYX" if ${spatial_dims}==3 else "CYX"'} + C: ${input_channel} + - _target_: monai.transforms.LoadImaged + keys: ${target_col1} + reader: + - _target_: cyto_dl.image.io.MonaiBioReader + dimension_order_out: ${eval:'"CZYX" if ${spatial_dims}==3 else "CYX"'} + C: 0 + - _target_: monai.transforms.LoadImaged + keys: ${target_col2} + allow_missing_keys: true + reader: + - _target_: cyto_dl.image.io.MonaiBioReader + dimension_order_out: ${eval:'"CZYX" if ${spatial_dims}==3 else "CYX"'} + C: 0 + - _target_: cyto_dl.image.io.PolygonLoaderd + keys: + - ${merge_mask_col} + shape_reference_key: ${target_col1} + missing_key_mode: ignore + - _target_: cyto_dl.image.io.PolygonLoaderd + keys: + - ${exclude_mask_col} + shape_reference_key: ${target_col1} + missing_key_mode: create + - _target_: monai.transforms.NormalizeIntensityd + keys: ${source_col} + channel_wise: true + - _target_: cyto_dl.image.transforms.merge.Merged + mask_key: ${merge_mask_col} + image_keys: + - ${target_col1} + - ${target_col2} + base_image_key: ${base_image_col} + output_name: target + - _target_: monai.transforms.ToTensord + keys: + - ${source_col} + - target + - ${exclude_mask_col} + dtype: float16 + predict: + _target_: monai.transforms.Compose + transforms: + - _target_: monai.transforms.LoadImaged + keys: ${source_col} + reader: + - _target_: cyto_dl.image.io.MonaiBioReader + dimension_order_out: ${eval:'"CZYX" if ${spatial_dims}==3 else "CYX"'} + C: ${input_channel} + - _target_: monai.transforms.NormalizeIntensityd + keys: ${source_col} + channel_wise: true + - _target_: monai.transforms.ToTensord + keys: + - ${source_col} + valid: + _target_: monai.transforms.Compose + transforms: + - _target_: cyto_dl.datamodules.dataframe.utils.RemoveNaNKeysd + - _target_: monai.transforms.LoadImaged + keys: ${source_col} + reader: + - _target_: cyto_dl.image.io.MonaiBioReader + dimension_order_out: ${eval:'"CZYX" if ${spatial_dims}==3 else "CYX"'} + C: ${input_channel} + - _target_: monai.transforms.LoadImaged + keys: ${target_col1} + reader: + - _target_: cyto_dl.image.io.MonaiBioReader + dimension_order_out: ${eval:'"CZYX" if ${spatial_dims}==3 else "CYX"'} + C: 0 + - _target_: monai.transforms.LoadImaged + keys: ${target_col2} + allow_missing_keys: true + reader: + - _target_: cyto_dl.image.io.MonaiBioReader + dimension_order_out: ${eval:'"CZYX" if ${spatial_dims}==3 else "CYX"'} + C: 1 + - _target_: monai.transforms.ThresholdIntensityd + allow_missing_keys: true + keys: + - ${target_col1} + - ${target_col2} + threshold: 0.1 + above: false + cval: 1 + - _target_: cyto_dl.image.io.PolygonLoaderd + keys: + - ${merge_mask_col} + shape_reference_key: ${target_col1} + missing_key_mode: ignore + - _target_: cyto_dl.image.io.PolygonLoaderd + keys: + - ${exclude_mask_col} + shape_reference_key: ${target_col1} + missing_key_mode: create + - _target_: monai.transforms.NormalizeIntensityd + keys: ${source_col} + channel_wise: true + - _target_: cyto_dl.image.transforms.merge.Merged + mask_key: ${merge_mask_col} + image_keys: + - ${target_col1} + - ${target_col2} + base_image_key: ${base_image_col} + output_name: target + - _target_: monai.transforms.ToTensord + keys: + - ${source_col} + - target + - ${exclude_mask_col} + dtype: float16 + - _target_: cyto_dl.image.transforms.RandomMultiScaleCropd + keys: + - ${source_col} + - target + - ${exclude_mask_col} + patch_shape: ${data._aux.patch_shape} + patch_per_image: 1 + scales_dict: ${kv_to_dict:${data._aux._scales_dict}} + _aux: + _scales_dict: + - - target + - - 1 + - - ${source_col} + - - 1 + - - ${exclude_mask_col} + - - 1 + patch_shape: + - 16 + - 128 + - 128 +model: + _target_: cyto_dl.models.im2im.MultiTaskIm2Im + save_images_every_n_epochs: 1 + save_dir: ${paths.output_dir} + x_key: ${source_col} + backbone: + _target_: monai.networks.nets.DynUNet + spatial_dims: ${spatial_dims} + in_channels: ${raw_im_channels} + out_channels: 1 + strides: ${model._aux.strides} + kernel_size: ${model._aux.kernel_size} + upsample_kernel_size: ${model._aux.upsample_kernel_size} + dropout: 0.0 + res_block: true + task_heads: + target: + _target_: cyto_dl.nn.head.MaskHead + mask_key: ${exclude_mask_col} + loss: + _target_: monai.losses.MaskedDiceLoss + sigmoid: true + postprocess: + input: + _target_: cyto_dl.models.im2im.utils.postprocessing.ActThreshLabel + rescale_dtype: numpy.uint8 + prediction: + _target_: cyto_dl.models.im2im.utils.postprocessing.AutoThreshold + method: threshold_otsu + + save_input: true + optimizer: + generator: + _partial_: true + _target_: torch.optim.Adam + lr: 0.0001 + weight_decay: 0.0001 + lr_scheduler: + generator: + _partial_: true + _target_: torch.optim.lr_scheduler.ExponentialLR + gamma: 0.995 + inference_args: + sw_batch_size: 1 + roi_size: ${data._aux.patch_shape} + overlap: 0.0 + mode: gaussian + _aux: + strides: + - 1 + - 2 + - 2 + kernel_size: + - 3 + - 3 + - 3 + upsample_kernel_size: + - 2 + - 2 + filters: + - 16 + - 32 + - 64 +callbacks: + model_checkpoint: + _target_: lightning.pytorch.callbacks.ModelCheckpoint + dirpath: ${paths.output_dir}/checkpoints + filename: epoch_{epoch:03d} + monitor: val/loss + verbose: false + save_last: true + save_top_k: 1 + mode: min + auto_insert_metric_name: false + save_weights_only: false + every_n_train_steps: null + train_time_interval: null + every_n_epochs: 1 + save_on_train_epoch_end: null + early_stopping: + _target_: lightning.pytorch.callbacks.EarlyStopping + monitor: val/loss + min_delta: 0.0 + patience: 100 + verbose: false + mode: min + strict: true + check_finite: true + stopping_threshold: null + divergence_threshold: null + check_on_train_epoch_end: null + model_summary: + _target_: lightning.pytorch.callbacks.RichModelSummary + max_depth: -1 + rich_progress_bar: + _target_: lightning.pytorch.callbacks.RichProgressBar +logger: + csv: + _target_: lightning.pytorch.loggers.csv_logs.CSVLogger + save_dir: ${paths.output_dir} + name: csv/ + prefix: '' +trainer: + _target_: lightning.Trainer + default_root_dir: ${paths.output_dir} + min_epochs: 1 + max_epochs: 100 + accelerator: gpu + devices: 1 + precision: 16 + check_val_every_n_epoch: 1 + deterministic: false + detect_anomaly: false + max_time: null +paths: + root_dir: ${oc.env:PROJECT_ROOT, './'} + data_dir: ${paths.root_dir}/data/ + log_dir: ${paths.root_dir}/logs/ + output_dir: //allen/aics/assay-dev/users/Benji/office_hours/cyto-dl/logs + work_dir: //allen/aics/assay-dev/users/Benji/office_hours/cyto-dl/logs +extras: + ignore_warnings: true + enforce_tags: false + print_config: false + precision: + _target_: torch.set_float32_matmul_precision + precision: medium +persist_cache: True +source_col: raw +target_col1: seg1 +target_col2: seg2 +merge_mask_col: merge_mask +exclude_mask_col: exclude_mask +base_image_col: base_image +spatial_dims: 3 +input_channel: 3 +raw_im_channels: 1 diff --git a/configs/data/im2im/segmentation_plugin.yaml b/configs/data/im2im/segmentation_plugin.yaml index 334bf7a8..f296e469 100644 --- a/configs/data/im2im/segmentation_plugin.yaml +++ b/configs/data/im2im/segmentation_plugin.yaml @@ -3,7 +3,7 @@ _target_: cyto_dl.datamodules.dataframe.DataframeDatamodule path: cache_dir: -num_workers: 0 +num_workers: 4 batch_size: 1 pin_memory: True split_column: @@ -40,7 +40,7 @@ transforms: reader: - _target_: cyto_dl.image.io.MonaiBioReader dimension_order_out: ${eval:'"CZYX" if ${spatial_dims}==3 else "CYX"'} - C: 1 + C: 0 - _target_: monai.transforms.ThresholdIntensityd allow_missing_keys: True keys: @@ -88,7 +88,7 @@ transforms: - target - ${exclude_mask_col} patch_shape: ${data._aux.patch_shape} - patch_per_image: 1 + patch_per_image: ${data._aux.patch_per_image} scales_dict: ${kv_to_dict:${data._aux._scales_dict}} test: @@ -190,7 +190,7 @@ transforms: reader: - _target_: cyto_dl.image.io.MonaiBioReader dimension_order_out: ${eval:'"CZYX" if ${spatial_dims}==3 else "CYX"'} - C: 1 + C: 0 - _target_: monai.transforms.ThresholdIntensityd allow_missing_keys: True @@ -241,10 +241,11 @@ transforms: - target - ${exclude_mask_col} patch_shape: ${data._aux.patch_shape} - patch_per_image: 1 + patch_per_image: ${data._aux.patch_per_image} scales_dict: ${kv_to_dict:${data._aux._scales_dict}} _aux: + patch_per_image: 1 _scales_dict: - - target - [1] diff --git a/configs/experiment/im2im/segmentation_plugin.yaml b/configs/experiment/im2im/segmentation_plugin.yaml index ac8a9a2a..b767581f 100644 --- a/configs/experiment/im2im/segmentation_plugin.yaml +++ b/configs/experiment/im2im/segmentation_plugin.yaml @@ -52,3 +52,9 @@ data: paths: output_dir: MUST_OVERRIDE work_dir: ${paths.output_dir} # it's unclear to me if this is necessary or used + + +model: + _aux: + filters: MUST_OVERRIDE + overlap: 0 diff --git a/configs/model/im2im/segmentation_plugin.yaml b/configs/model/im2im/segmentation_plugin.yaml index a58143bc..8557fd01 100644 --- a/configs/model/im2im/segmentation_plugin.yaml +++ b/configs/model/im2im/segmentation_plugin.yaml @@ -10,11 +10,12 @@ backbone: spatial_dims: ${spatial_dims} in_channels: ${raw_im_channels} out_channels: 1 - strides: ${model._aux.strides} - kernel_size: ${model._aux.kernel_size} - upsample_kernel_size: ${model._aux.upsample_kernel_size} + strides: [1, 2, 2] + kernel_size: [3, 3, 3] + upsample_kernel_size: [2, 2] dropout: 0.0 res_block: True + filters: ${model._aux.filters} task_heads: target: @@ -30,6 +31,7 @@ task_heads: prediction: _target_: cyto_dl.models.im2im.utils.postprocessing.AutoThreshold method: "threshold_otsu" + save_input: True optimizer: @@ -48,10 +50,9 @@ lr_scheduler: inference_args: sw_batch_size: 1 roi_size: ${data._aux.patch_shape} - overlap: 0.0 + overlap: ${model._aux.overlap} mode: "gaussian" _aux: - strides: - kernel_size: - upsample_kernel_size: + filters: + overlap: