From 66a9d8d505ba2fbd0c7497a2ad71b03c45822cc4 Mon Sep 17 00:00:00 2001 From: BIGWangYuDong Date: Thu, 26 Oct 2023 00:38:48 +0800 Subject: [PATCH 1/5] [Feature] Support TIENet --- .../retinanet_r50_fpn_1x_urpc-coco.py | 91 ++++ .../base_editor/tienet_enhance_model.py | 38 ++ .../tienet_retinanet_r50_fpn_1x_urpc-coco.py | 35 ++ lqit/common/utils/__init__.py | 8 +- lqit/common/utils/lark_manager.py | 27 +- lqit/detection/models/detectors/__init__.py | 3 +- .../detectors/detector_with_enhance_model.py | 387 +++++++++++++++++ lqit/detection/utils/__init__.py | 3 + lqit/detection/utils/merge_det_results.py | 108 +++++ lqit/edit/models/editors/__init__.py | 1 + lqit/edit/models/editors/tienet/__init__.py | 4 + lqit/edit/models/editors/tienet/tienet.py | 191 ++++++++ .../models/editors/tienet/tienet_generator.py | 108 +++++ lqit/edit/models/losses/__init__.py | 4 +- lqit/edit/models/losses/structure_fft_loss.py | 409 ++++++++++++++++++ tools/test.py | 3 + tools/train.py | 7 +- 17 files changed, 1411 insertions(+), 16 deletions(-) create mode 100644 configs/detection/tienet/base_detector/retinanet_r50_fpn_1x_urpc-coco.py create mode 100644 configs/detection/tienet/base_editor/tienet_enhance_model.py create mode 100644 configs/detection/tienet/tienet_retinanet_r50_fpn_1x_urpc-coco.py create mode 100644 lqit/detection/models/detectors/detector_with_enhance_model.py create mode 100644 lqit/detection/utils/__init__.py create mode 100644 lqit/detection/utils/merge_det_results.py create mode 100644 lqit/edit/models/editors/tienet/__init__.py create mode 100644 lqit/edit/models/editors/tienet/tienet.py create mode 100644 lqit/edit/models/editors/tienet/tienet_generator.py create mode 100644 lqit/edit/models/losses/structure_fft_loss.py diff --git a/configs/detection/tienet/base_detector/retinanet_r50_fpn_1x_urpc-coco.py b/configs/detection/tienet/base_detector/retinanet_r50_fpn_1x_urpc-coco.py new file mode 100644 index 0000000..bfc825d --- /dev/null +++ b/configs/detection/tienet/base_detector/retinanet_r50_fpn_1x_urpc-coco.py @@ -0,0 +1,91 @@ +_base_ = [ + '../../_base_/datasets/urpc2020/urpc2020-validation_coco_detection.py', + '../../_base_/schedules/schedule_1x.py', '../../_base_/default_runtime.py' +] + +# model settings +model = dict( + type='RetinaNet', + data_preprocessor=dict( + type='DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32), + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs='on_input', + num_outs=5), + bbox_head=dict( + type='RetinaHead', + num_classes=4, + in_channels=256, + stacked_convs=4, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + octave_base_scale=4, + scales_per_octave=3, + ratios=[0.5, 1.0, 2.0], + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0)), + # model training and testing settings + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.4, + min_pos_iou=0, + ignore_iof_thr=-1), + sampler=dict( + type='PseudoSampler'), # Focal loss should use PseudoSampler + allowed_border=-1, + pos_weight=-1, + debug=False), + test_cfg=dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.5), + max_per_img=100)) + +# optimizer +optim_wrapper = dict( + optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001), + clip_grad=dict(max_norm=35, norm_type=2)) # loss may NaN without clip_grad + +param_scheduler = [ + dict( + type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, + end=1000), + dict( + type='MultiStepLR', + begin=0, + end=12, + by_epoch=True, + milestones=[8, 11], + gamma=0.1) +] diff --git a/configs/detection/tienet/base_editor/tienet_enhance_model.py b/configs/detection/tienet/base_editor/tienet_enhance_model.py new file mode 100644 index 0000000..70cdd39 --- /dev/null +++ b/configs/detection/tienet/base_editor/tienet_enhance_model.py @@ -0,0 +1,38 @@ +enhance_model = dict( + _scope_='lqit', + type='BaseEditModel', + destruct_gt=True, + data_preprocessor=dict( + type='EditDataPreprocessor', + mean=[0.0, 0.0, 0.0], + std=[255.0, 255.0, 255.0], + bgr_to_rgb=False, + gt_name='img'), + generator=dict( + type='TIENetGenerator', + model=dict( + type='TIENetEnhanceModel', + in_channels=3, + feat_channels=64, + out_channels=3, + num_blocks=3, + expand_ratio=0.5, + kernel_size=[1, 3, 5], + output_weight=[1.0, 1.0], + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='SiLU'), + use_depthwise=True), + spacial_pred='structure', + structure_pred='structure', + spacial_loss=dict(type='SpatialLoss', loss_weight=1.0), + tv_loss=dict(type='MaskedTVLoss', loss_mode='mse', loss_weight=10.0), + structure_loss=dict( + type='StructureFFTLoss', + radius=4, + pass_type='high', + channel_mean=True, + loss_type='mse', + guid_filter=dict( + type='GuidedFilter2d', radius=32, eps=1e-4, fast_s=2), + loss_weight=0.1))) diff --git a/configs/detection/tienet/tienet_retinanet_r50_fpn_1x_urpc-coco.py b/configs/detection/tienet/tienet_retinanet_r50_fpn_1x_urpc-coco.py new file mode 100644 index 0000000..ee2f3ad --- /dev/null +++ b/configs/detection/tienet/tienet_retinanet_r50_fpn_1x_urpc-coco.py @@ -0,0 +1,35 @@ +# default scope is mmdet +_base_ = [ + './base_editor/tienet_enhance_model.py', + './base_detector/retinanet_r50_fpn_1x_urpc-coco.py' +] + +model = dict( + _delete_=True, + type='lqit.DetectorWithEnhanceModel', + detector={{_base_.model}}, + enhance_model={{_base_.enhance_model}}, + train_mode='enhance', + pred_mode='enhance', + detach_enhance_img=False) + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001), + clip_grad=dict(max_norm=35, norm_type=2)) + +# dataset settings +train_pipeline = [ + dict(type='LoadImageFromFile', backend_args={{_base_.backend_args}}), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='Resize', scale=(1333, 800), keep_ratio=True), + dict(type='RandomFlip', prob=0.5), + dict(type='lqit.SetInputImageAsGT'), + dict(type='lqit.PackInputs') +] +train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) + +model_wrapper_cfg = dict( + type='lqit.SelfEnhanceModelDDP', + broadcast_buffers=False, + find_unused_parameters=False) diff --git a/lqit/common/utils/__init__.py b/lqit/common/utils/__init__.py index d30586f..64866b4 100644 --- a/lqit/common/utils/__init__.py +++ b/lqit/common/utils/__init__.py @@ -1,8 +1,10 @@ from .lark_manager import (MonitorManager, MonitorTracker, - context_monitor_manager, get_user_name, - initialize_monitor_manager, send_alert_message) + context_monitor_manager, get_error_message, + get_user_name, initialize_monitor_manager, + send_alert_message) __all__ = [ 'send_alert_message', 'get_user_name', 'initialize_monitor_manager', - 'context_monitor_manager', 'MonitorTracker', 'MonitorManager' + 'context_monitor_manager', 'MonitorTracker', 'MonitorManager', + 'get_error_message' ] diff --git a/lqit/common/utils/lark_manager.py b/lqit/common/utils/lark_manager.py index bc57499..d9b2241 100644 --- a/lqit/common/utils/lark_manager.py +++ b/lqit/common/utils/lark_manager.py @@ -203,16 +203,7 @@ def monitor_exception(self) -> None: assert self.url is not None, \ 'Please run `MonitorManager.start_monitor` first.' - filtered_trace = traceback.format_exc().split('\n')[-15:] - format_trace = '' - for line in filtered_trace: - format_trace += '\n' + line - - # try to add error message into logger else directly print message - try: - print_log(format_trace, logger='current') - except Exception: - print(format_trace) + format_trace = get_error_message() title = 'Task Error Report' content = f"{self.user_name}'s {self.task_type} task\n" \ f'Config file: {self.cfg_file}\n' \ @@ -346,3 +337,19 @@ def context_monitor_manager(monitor_manager: Optional[MonitorManager] = None): monitor_manager.stop_monitor() else: yield + + +def get_error_message() -> None: + """Catch and format exception information, send alert message to Feishu.""" + + filtered_trace = traceback.format_exc().split('\n')[-15:] + format_trace = '' + for line in filtered_trace: + format_trace += '\n' + line + + # try to add error message into logger else directly print message + try: + print_log(format_trace, logger='current') + except Exception: + print(format_trace) + return format_trace diff --git a/lqit/detection/models/detectors/__init__.py b/lqit/detection/models/detectors/__init__.py index 21c50c9..9e25cb0 100644 --- a/lqit/detection/models/detectors/__init__.py +++ b/lqit/detection/models/detectors/__init__.py @@ -1,3 +1,4 @@ +from .detector_with_enhance_model import DetectorWithEnhanceModel from .edffnet import EDFFNet from .multi_input_wrapper import MultiInputDetectorWrapper from .single_stage_enhance_head import SingleStageDetector @@ -5,5 +6,5 @@ __all__ = [ 'TwoStageWithEnhanceHead', 'MultiInputDetectorWrapper', - 'SingleStageDetector', 'EDFFNet' + 'SingleStageDetector', 'EDFFNet', 'DetectorWithEnhanceModel' ] diff --git a/lqit/detection/models/detectors/detector_with_enhance_model.py b/lqit/detection/models/detectors/detector_with_enhance_model.py new file mode 100644 index 0000000..892668f --- /dev/null +++ b/lqit/detection/models/detectors/detector_with_enhance_model.py @@ -0,0 +1,387 @@ +import copy +from typing import Any, Dict, Optional, Tuple, Union + +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from mmengine.model import BaseModel +from mmengine.model.wrappers import MMDistributedDataParallel as MMDDP +from mmengine.utils import is_list_of +from torch import Tensor + +from lqit.common.structures import SampleList +from lqit.detection.utils import merge_det_results +from lqit.edit.models.post_processor import add_pixel_pred_to_datasample +from lqit.registry import MODEL_WRAPPERS, MODELS + +ForwardResults = Union[Dict[str, Tensor], SampleList, Tuple[Tensor], Tensor] + + +@MODELS.register_module() +class DetectorWithEnhanceModel(BaseModel): + """Detector with enhance model. + + The `DetectorWithEnhanceModel` usually combines a detector and an enhance + model. It has three train mode: `raw`, `enhance` and + `both`. The `raw` mode only train the detector with raw image. The + `enhance` mode only train the detector with enhance image. The `both` mode + train the detector with both raw and enhance image. + + Args: + detector (dict or ConfigDict): Config for detector. + enhance_model (dict or ConfigDict, optional): Config for enhance model. + loss_weight (list): Detection loss weight for raw and enhanced image. + Only used when `train_mode` is `both`. + vis_enhance (bool): Whether visualize enhance image during inference. + Defaults to False. + train_mode (str): Train mode of detector, support `raw`, `enhance` and + `both`. Defaults to `enhance`. + pred_mode (str): Predict mode of detector, support `raw`, `enhance`, + and `both`. Defaults to `enhance`. + detach_enhance_img (bool): Whether stop the gradient of enhance image. + Defaults to False. + merge_cfg (dict or ConfigDict, optional): The config to control the + merge process of raw and enhance image. Defaults to None. + init_cfg (dict or ConfigDict, optional): The config to control the + initialization. Defaults to None. + """ + + def __init__(self, + detector: ConfigType, + enhance_model: OptConfigType = None, + loss_weight: list = [0.5, 0.5], + vis_enhance: Optional[bool] = False, + train_mode: str = 'enhance', + pred_mode: str = 'enhance', + detach_enhance_img: bool = False, + merge_cfg: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__(init_cfg=init_cfg) + # if train_mode is `both`, the loss_weight should be a list of two + # elements, which means the loss weight of raw and enhance image. + if enhance_model is not None and train_mode == 'both': + assert isinstance(loss_weight, list) and len(loss_weight) == 2 + + assert pred_mode in ['raw', 'enhance', 'both'] + assert train_mode in ['raw', 'enhance', 'both'] + self.train_mode = train_mode + self.pred_mode = pred_mode + + # build detector + self.detector = MODELS.build(detector) + # build enhance model + if enhance_model is not None: + self.enhance_model = MODELS.build(enhance_model) + else: + self.enhance_model = None + + self.detach_enhance_img = detach_enhance_img + if vis_enhance: + assert self.with_enhance_model + self.vis_enhance = vis_enhance + + self.merge_cfg = merge_cfg + if train_mode == 'both': + # if train_mode is `both`, should have enhance_model. + assert self.with_enhance_model + assert merge_cfg is not None + # The loss_weight should be a list of two elements, which means + # the loss weight of raw and enhance image. + assert isinstance(loss_weight, list) and len(loss_weight) == 2 + self.prefix_name = ['raw', 'enhance'] + self.loss_weight = loss_weight + elif train_mode == 'enhance': + # if train_mode is `enhance`, should have enhance_model. + assert self.with_enhance_model + self.prefix_name = ['enhance'] + self.loss_weight = [1.0] + else: + self.prefix_name = ['raw'] + self.loss_weight = [1.0] + + @property + def with_enhance_model(self) -> bool: + """bool: whether the detector has a Enhance Model""" + return (hasattr(self, 'enhance_model') + and self.enhance_model is not None) + + def forward(self, + data: dict, + mode: str = 'tensor', + **kwargs) -> ForwardResults: + assert isinstance(data, dict) + if mode == 'loss': + return self.loss(data) + elif mode == 'predict': + return self.predict(data) + elif mode == 'tensor': + return self._forward(data) + else: + raise RuntimeError(f'Invalid mode "{mode}". ' + 'Only supports loss, predict and tensor mode') + + def _run_forward(self, data: dict, + mode: str) -> Union[Dict[str, Tensor], list]: + """Unpacks data for :meth:`forward` + + Args: + data (dict or tuple or list): Data sampled from dataset. + mode (str): Mode of forward. + + Returns: + dict or list: Results of training or testing mode. + """ + assert isinstance(data, dict), \ + 'The output of DataPreprocessor should be a dict, ' \ + 'which only deal with `cast_data`. The data_preprocessor ' \ + 'should process in forward.' + results = self(data, mode=mode) + + return results + + def _preprocess_data(self, data: Union[dict, list, tuple]) -> tuple: + """Preprocess data to a tuple of (batch_inputs, batch_data_samples).""" + if isinstance(data, dict): + batch_inputs = data['inputs'] + batch_data_samples = data['data_samples'] + elif isinstance(data, (list, tuple)): + batch_inputs = data[0] + batch_data_samples = data[1] + else: + raise TypeError('Output of `data_preprocessor` should be ' + 'list, tuple or dict, but got ' + f'{type(data)}') + return batch_inputs, batch_data_samples + + def _forward(self, data: dict) -> tuple: + """Network forward process. Usually includes backbone, neck and head + forward without any post-processing. + + Args: + data (dict): Data sampled from dataloader, usually contains + following keys. + + - inputs (list[Tensor]): A list of input image. + - data_samples (:obj:`DataSample`): A list of DataSample. + + Returns: + tuple: A tuple of features from ``rpn_head`` and ``roi_head`` + forward. + """ + results = () + + raw_det_data = self.detector.data_preprocessor(data, True) + raw_batch_inputs, batch_det_data_samples = self._preprocess_data( + raw_det_data) + if self.pred_mode in ['enhance', 'both']: + enhance_raw_data = self.enhance_model.data_preprocessor(data, True) + + # the enhance model should have `loss_and_predict` mode + if isinstance(enhance_raw_data, dict): + enhance_results = self.enhance_model( + **enhance_raw_data, mode='predict') + elif isinstance(enhance_raw_data, (list, tuple)): + enhance_results = self.enhance_model( + *enhance_raw_data, mode='predict') + else: + raise TypeError('Output of `data_preprocessor` should be ' + 'list, tuple or dict, but got ' + f'{type(enhance_raw_data)}') + enhance_img_list = [ + result.pred_pixel.pred_img for result in enhance_results + ] + # get enhance_batch_inputs of detector + enhance_data = {'inputs': enhance_img_list} + enhance_det_data = self.detector.data_preprocessor( + enhance_data, True) + enhance_batch_inputs, _ = self._preprocess_data(enhance_det_data) + results = results + (enhance_batch_inputs, ) + + if self.pred_mode == 'raw': + raw_results = self.detector( + raw_batch_inputs, batch_det_data_samples, mode='tensor') + results = results + (raw_results, ) + elif self.pred_mode == 'enhance': + enhance_results = self.detector( + enhance_batch_inputs, batch_det_data_samples, mode='tensor') + results = results + (enhance_results, ) + else: + raw_results = self.detector( + raw_batch_inputs, batch_det_data_samples, mode='tensor') + results = results + (raw_results, ) + enhance_results = self.detector( + enhance_batch_inputs, batch_det_data_samples, mode='tensor') + results = results + (enhance_results, ) + return results + + def loss(self, data: dict) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Args: + data (dict): Data sampled from dataloader, usually contains + following keys. + + - inputs (list[Tensor]): A list of input image. + - data_samples (:obj:`DataSample`): A list of DataSample. + + Returns: + dict: A dictionary of loss components + """ + losses = dict() + + # get batch_inputs and batch_data_samples of detector + raw_det_data = self.detector.data_preprocessor(data, True) + raw_batch_inputs, batch_det_data_samples = self._preprocess_data( + raw_det_data) + + if self.train_mode in ['enhance', 'both']: + # get batch_inputs and batch_data_samples of enhance_model + enhance_raw_data = self.enhance_model.data_preprocessor(data, True) + + # the enhance model should have `loss_and_predict` mode + if isinstance(enhance_raw_data, dict): + results = self.enhance_model( + **enhance_raw_data, mode='loss_and_predict') + elif isinstance(enhance_raw_data, (list, tuple)): + results = self.enhance_model( + *enhance_raw_data, mode='loss_and_predict') + else: + raise TypeError('Output of `data_preprocessor` should be ' + 'list, tuple or dict, but got ' + f'{type(enhance_raw_data)}') + # results should have `enhance_loss` and `enhance_results` + enhance_loss, enhance_results = results + losses.update(enhance_loss) + + enhance_img_list = [ + result.pred_pixel.pred_img for result in enhance_results + ] + # get enhance_batch_inputs of detector + enhance_data = {'inputs': enhance_img_list} + enhance_det_data = self.detector.data_preprocessor( + enhance_data, True) + + enhance_batch_inputs, _ = self._preprocess_data(enhance_det_data) + if self.detach_enhance_img: + # if self.detach_enhance_img is True, stop the gradient of + # enhance image. + enhance_batch_inputs = enhance_batch_inputs.detach() + + if self.train_mode == 'both': + batch_inputs_list = [raw_batch_inputs, enhance_batch_inputs] + elif self.train_mode == 'raw': + batch_inputs_list = [raw_batch_inputs] + else: + batch_inputs_list = [enhance_batch_inputs] + + for i, batch_inputs in enumerate(batch_inputs_list): + temp_losses = self.detector( + batch_inputs, batch_det_data_samples, mode='loss') + + for name, value in temp_losses.items(): + if 'loss' in name: + if isinstance(value, Tensor): + value = value * self.loss_weight[i] + elif is_list_of(value, Tensor): + value = [_v * self.loss_weight[i] for _v in value] + losses[f'{self.prefix_name[i]}_{name}'] = value + return losses + + def predict(self, data: dict) -> SampleList: + """Predict results from a batch of inputs and data samples with post- + processing. + + Args: + data (dict): Data sampled from dataloader, usually contains + following keys. + + - inputs (list[Tensor]): A list of input image. + - data_samples (:obj:`DataSample`): A list of DataSample. + + Returns: + list[:obj:`DataSample`]: Return the detection results of the + input images. The returns value is DataSample, + which usually contain 'pred_instances'. And the + ``pred_instances`` usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, H, W). + """ + # get batch_inputs and batch_data_samples of detector + raw_det_data = self.detector.data_preprocessor(data, True) + raw_batch_inputs, batch_data_samples = self._preprocess_data( + raw_det_data) + + if self.pred_mode in ['enhance', 'both']: + enhance_raw_data = self.enhance_model.data_preprocessor(data, True) + # get enhance image + if isinstance(enhance_raw_data, dict): + results = self.enhance_model( + **enhance_raw_data, mode='predict') + elif isinstance(enhance_raw_data, (list, tuple)): + results = self.enhance_model(*enhance_raw_data, mode='predict') + else: + raise TypeError('Output of `data_preprocessor` should be ' + 'list, tuple or dict, but got ' + f'{type(enhance_raw_data)}') + enhance_img_list = [ + result.pred_pixel.pred_img for result in results + ] + # add into batch_data_samples + batch_data_samples = add_pixel_pred_to_datasample( + data_samples=batch_data_samples, pixel_list=enhance_img_list) + enhance_data = {'inputs': enhance_img_list} + enhance_det_data = self.detector.data_preprocessor( + enhance_data, True) + enhance_batch_inputs, _ = self._preprocess_data(enhance_det_data) + + if self.pred_mode == 'raw': + batch_data_samples = self.detector( + raw_batch_inputs, batch_data_samples, mode='predict') + elif self.pred_mode == 'enhance': + batch_data_samples = self.detector( + enhance_batch_inputs, batch_data_samples, mode='predict') + else: + raw_batch_data_samples = copy.deepcopy(batch_data_samples) + raw_batch_data_samples = self.detector( + raw_batch_inputs, raw_batch_data_samples, mode='predict') + + enhance_batch_data_samples = copy.deepcopy(batch_data_samples) + enhance_batch_data_samples = self.detector( + enhance_batch_inputs, + enhance_batch_data_samples, + mode='predict') + + batch_data_samples = [] + for raw_data_sample, enhance_data_sample in zip( + raw_batch_data_samples, enhance_batch_data_samples): + batch_data_samples.append( + [raw_data_sample, enhance_data_sample]) + batch_data_samples = merge_det_results(batch_data_samples, + self.merge_cfg) + return batch_data_samples + + +@MODEL_WRAPPERS.register_module() +class SelfEnhanceModelDDP(MMDDP): + + def _run_forward(self, data: Union[dict, tuple, list], mode: str) -> Any: + """Unpacks data for :meth:`forward` + + Args: + data (dict or tuple or list): Data sampled from dataset. + mode (str): Mode of forward. + + Returns: + dict or list: Results of training or testing mode. + """ + assert isinstance(data, dict), \ + 'The output of DataPreprocessor should be a dict, ' \ + 'which only deal with `cast_data`. The data_preprocessor ' \ + 'should process in forward.' + results = self(data, mode=mode) + + return results diff --git a/lqit/detection/utils/__init__.py b/lqit/detection/utils/__init__.py new file mode 100644 index 0000000..4dcb802 --- /dev/null +++ b/lqit/detection/utils/__init__.py @@ -0,0 +1,3 @@ +from .merge_det_results import merge_preds + +__all__ = ['merge_preds'] diff --git a/lqit/detection/utils/merge_det_results.py b/lqit/detection/utils/merge_det_results.py new file mode 100644 index 0000000..c917e20 --- /dev/null +++ b/lqit/detection/utils/merge_det_results.py @@ -0,0 +1,108 @@ +# Modified from https://github.com/open-mmlab/mmdetection/blob/main/mmdet/models/test_time_augs/det_tta.py # noqa +from typing import List, Tuple + +import torch +from mmcv.ops import batched_nms +from mmdet.structures.bbox import bbox_flip +from mmengine.structures import InstanceData +from torch import Tensor + +from lqit.common.structures import DataSample + + +def merge_preds(data_samples_list: List[List[DataSample]], + merge_cfg: dict) -> List[DataSample]: + """Merge batch predictions of enhanced data. + + Args: + data_samples_list (List[List[DataSample]]): List of predictions + of all enhanced data. The outer list indicates images, and the + inner list corresponds to the different views of one image. + Each element of the inner list is a ``DataSample``. + merge_cfg (dict): Config of merge method. + + Returns: + List[DataSample]: Merged batch prediction. + """ + merged_data_samples = [] + for data_samples in data_samples_list: + merged_data_samples.append(_merge_single_sample(data_samples)) + return merged_data_samples + + +def _merge_single_sample(data_samples: List[DataSample], + merge_cfg: dict) -> DataSample: + """Merge predictions which come form the different views of one image to + one prediction. + + Args: + data_samples (List[DataSample]): List of predictions + of enhanced data which come form one image. + merge_cfg (dict): Config of merge method. + + Returns: + List[DataSample]: Merged prediction. + """ + aug_bboxes = [] + aug_scores = [] + aug_labels = [] + img_metas = [] + # TODO: support instance segmentation TTA + assert data_samples[0].pred_instances.get('masks', None) is None, \ + 'TTA of instance segmentation does not support now.' + for data_sample in data_samples: + aug_bboxes.append(data_sample.pred_instances.bboxes) + aug_scores.append(data_sample.pred_instances.scores) + aug_labels.append(data_sample.pred_instances.labels) + img_metas.append(data_sample.metainfo) + + merged_bboxes, merged_scores = merge_aug_bboxes(aug_bboxes, aug_scores, + img_metas) + merged_labels = torch.cat(aug_labels, dim=0) + + if merged_bboxes.numel() == 0: + return data_samples[0] + + det_bboxes, keep_idxs = batched_nms(merged_bboxes, merged_scores, + merged_labels, merge_cfg.nms) + + det_bboxes = det_bboxes[:merge_cfg.max_per_img] + det_labels = merged_labels[keep_idxs][:merge_cfg.max_per_img] + + results = InstanceData() + _det_bboxes = det_bboxes.clone() + results.bboxes = _det_bboxes[:, :-1] + results.scores = _det_bboxes[:, -1] + results.labels = det_labels + det_results = data_samples[0] + det_results.pred_instances = results + return det_results + + +def merge_aug_bboxes(aug_bboxes: List[Tensor], aug_scores: List[Tensor], + img_metas: List[str]) -> Tuple[Tensor, Tensor]: + """Merge augmented detection bboxes and scores. + + Args: + aug_bboxes (list[Tensor]): shape (n, 4*#class) + aug_scores (list[Tensor] or None): shape (n, #class) + Returns: + tuple[Tensor]: ``bboxes`` with shape (n,4), where + 4 represent (tl_x, tl_y, br_x, br_y) + and ``scores`` with shape (n,). + """ + recovered_bboxes = [] + for bboxes, img_info in zip(aug_bboxes, img_metas): + ori_shape = img_info['ori_shape'] + flip = img_info['flip'] + flip_direction = img_info['flip_direction'] + if flip: + bboxes = bbox_flip( + bboxes=bboxes, img_shape=ori_shape, direction=flip_direction) + recovered_bboxes.append(bboxes) + bboxes = torch.cat(recovered_bboxes, dim=0) + if aug_scores is None: + return bboxes + else: + scores = torch.cat(aug_scores, dim=0) + return bboxes, scores diff --git a/lqit/edit/models/editors/__init__.py b/lqit/edit/models/editors/__init__.py index dd58005..95dff08 100644 --- a/lqit/edit/models/editors/__init__.py +++ b/lqit/edit/models/editors/__init__.py @@ -1,2 +1,3 @@ +from .tienet import * # noqa: F401,F403 from .unet import * # noqa: F401,F403 from .zero_dce import * # noqa: F401,F403 diff --git a/lqit/edit/models/editors/tienet/__init__.py b/lqit/edit/models/editors/tienet/__init__.py new file mode 100644 index 0000000..7b954bd --- /dev/null +++ b/lqit/edit/models/editors/tienet/__init__.py @@ -0,0 +1,4 @@ +from .tienet import TIENetEnhanceModel +from .tienet_generator import TIENetGenerator + +__all__ = ['TIENetEnhanceModel', 'TIENetGenerator'] diff --git a/lqit/edit/models/editors/tienet/tienet.py b/lqit/edit/models/editors/tienet/tienet.py new file mode 100644 index 0000000..4d1f910 --- /dev/null +++ b/lqit/edit/models/editors/tienet/tienet.py @@ -0,0 +1,191 @@ +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule +from mmengine.model import BaseModule +from torch import Tensor + +from lqit.registry import MODELS +from lqit.utils import ConfigType, OptConfigType + + +@MODELS.register_module() +class TIENetEnhanceModel(BaseModule): + """The Enhance Model of TIENet.""" + + def __init__(self, + in_channels=3, + feat_channels=64, + out_channels=3, + num_blocks=4, + expand_ratio=1.0, + kernel_size=[1, 3, 5, 7], + conv_cfg: OptConfigType = None, + norm_cfg: ConfigType = dict(type='BN'), + act_cfg: ConfigType = dict(type='SiLU'), + use_depthwise: bool = True, + output_weight: list = [0.8, 0.2], + init_cfg=[ + dict(type='Normal', layer='Conv2d', mean=0, std=0.02), + dict( + type='Normal', + layer='BatchNorm2d', + mean=1.0, + std=0.02, + bias=0), + ]): + super().__init__(init_cfg=init_cfg) + assert len(kernel_size) == num_blocks + self.in_channels = in_channels + self.stem = ConvModule( + in_channels, + feat_channels, + 3, + padding=1, + stride=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + layers = [] + for i in range(num_blocks): + if i == num_blocks - 1: + _out_channels = out_channels + else: + _out_channels = feat_channels + layer = SelfEnhanceLayer( + in_channels=feat_channels, + out_channels=_out_channels, + expand_ratio=expand_ratio, + use_depthwise=use_depthwise, + kernel_size=kernel_size[i], + norm_cfg=norm_cfg, + act_cfg=act_cfg, + conv_cfg=conv_cfg) + layers.append(layer) + self.layers = nn.Sequential(*layers) + + assert len(output_weight) == 2 + self.raw_weight = output_weight[0] + self.structure_weight = output_weight[1] + + def forward(self, x: Tensor) -> Tensor: + x_stem = self.stem(x) + + out = self.layers(x_stem) # enhanced img structure + out_img = self.raw_weight * out + \ + self.structure_weight * x # enhance img + cat_tensor = torch.cat([out_img, out], dim=1) + return cat_tensor + + +class SelfEnhanceLayer(BaseModule): + + def __init__(self, + in_channels: int, + out_channels: int, + expand_ratio: float = 0.5, + use_depthwise: bool = False, + kernel_size: int = 3, + norm_cfg: ConfigType = dict(type='BN'), + act_cfg: ConfigType = dict(type='SiLU'), + conv_cfg: OptConfigType = None, + init_cfg=[ + dict(type='Normal', layer='Conv2d', mean=0, std=0.02), + dict( + type='Normal', + layer='BatchNorm2d', + mean=1.0, + std=0.02, + bias=0), + ]): + conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule + super().__init__(init_cfg=init_cfg) + mid_channels = int(out_channels * expand_ratio) + + self.main_conv = ConvModule( + in_channels, + mid_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.short_conv = conv( + in_channels, + mid_channels, + kernel_size, + stride=1, + padding=kernel_size // 2, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.final_conv = ConvModule( + 2 * mid_channels, + out_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.block = SelfEnhanceBasicBlock( + in_channels=mid_channels, + out_channels=mid_channels, + expansion=1.0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, x: Tensor) -> Tensor: + x_short = self.short_conv(x) + + x_main = self.main_conv(x) + x_main = self.block(x_main) + + x_final = torch.cat((x_main, x_short), dim=1) + + out = self.final_conv(x_final) + return out + + +class SelfEnhanceBasicBlock(BaseModule): + + def __init__(self, + in_channels: int, + out_channels: int, + expansion: float = 1.0, + conv_cfg: OptConfigType = None, + norm_cfg: ConfigType = dict(type='BN'), + act_cfg: ConfigType = dict(type='SiLU'), + init_cfg=[ + dict(type='Normal', layer='Conv2d', mean=0, std=0.02), + dict( + type='Normal', + layer='BatchNorm2d', + mean=1.0, + std=0.02, + bias=0), + ]): + super().__init__(init_cfg=init_cfg) + mid_channels = int(out_channels * expansion) + self.conv1 = ConvModule( + in_channels, + mid_channels, + 3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.conv2 = ConvModule( + in_channels, + out_channels, + 3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, x: Tensor) -> Tensor: + identity = x + out = self.conv1(x) + out = self.conv2(out) + out = out + identity + + return out diff --git a/lqit/edit/models/editors/tienet/tienet_generator.py b/lqit/edit/models/editors/tienet/tienet_generator.py new file mode 100644 index 0000000..81f5971 --- /dev/null +++ b/lqit/edit/models/editors/tienet/tienet_generator.py @@ -0,0 +1,108 @@ +from typing import List + +from lqit.edit.models.base_models import BaseGenerator +from lqit.edit.structures import BatchPixelData +from lqit.registry import MODELS +from lqit.utils import ConfigType, OptConfigType, OptMultiConfig + + +@MODELS.register_module() +class TIENetGenerator(BaseGenerator): + """The generator of TIENet.""" + + def __init__(self, + model: ConfigType, + spacial_loss: ConfigType = dict( + type='SpatialLoss', loss_weight=1.0), + tv_loss: ConfigType = dict( + type='MaskedTVLoss', loss_mode='mse', loss_weight=10.0), + structure_loss: OptConfigType = None, + spacial_pred: str = 'structure', + structure_pred: str = 'structure', + perceptual_loss: OptConfigType = None, + init_cfg: OptMultiConfig = None, + **kwargs) -> None: + super().__init__( + model=model, perceptual_loss=perceptual_loss, init_cfg=init_cfg) + + assert structure_pred in ['output', 'structure'] + self.structure_pred = structure_pred + + assert spacial_pred in ['output', 'structure'] + self.spacial_pred = spacial_pred + # build losses + if spacial_loss is not None: + self.spacial_loss = MODELS.build(spacial_loss) + else: + self.spacial_loss = None + + if tv_loss is not None: + self.tv_loss = MODELS.build(tv_loss) + else: + self.tv_loss = None + + if structure_loss is not None: + self.structure_loss = MODELS.build(structure_loss) + else: + self.structure_loss = None + + def loss(self, loss_input: BatchPixelData, + batch_img_metas: List[dict]) -> dict: + """Calculate the loss based on the outputs of generator.""" + losses = dict() + + batch_outputs = loss_input.output + batch_inputs = loss_input.input + + in_channels = self.model.in_channels + batch_enhance_img = batch_outputs[:, :in_channels, :, :] + batch_enhance_structure = batch_outputs[:, in_channels:, :, :] + if self.tv_loss is not None: + tv_loss = self.tv_loss(batch_enhance_structure) + losses['tv_loss'] = tv_loss + + if self.spacial_loss is not None: + if self.spacial_pred == 'output': + spacial_loss = self.spacial_loss(batch_enhance_img, + batch_inputs) + else: + spacial_loss = self.spacial_loss(batch_enhance_structure, + batch_inputs) + losses['spacial_loss'] = spacial_loss + + if self.structure_loss is not None: + if self.structure_pred == 'output': + de_batch_outputs = loss_input.de_output + else: + de_batch_outputs = loss_input.de_structure + de_batch_inputs = loss_input.de_input + structure_loss = self.structure_loss(de_batch_outputs, + de_batch_inputs, + batch_img_metas) + losses['structure_loss'] = structure_loss + + if self.perceptual_loss is not None: + de_batch_outputs = loss_input.de_output + de_batch_inputs = loss_input.de_input + if de_batch_outputs.shape[1] > 3: + de_batch_outputs = de_batch_outputs[:, :in_channels, :, :] + # norm to 0-1 + de_batch_outputs = de_batch_outputs / 255 + de_batch_inputs = de_batch_inputs / 255 + loss_percep, loss_style = self.perceptual_loss( + de_batch_outputs, de_batch_inputs) + if loss_percep is not None: + losses['perceptual_loss'] = loss_percep + if loss_style is not None: + losses['style_loss'] = loss_style + + return losses + + def post_precess(self, outputs): + assert outputs.dim() in [3, 4] + in_channels = self.model.in_channels + if outputs.dim() == 4: + enhance_img = outputs[:, :in_channels, :, :] + else: + enhance_img = outputs[:in_channels, :, :] + return enhance_img diff --git a/lqit/edit/models/losses/__init__.py b/lqit/edit/models/losses/__init__.py index 7a94ac7..5d983a1 100644 --- a/lqit/edit/models/losses/__init__.py +++ b/lqit/edit/models/losses/__init__.py @@ -3,10 +3,12 @@ from .pixelwise_loss import (CharbonnierLoss, ColorLoss, ExposureLoss, L1Loss, MaskedTVLoss, MSELoss, SpatialLoss) from .ssim_loss import SSIMLoss +from .structure_fft_loss import StructureFFTLoss from .utils import mask_reduce_loss, reduce_loss __all__ = [ 'CharbonnierLoss', 'L1Loss', 'MaskedTVLoss', 'MSELoss', 'SpatialLoss', 'PerceptualLoss', 'PerceptualVGG', 'TransferalPerceptualLoss', 'SSIMLoss', - 'ExposureLoss', 'ColorLoss', 'mask_reduce_loss', 'reduce_loss' + 'ExposureLoss', 'ColorLoss', 'mask_reduce_loss', 'reduce_loss', + 'StructureFFTLoss' ] diff --git a/lqit/edit/models/losses/structure_fft_loss.py b/lqit/edit/models/losses/structure_fft_loss.py new file mode 100644 index 0000000..2f59798 --- /dev/null +++ b/lqit/edit/models/losses/structure_fft_loss.py @@ -0,0 +1,409 @@ +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from lqit.registry import MODELS +from lqit.utils import OptConfigType + + +@MODELS.register_module() +class StructureFFTLoss(nn.Module): + """`FFT-based Structure Loss. + + `_ + + Args: + radius (int): Radius of the mask. + pass_type (str): FFT pass type, can be 'high' or 'low'. + Defaults to 'high'. + shape (str): Shape of the mask, can be 'cycle' or 'square'. + Defaults to 'cycle'. + channel_mean (bool): Whether to calculate channel mean before + calculating loss. Defaults to False. + loss_type (str): Type of loss, can be 'l1' or 'mse'. + Defaults to 'mse'. + guid_filter (dict or ConfigDict, optional): Config of guided filter, + which is used to smooth the high pass image. Defaults to None. + loss_weight (float): Weight of the loss. Defaults to 1.0. + """ + + def __init__(self, + radius: int = 64, + pass_type: str = 'high', + shape: str = 'cycle', + channel_mean: bool = False, + loss_type: str = 'mse', + guid_filter: OptConfigType = None, + loss_weight=1.0) -> None: + super().__init__() + assert pass_type in ['high', 'low'] + self.pass_type = pass_type + assert shape in ['cycle', 'square'] + if shape == 'cycle': + self.center_mask = self._cycle_mask(radius) + else: + self.center_mask = self._square_mask(radius) + + assert loss_type in ['l1', 'mse'] + self.loss_type = loss_type + self.channel_mean = channel_mean + self.radius = radius + self.loss_weight = loss_weight + + if guid_filter is not None: + self.guid_filter = MODELS.build(guid_filter) + else: + self.guid_filter = None + + def _cycle_mask(self, radius: int) -> Tensor: + """Generate a cycle mask.""" + x = torch.arange(0, 2 * radius)[None, :] + y = torch.arange(2 * radius - 1, -1, -1)[:, None] + cycle_mask = ((x - radius)**2 + (y - radius)**2) <= (radius - 1)**2 + return cycle_mask + + def _square_mask(self, radius: int) -> Tensor: + """Generate a square mask.""" + square_mask = torch.ones((radius * 2, radius * 2), dtype=torch.bool) + return square_mask + + def _get_mask(self, img: Tensor) -> Tensor: + """Get the mask of the image.""" + device = img.device + center_mask = self.center_mask.to(device) + hw_img = img[0, ...] + mask = torch.zeros_like(hw_img, dtype=torch.bool) + height, width = mask.shape[0], mask.shape[1] + x_c, y_c = width // 2, height // 2 + + mask[y_c - self.radius:y_c + self.radius, + x_c - self.radius:x_c + self.radius] = center_mask + if self.pass_type == 'high': + mask = ~mask + return mask + + def forward(self, pred: Tensor, target: Tensor, batch_img_metas: list, + **kwargs) -> Tensor: + """Forward function. + + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + batch_img_metas (list[dict]): List of image information. + + Returns: + Tensor: Calculated loss. + """ + assert pred.shape == target.shape + if self.channel_mean: + pred = torch.mean(pred, dim=1, keepdim=True) + target = torch.mean(target, dim=1, keepdim=True) + + losses = [] + + for _pred, _target, img_meta in zip(pred, target, batch_img_metas): + assert len(_pred.shape) == len(_target.shape) == 3 + h, w = img_meta['img_shape'] + no_padding_pred = _pred[:, :h, :w] + no_padding_target = _target[:, :h, :w] + mask = self._get_mask(no_padding_pred) + + high_pass_target = self.get_pass_img(no_padding_target, mask) + + if self.guid_filter is not None: + high_pass_target = self.guid_filter( + high_pass_target[None, ...], no_padding_target[None, + ...])[0] + high_pass_target = high_pass_target.clip_(min=1e-7, max=255) + + high_pass_pred = self.get_pass_img(no_padding_pred, mask) + norm_high_pass_pred = high_pass_pred / 255 + norm_high_pass_target = high_pass_target / 255 + + # Check if here got NaN of inf + nan_number1 = torch.where(torch.isnan(norm_high_pass_pred) == 1)[0] + nan_number2 = torch.where( + torch.isnan(norm_high_pass_target) == 1)[0] + if len(nan_number1) > 0 or len(nan_number2) > 0: + norm_high_pass_pred[torch.isnan(norm_high_pass_pred)] = 1.0 + norm_high_pass_target[torch.isnan(norm_high_pass_target)] = 1.0 + inf_number1 = torch.where(torch.isinf(norm_high_pass_pred) == 1)[0] + inf_number2 = torch.where( + torch.isinf(norm_high_pass_target) == 1)[0] + + if len(inf_number1) > 0 or len(inf_number2) > 0: + norm_high_pass_pred[torch.isinf(norm_high_pass_pred)] = 0.0 + norm_high_pass_target[torch.isinf(norm_high_pass_target)] = 0.0 + + if self.loss_type == 'l1': + loss = F.l1_loss( + norm_high_pass_pred, + norm_high_pass_target, + reduction='mean') + else: + loss = F.mse_loss( + norm_high_pass_pred, + norm_high_pass_target, + reduction='mean') + losses.append(loss) + total_loss = sum(_loss.mean() for _loss in losses) + return total_loss * self.loss_weight + + @staticmethod + def get_pass_img(img, mask: Tensor) -> Tensor: + """Get the FFT filtered image.""" + channel_img_list = [] + for i in range(img.size(0)): + channel_img = img[i, ...] + f = torch.fft.fft2(channel_img) + fshift = torch.fft.fftshift(f) + filter_fshift = fshift * mask + + ishift = torch.fft.ifftshift(filter_fshift) + high_pass_img = torch.fft.ifft2(ishift) + high_pass_img = torch.abs(high_pass_img).clip_(min=1e-7, max=255) + channel_img_list.append(high_pass_img[None, ...]) + result_img = torch.cat(channel_img_list, dim=0) + return result_img + + +@MODELS.register_module() +class GuidedFilter2d(nn.Module): + """Guided filter for 2D image.""" + + def __init__(self, + radius: int = 30, + eps: float = 1e-4, + fast_s: Optional[int] = None, + channel_wise: bool = True): + super().__init__() + self.r = radius + self.eps = eps + self.fast_s = fast_s + self.channel_wise = channel_wise + + def forward(self, x: Tensor, guide: Tensor) -> Tensor: + """Forward function.""" + if guide.shape[1] == 3: + if self.channel_wise: + assert x.shape == guide.shape + channel_result = [] + for i in range(3): + result = self.guidedfilter2d_gray(guide[:, i:i + 1, ...], + x[:, i:i + 1, + ...], self.r, self.eps, + self.fast_s) + channel_result.append(result) + + results = torch.cat(channel_result, dim=1) + return results + else: + return self.guidedfilter2d_color(guide, x, self.r, self.eps, + self.fast_s) + elif guide.shape[1] == 1: + return self.guidedfilter2d_gray(guide, x, self.r, self.eps, + self.fast_s) + else: + raise NotImplementedError + + def guidedfilter2d_color(self, + guide: Tensor, + src: Tensor, + radius: int, + eps: float, + scale: Optional[int] = None) -> Tensor: + """guided filter for a color guide image. + + Args: + guide (Tensor): Guide imageof shape (B, 3, H, W). + src (Tensor): Filtering image of shape (B, C, H, W). + radius (int): Filter radius. + eps (float): Regularization coefficient. + scale (int, optional): Scale factor of the image. Defaults to None. + + Returns: + Tensor: Filtered image. + """ + assert guide.shape[1] == 3 + if src.ndim == 3: + src = src[:, None] + if scale is not None: + guide_sub = guide.clone() + src = F.interpolate(src, scale_factor=1. / scale, mode='nearest') + guide = F.interpolate( + guide, scale_factor=1. / scale, mode='nearest') + radius = radius // scale + + # b x 1 x H x W + guide_r, guide_g, guide_b = torch.chunk(guide, 3, 1) + ones = torch.ones_like(guide_r) + N = self.boxfilter2d(ones, radius) + + # b x 3 x H x W + mean_I = self.boxfilter2d(guide, radius) / N + mean_p = self.boxfilter2d(src, radius) / N + # b x 1 x H x W + mean_I_r, mean_I_g, mean_I_b = torch.chunk(mean_I, 3, 1) + + # b x C x H x W + mean_Ip_r = self.boxfilter2d(guide_r * src, radius) / N + mean_Ip_g = self.boxfilter2d(guide_g * src, radius) / N + mean_Ip_b = self.boxfilter2d(guide_b * src, radius) / N + + # b x C x H x W + cov_Ip_r = mean_Ip_r - mean_I_r * mean_p + cov_Ip_g = mean_Ip_g - mean_I_g * mean_p + cov_Ip_b = mean_Ip_b - mean_I_b * mean_p + + # b x 1 x H x W + var_I_rr = self.boxfilter2d(guide_r * guide_r, radius) / N \ + - mean_I_r * mean_I_r + eps + var_I_rg = self.boxfilter2d(guide_r * guide_g, radius) / N \ + - mean_I_r * mean_I_g + var_I_rb = self.boxfilter2d(guide_r * guide_b, radius) / N \ + - mean_I_r * mean_I_b + var_I_gg = self.boxfilter2d(guide_g * guide_g, radius) / N \ + - mean_I_g * mean_I_g + eps + var_I_gb = self.boxfilter2d(guide_g * guide_b, radius) / N \ + - mean_I_g * mean_I_b + var_I_bb = self.boxfilter2d(guide_b * guide_b, radius) / N \ + - mean_I_b * mean_I_b + eps + + # determinant, b x 1 x H x W + cov_det = var_I_rr * var_I_gg * var_I_bb \ + + var_I_rg * var_I_gb * var_I_rb \ + + var_I_rb * var_I_rg * var_I_gb \ + - var_I_rb * var_I_gg * var_I_rb \ + - var_I_rg * var_I_rg * var_I_bb \ + - var_I_rr * var_I_gb * var_I_gb + + # inverse, b x 1 x H x W + inv_var_I_rr = (var_I_gg * var_I_bb - var_I_gb * var_I_gb) / cov_det + inv_var_I_rg = -(var_I_rg * var_I_bb - var_I_rb * var_I_gb) / cov_det + inv_var_I_rb = (var_I_rg * var_I_gb - var_I_rb * var_I_gg) / cov_det + inv_var_I_gg = (var_I_rr * var_I_bb - var_I_rb * var_I_rb) / cov_det + inv_var_I_gb = -(var_I_rr * var_I_gb - var_I_rb * var_I_rg) / cov_det + inv_var_I_bb = (var_I_rr * var_I_gg - var_I_rg * var_I_rg) / cov_det + + # b x 3 x 3 x H x W + inv_sigma = torch.stack([ + torch.stack([inv_var_I_rr, inv_var_I_rg, inv_var_I_rb], 1), + torch.stack([inv_var_I_rg, inv_var_I_gg, inv_var_I_gb], 1), + torch.stack([inv_var_I_rb, inv_var_I_gb, inv_var_I_bb], 1) + ], 1).squeeze(-3) + + # b x 3 x C x H x W + cov_Ip = torch.stack([cov_Ip_r, cov_Ip_g, cov_Ip_b], 1) + + a = torch.einsum('bichw,bijhw->bjchw', (cov_Ip, inv_sigma)) + # b x C x H x W + b = mean_p - a[:, 0] * mean_I_r - \ + a[:, 1] * mean_I_g - \ + a[:, 2] * mean_I_b + + mean_a = torch.stack( + [self.boxfilter2d(a[:, i], radius) / N for i in range(3)], 1) + mean_b = self.boxfilter2d(b, radius) / N + + if scale is not None: + guide = guide_sub + mean_a = torch.stack([ + F.interpolate(mean_a[:, i], guide.shape[-2:], mode='bilinear') + for i in range(3) + ], 1) + mean_b = F.interpolate(mean_b, guide.shape[-2:], mode='bilinear') + + q = torch.einsum('bichw,bihw->bchw', (mean_a, guide)) + mean_b + + return q + + def guidedfilter2d_gray(self, + guide: Tensor, + src: Tensor, + radius: int, + eps: float, + scale: Optional[int] = None) -> Tensor: + """guided filter for a gray scale guide image. + + Args: + guide (Tensor): Guide imageof shape (B, 3, H, W). + src (Tensor): Filtering image of shape (B, C, H, W). + radius (int): Filter radius. + eps (float): Regularization coefficient. + scale (int, optional): Scale factor of the image. Defaults to None. + + Returns: + Tensor: Filtered image.ficient. + """ + + if guide.ndim == 3: + guide = guide[:, None] + if src.ndim == 3: + src = src[:, None] + + if scale is not None: + guide_sub = guide.clone() + src = F.interpolate(src, scale_factor=1. / scale, mode='nearest') + guide = F.interpolate( + guide, scale_factor=1. / scale, mode='nearest') + radius = radius // scale + + ones = torch.ones_like(guide) + N = self.boxfilter2d(ones, radius) + + mean_I = self.boxfilter2d(guide, radius) / N + mean_p = self.boxfilter2d(src, radius) / N + mean_Ip = self.boxfilter2d(guide * src, radius) / N + cov_Ip = mean_Ip - mean_I * mean_p + + mean_II = self.boxfilter2d(guide * guide, radius) / N + var_I = mean_II - mean_I * mean_I + + a = cov_Ip / (var_I + eps) + b = mean_p - a * mean_I + + mean_a = self.boxfilter2d(a, radius) / N + mean_b = self.boxfilter2d(b, radius) / N + + if scale is not None: + guide = guide_sub + mean_a = F.interpolate(mean_a, guide.shape[-2:], mode='bilinear') + mean_b = F.interpolate(mean_b, guide.shape[-2:], mode='bilinear') + + q = mean_a * guide + mean_b + return q + + def boxfilter2d(self, src: Tensor, radius: int) -> Tensor: + """Box filter for 2D image.""" + return self._diff_y(self._diff_x(src, radius), radius) + + @staticmethod + def _diff_x(src: Tensor, r: int) -> Tensor: + """Difference along x axis.""" + cum_src = src.cumsum(-2) + + left = cum_src[..., r:2 * r + 1, :] + middle = cum_src[..., 2 * r + 1:, :] - \ + cum_src[..., :-2 * r - 1, :] + right = cum_src[..., -1:, :] - \ + cum_src[..., -2 * r - 1:-r - 1, :] + + output = torch.cat([left, middle, right], -2) + return output + + @staticmethod + def _diff_y(src: Tensor, r: int) -> Tensor: + """Difference along y axis.""" + cum_src = src.cumsum(-1) + + left = cum_src[..., r:2 * r + 1] + middle = cum_src[..., 2 * r + 1:] - \ + cum_src[..., :-2 * r - 1] + right = cum_src[..., -1:] - \ + cum_src[..., -2 * r - 1:-r - 1] + + output = torch.cat([left, middle, right], -1) + return output diff --git a/tools/test.py b/tools/test.py index 9b00051..11436a1 100644 --- a/tools/test.py +++ b/tools/test.py @@ -7,6 +7,7 @@ from mmengine.runner import Runner from lqit.common.utils.lark_manager import (context_monitor_manager, + get_error_message, initialize_monitor_manager) from lqit.common.utils.process_lark_hook import process_lark_hook from lqit.registry import RUNNERS @@ -165,3 +166,5 @@ def main(args): except Exception: if monitor_manager is not None: monitor_manager.monitor_exception() + else: + get_error_message() diff --git a/tools/train.py b/tools/train.py index c8bc541..bad72b9 100644 --- a/tools/train.py +++ b/tools/train.py @@ -10,6 +10,7 @@ from mmengine.runner import Runner from lqit.common.utils.lark_manager import (context_monitor_manager, + get_error_message, initialize_monitor_manager) from lqit.common.utils.process_lark_hook import process_lark_hook from lqit.utils import print_colored_log, setup_cache_size_limit_of_dynamo @@ -156,7 +157,9 @@ def main(args): monitor_manager = None - if args.lark: + if not args.lark: + main(args) + else: lark_file = args.lark_file if not osp.exists(lark_file): warnings.warn(f'{lark_file} not exists, skip.') @@ -186,3 +189,5 @@ def main(args): except Exception: if monitor_manager is not None: monitor_manager.monitor_exception() + else: + get_error_message() From 31666f778add2a80023f95bd1d86cb771c4fea54 Mon Sep 17 00:00:00 2001 From: BIGWangYuDong Date: Thu, 26 Oct 2023 16:59:48 +0800 Subject: [PATCH 2/5] add uod-air model --- .../atss_r50_fpn_1x_urpc-coco.py | 72 ++++ .../faster-rcnn_r50_fpn_1x_urpc-coco.py | 4 +- ...fcos_r50-caffe_fpn_gn-head_1x_urpc-coco.py | 77 ++++ .../base_detector/paa_r50_fpn_1x_urpc-coco.py | 93 ++++ .../tood_r50_fpn_1x_urpc-coco.py | 81 ++++ .../base_editor/tienet_enhance_model.py | 2 +- ..._faster-rcnn_r50_fpn_1x_4xbs4_urpc-coco.py | 3 + ...tienet_faster-rcnn_r50_fpn_1x_urpc-coco.py | 35 ++ ...et_retinanet_r50_fpn_1x_4xbs4_urpc-coco.py | 3 + .../faster-rcnn_r50_fpn_1x_urpc-coco.py | 121 ++++++ .../retinanet_r50_ufpn_1x_urpc-coco.py} | 12 +- .../uod_air/base_ehance_head/enhance_head.py | 14 + ...od-air_faster-rcnn_r50_fpn_1x_urpc-coco.py | 23 + ...uod-air_retinanet_r50_ufpn_1x_urpc-coco.py | 41 ++ lqit/detection/models/detectors/__init__.py | 4 +- .../detectors/detector_with_enhance_head.py | 400 ++++++++++-------- .../detectors/detector_with_enhance_model.py | 6 +- lqit/detection/models/necks/__init__.py | 3 +- lqit/detection/models/necks/ufpn.py | 249 +++++++++++ .../models/editor_heads/basic_enhance_head.py | 5 +- 20 files changed, 1061 insertions(+), 187 deletions(-) create mode 100644 configs/detection/tienet/base_detector/atss_r50_fpn_1x_urpc-coco.py rename configs/detection/{uod_air => tienet/base_detector}/faster-rcnn_r50_fpn_1x_urpc-coco.py (96%) create mode 100644 configs/detection/tienet/base_detector/fcos_r50-caffe_fpn_gn-head_1x_urpc-coco.py create mode 100644 configs/detection/tienet/base_detector/paa_r50_fpn_1x_urpc-coco.py create mode 100644 configs/detection/tienet/base_detector/tood_r50_fpn_1x_urpc-coco.py create mode 100644 configs/detection/tienet/tienet_faster-rcnn_r50_fpn_1x_4xbs4_urpc-coco.py create mode 100644 configs/detection/tienet/tienet_faster-rcnn_r50_fpn_1x_urpc-coco.py create mode 100644 configs/detection/tienet/tienet_retinanet_r50_fpn_1x_4xbs4_urpc-coco.py create mode 100644 configs/detection/uod_air/base_detector/faster-rcnn_r50_fpn_1x_urpc-coco.py rename configs/detection/uod_air/{retinanet_r50_fpn_1x_urpc-coco.py => base_detector/retinanet_r50_ufpn_1x_urpc-coco.py} (90%) create mode 100644 configs/detection/uod_air/base_ehance_head/enhance_head.py create mode 100644 configs/detection/uod_air/uod-air_faster-rcnn_r50_fpn_1x_urpc-coco.py create mode 100644 configs/detection/uod_air/uod-air_retinanet_r50_ufpn_1x_urpc-coco.py create mode 100644 lqit/detection/models/necks/ufpn.py diff --git a/configs/detection/tienet/base_detector/atss_r50_fpn_1x_urpc-coco.py b/configs/detection/tienet/base_detector/atss_r50_fpn_1x_urpc-coco.py new file mode 100644 index 0000000..190eb71 --- /dev/null +++ b/configs/detection/tienet/base_detector/atss_r50_fpn_1x_urpc-coco.py @@ -0,0 +1,72 @@ +_base_ = [ + '../../_base_/datasets/urpc2020/urpc2020-validation_coco_detection.py', + '../../_base_/schedules/schedule_1x.py', '../../_base_/default_runtime.py' +] + +# model settings +model = dict( + type='ATSS', + data_preprocessor=dict( + type='DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32), + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs='on_output', + num_outs=5), + bbox_head=dict( + type='ATSSHead', + num_classes=4, + in_channels=256, + stacked_convs=4, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + ratios=[1.0], + octave_base_scale=8, + scales_per_octave=1, + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[0.1, 0.1, 0.2, 0.2]), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=2.0), + loss_centerness=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)), + # training and testing settings + train_cfg=dict( + assigner=dict(type='ATSSAssigner', topk=9), + allowed_border=-1, + pos_weight=-1, + debug=False), + test_cfg=dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.6), + max_per_img=100)) + +# optimizer +optim_wrapper = dict( + optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)) diff --git a/configs/detection/uod_air/faster-rcnn_r50_fpn_1x_urpc-coco.py b/configs/detection/tienet/base_detector/faster-rcnn_r50_fpn_1x_urpc-coco.py similarity index 96% rename from configs/detection/uod_air/faster-rcnn_r50_fpn_1x_urpc-coco.py rename to configs/detection/tienet/base_detector/faster-rcnn_r50_fpn_1x_urpc-coco.py index 2cb93f6..a28ff6f 100644 --- a/configs/detection/uod_air/faster-rcnn_r50_fpn_1x_urpc-coco.py +++ b/configs/detection/tienet/base_detector/faster-rcnn_r50_fpn_1x_urpc-coco.py @@ -1,6 +1,6 @@ _base_ = [ - '../_base_/datasets/urpc2020/urpc2020-validation_coco_detection.py', - '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' + '../../_base_/datasets/urpc2020/urpc2020-validation_coco_detection.py', + '../../_base_/schedules/schedule_1x.py', '../../_base_/default_runtime.py' ] # model settings diff --git a/configs/detection/tienet/base_detector/fcos_r50-caffe_fpn_gn-head_1x_urpc-coco.py b/configs/detection/tienet/base_detector/fcos_r50-caffe_fpn_gn-head_1x_urpc-coco.py new file mode 100644 index 0000000..8f27672 --- /dev/null +++ b/configs/detection/tienet/base_detector/fcos_r50-caffe_fpn_gn-head_1x_urpc-coco.py @@ -0,0 +1,77 @@ +_base_ = [ + '../../_base_/datasets/urpc2020/urpc2020-validation_coco_detection.py', + '../../_base_/schedules/schedule_1x.py', '../../_base_/default_runtime.py' +] + +# model settings +model = dict( + type='FCOS', + data_preprocessor=dict( + type='DetDataPreprocessor', + mean=[102.9801, 115.9465, 122.7717], + std=[1.0, 1.0, 1.0], + bgr_to_rgb=False, + pad_size_divisor=32), + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=False), + norm_eval=True, + style='caffe', + init_cfg=dict( + type='Pretrained', + checkpoint='open-mmlab://detectron/resnet50_caffe')), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs='on_output', # use P5 + num_outs=5, + relu_before_extra_convs=True), + bbox_head=dict( + type='FCOSHead', + num_classes=4, + in_channels=256, + stacked_convs=4, + feat_channels=256, + strides=[8, 16, 32, 64, 128], + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='IoULoss', loss_weight=1.0), + loss_centerness=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)), + # testing settings + test_cfg=dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.5), + max_per_img=100)) + +# learning rate +param_scheduler = [ + dict( + type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, + end=1000), + dict( + type='MultiStepLR', + begin=0, + end=12, + by_epoch=True, + milestones=[8, 11], + gamma=0.1) +] + +# optimizer +optim_wrapper = dict( + optimizer=dict(lr=0.01), + paramwise_cfg=dict(bias_lr_mult=2., bias_decay_mult=0.), + clip_grad=dict(max_norm=35, norm_type=2)) # loss may NaN without clip_grad diff --git a/configs/detection/tienet/base_detector/paa_r50_fpn_1x_urpc-coco.py b/configs/detection/tienet/base_detector/paa_r50_fpn_1x_urpc-coco.py new file mode 100644 index 0000000..a229916 --- /dev/null +++ b/configs/detection/tienet/base_detector/paa_r50_fpn_1x_urpc-coco.py @@ -0,0 +1,93 @@ +_base_ = [ + '../../_base_/datasets/urpc2020/urpc2020-validation_coco_detection.py', + '../../_base_/schedules/schedule_1x.py', '../../_base_/default_runtime.py' +] + +# model settings +model = dict( + type='PAA', + data_preprocessor=dict( + type='DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32), + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs='on_output', + num_outs=5), + bbox_head=dict( + type='PAAHead', + reg_decoded_bbox=True, + score_voting=True, + topk=9, + num_classes=5, + in_channels=256, + stacked_convs=4, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + ratios=[1.0], + octave_base_scale=8, + scales_per_octave=1, + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[0.1, 0.1, 0.2, 0.2]), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=1.3), + loss_centerness=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.5)), + # training and testing settings + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.1, + neg_iou_thr=0.1, + min_pos_iou=0, + ignore_iof_thr=-1), + allowed_border=-1, + pos_weight=-1, + debug=False), + test_cfg=dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.6), + max_per_img=100)) + +# optimizer +optim_wrapper = dict( + optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)) + +# learning rate +param_scheduler = [ + dict( + type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=500), + dict( + type='MultiStepLR', + begin=0, + end=12, + by_epoch=True, + milestones=[8, 11], + gamma=0.1) +] diff --git a/configs/detection/tienet/base_detector/tood_r50_fpn_1x_urpc-coco.py b/configs/detection/tienet/base_detector/tood_r50_fpn_1x_urpc-coco.py new file mode 100644 index 0000000..66cd58f --- /dev/null +++ b/configs/detection/tienet/base_detector/tood_r50_fpn_1x_urpc-coco.py @@ -0,0 +1,81 @@ +_base_ = [ + '../../_base_/datasets/urpc2020/urpc2020-validation_coco_detection.py', + '../../_base_/schedules/schedule_1x.py', '../../_base_/default_runtime.py' +] + +# model settings +model = dict( + type='TOOD', + data_preprocessor=dict( + type='DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32), + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs='on_output', + num_outs=5), + bbox_head=dict( + type='TOODHead', + num_classes=4, + in_channels=256, + stacked_convs=6, + feat_channels=256, + anchor_type='anchor_free', + anchor_generator=dict( + type='AnchorGenerator', + ratios=[1.0], + octave_base_scale=8, + scales_per_octave=1, + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[0.1, 0.1, 0.2, 0.2]), + initial_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + activated=True, # use probability instead of logit as input + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_cls=dict( + type='QualityFocalLoss', + use_sigmoid=True, + activated=True, # use probability instead of logit as input + beta=2.0, + loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=2.0)), + train_cfg=dict( + initial_epoch=4, + initial_assigner=dict(type='ATSSAssigner', topk=9), + assigner=dict(type='TaskAlignedAssigner', topk=13), + alpha=1, + beta=6, + allowed_border=-1, + pos_weight=-1, + debug=False), + test_cfg=dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.6), + max_per_img=100)) + +# optimizer +optim_wrapper = dict( + optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)) diff --git a/configs/detection/tienet/base_editor/tienet_enhance_model.py b/configs/detection/tienet/base_editor/tienet_enhance_model.py index 70cdd39..dc8f5bb 100644 --- a/configs/detection/tienet/base_editor/tienet_enhance_model.py +++ b/configs/detection/tienet/base_editor/tienet_enhance_model.py @@ -31,7 +31,7 @@ type='StructureFFTLoss', radius=4, pass_type='high', - channel_mean=True, + channel_mean=False, loss_type='mse', guid_filter=dict( type='GuidedFilter2d', radius=32, eps=1e-4, fast_s=2), diff --git a/configs/detection/tienet/tienet_faster-rcnn_r50_fpn_1x_4xbs4_urpc-coco.py b/configs/detection/tienet/tienet_faster-rcnn_r50_fpn_1x_4xbs4_urpc-coco.py new file mode 100644 index 0000000..17910b9 --- /dev/null +++ b/configs/detection/tienet/tienet_faster-rcnn_r50_fpn_1x_4xbs4_urpc-coco.py @@ -0,0 +1,3 @@ +_base_ = ['./tienet_faster-rcnn_r50_fpn_1x_urpc-coco.py'] + +train_dataloader = dict(batch_size=4, num_workers=4) diff --git a/configs/detection/tienet/tienet_faster-rcnn_r50_fpn_1x_urpc-coco.py b/configs/detection/tienet/tienet_faster-rcnn_r50_fpn_1x_urpc-coco.py new file mode 100644 index 0000000..f503b19 --- /dev/null +++ b/configs/detection/tienet/tienet_faster-rcnn_r50_fpn_1x_urpc-coco.py @@ -0,0 +1,35 @@ +# default scope is mmdet +_base_ = [ + './base_editor/tienet_enhance_model.py', + './base_detector/faster-rcnn_r50_fpn_1x_urpc-coco.py' +] + +model = dict( + _delete_=True, + type='lqit.DetectorWithEnhanceModel', + detector={{_base_.model}}, + enhance_model={{_base_.enhance_model}}, + train_mode='enhance', + pred_mode='enhance', + detach_enhance_img=False) + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001), + clip_grad=dict(max_norm=35, norm_type=2)) + +# dataset settings +train_pipeline = [ + dict(type='LoadImageFromFile', backend_args={{_base_.backend_args}}), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='Resize', scale=(1333, 800), keep_ratio=True), + dict(type='RandomFlip', prob=0.5), + dict(type='lqit.SetInputImageAsGT'), + dict(type='lqit.PackInputs') +] +train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) + +model_wrapper_cfg = dict( + type='lqit.SelfEnhanceModelDDP', + broadcast_buffers=False, + find_unused_parameters=False) diff --git a/configs/detection/tienet/tienet_retinanet_r50_fpn_1x_4xbs4_urpc-coco.py b/configs/detection/tienet/tienet_retinanet_r50_fpn_1x_4xbs4_urpc-coco.py new file mode 100644 index 0000000..3c5ab3b --- /dev/null +++ b/configs/detection/tienet/tienet_retinanet_r50_fpn_1x_4xbs4_urpc-coco.py @@ -0,0 +1,3 @@ +_base_ = ['./tienet_retinanet_r50_fpn_1x_urpc-coco.py'] + +train_dataloader = dict(batch_size=4, num_workers=4) diff --git a/configs/detection/uod_air/base_detector/faster-rcnn_r50_fpn_1x_urpc-coco.py b/configs/detection/uod_air/base_detector/faster-rcnn_r50_fpn_1x_urpc-coco.py new file mode 100644 index 0000000..512d9d7 --- /dev/null +++ b/configs/detection/uod_air/base_detector/faster-rcnn_r50_fpn_1x_urpc-coco.py @@ -0,0 +1,121 @@ +_base_ = [ + '../../_base_/datasets/urpc2020/urpc2020-validation_coco_detection.py', + '../../_base_/schedules/schedule_1x.py', '../../_base_/default_runtime.py' +] + +# model settings +model = dict( + type='FasterRCNN', + data_preprocessor=dict( + type='DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32), + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type='lqit.UFPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=0, + add_extra_convs='on_output', + num_outs=5), + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0)), + roi_head=dict( + type='StandardRoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=4, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=False, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.5), + max_per_img=100) + # soft-nms is also supported for rcnn testing + # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) + )) diff --git a/configs/detection/uod_air/retinanet_r50_fpn_1x_urpc-coco.py b/configs/detection/uod_air/base_detector/retinanet_r50_ufpn_1x_urpc-coco.py similarity index 90% rename from configs/detection/uod_air/retinanet_r50_fpn_1x_urpc-coco.py rename to configs/detection/uod_air/base_detector/retinanet_r50_ufpn_1x_urpc-coco.py index eef8d7a..9dd042d 100644 --- a/configs/detection/uod_air/retinanet_r50_fpn_1x_urpc-coco.py +++ b/configs/detection/uod_air/base_detector/retinanet_r50_ufpn_1x_urpc-coco.py @@ -1,6 +1,6 @@ _base_ = [ - '../_base_/datasets/urpc2020/urpc2020-validation_coco_detection.py', - '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' + '../../_base_/datasets/urpc2020/urpc2020-validation_coco_detection.py', + '../../_base_/schedules/schedule_1x.py', '../../_base_/default_runtime.py' ] # model settings @@ -23,12 +23,12 @@ style='pytorch', init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), neck=dict( - type='FPN', + type='lqit.UFPN', in_channels=[256, 512, 1024, 2048], out_channels=256, - start_level=1, - add_extra_convs='on_input', - num_outs=5), + start_level=0, + add_extra_convs='on_output', + num_outs=6), bbox_head=dict( type='RetinaHead', num_classes=4, diff --git a/configs/detection/uod_air/base_ehance_head/enhance_head.py b/configs/detection/uod_air/base_ehance_head/enhance_head.py new file mode 100644 index 0000000..0d308f2 --- /dev/null +++ b/configs/detection/uod_air/base_ehance_head/enhance_head.py @@ -0,0 +1,14 @@ +enhance_head = dict( + _scope_='lqit', + type='BasicEnhanceHead', + in_channels=256, + feat_channels=256, + num_convs=2, + loss_enhance=dict(type='L1Loss', loss_weight=1.0), + gt_preprocessor=dict( + type='GTPixelPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32, + element_name='img')) diff --git a/configs/detection/uod_air/uod-air_faster-rcnn_r50_fpn_1x_urpc-coco.py b/configs/detection/uod_air/uod-air_faster-rcnn_r50_fpn_1x_urpc-coco.py new file mode 100644 index 0000000..c8a8d94 --- /dev/null +++ b/configs/detection/uod_air/uod-air_faster-rcnn_r50_fpn_1x_urpc-coco.py @@ -0,0 +1,23 @@ +_base_ = [ + './base_detector/faster-rcnn_r50_fpn_1x_urpc-coco.py', + './base_ehance_head/enhance_head.py' +] + +# model settings +model = dict( + _delete_=True, + type='lqit.DetectorWithEnhanceHead', + detector={{_base_.model}}, + enhance_head={{_base_.enhance_head}}, + vis_enhance=False) + +# dataset settings +train_pipeline = [ + dict(type='LoadImageFromFile', backend_args={{_base_.backend_args}}), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='Resize', scale=(1333, 800), keep_ratio=True), + dict(type='RandomFlip', prob=0.5), + dict(type='lqit.SetInputImageAsGT'), + dict(type='lqit.PackInputs') +] +train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) diff --git a/configs/detection/uod_air/uod-air_retinanet_r50_ufpn_1x_urpc-coco.py b/configs/detection/uod_air/uod-air_retinanet_r50_ufpn_1x_urpc-coco.py new file mode 100644 index 0000000..2b03a40 --- /dev/null +++ b/configs/detection/uod_air/uod-air_retinanet_r50_ufpn_1x_urpc-coco.py @@ -0,0 +1,41 @@ +_base_ = [ + './base_detector/retinanet_r50_ufpn_1x_urpc-coco.py', + './base_ehance_head/enhance_head.py' +] + +# model settings +model = dict( + _delete_=True, + type='lqit.DetectorWithEnhanceHead', + detector={{_base_.model}}, + enhance_head={{_base_.enhance_head}}, + vis_enhance=False) + +# dataset settings +train_pipeline = [ + dict(type='LoadImageFromFile', backend_args={{_base_.backend_args}}), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='Resize', scale=(1333, 800), keep_ratio=True), + dict(type='RandomFlip', prob=0.5), + dict(type='lqit.SetInputImageAsGT'), + dict(type='lqit.PackInputs') +] +train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) + +# optimizer +optim_wrapper = dict( + optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001), + clip_grad=dict(max_norm=35, norm_type=2)) # loss may NaN without clip_grad + +param_scheduler = [ + dict( + type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, + end=1000), + dict( + type='MultiStepLR', + begin=0, + end=12, + by_epoch=True, + milestones=[8, 11], + gamma=0.1) +] diff --git a/lqit/detection/models/detectors/__init__.py b/lqit/detection/models/detectors/__init__.py index 9e25cb0..b9fe038 100644 --- a/lqit/detection/models/detectors/__init__.py +++ b/lqit/detection/models/detectors/__init__.py @@ -1,3 +1,4 @@ +from .detector_with_enhance_head import DetectorWithEnhanceHead from .detector_with_enhance_model import DetectorWithEnhanceModel from .edffnet import EDFFNet from .multi_input_wrapper import MultiInputDetectorWrapper @@ -6,5 +7,6 @@ __all__ = [ 'TwoStageWithEnhanceHead', 'MultiInputDetectorWrapper', - 'SingleStageDetector', 'EDFFNet', 'DetectorWithEnhanceModel' + 'SingleStageDetector', 'EDFFNet', 'DetectorWithEnhanceModel', + 'DetectorWithEnhanceHead' ] diff --git a/lqit/detection/models/detectors/detector_with_enhance_head.py b/lqit/detection/models/detectors/detector_with_enhance_head.py index 6bcf7bc..92e110d 100644 --- a/lqit/detection/models/detectors/detector_with_enhance_head.py +++ b/lqit/detection/models/detectors/detector_with_enhance_head.py @@ -1,78 +1,162 @@ import copy -from typing import Optional +import warnings +from typing import Dict, Optional, Tuple, Union import torch from mmdet.models import SingleStageDetector, TwoStageDetector -from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from mmengine.model import BaseModel from torch import Tensor -from lqit.common.structures import SampleList +from lqit.common.structures import OptSampleList, SampleList from lqit.edit.models import add_pixel_pred_to_datasample from lqit.registry import MODELS +from lqit.utils import ConfigType, OptConfigType, OptMultiConfig +ForwardResults = Union[Dict[str, Tensor], SampleList, Tuple[Tensor], Tensor] -@MODELS.register_module() -class SingleStageWithEnhanceHead(SingleStageDetector): - """Base class for two-stage detectors with enhance head. - Two-stage detectors typically consisting of a region proposal network and a - task-specific regression head. +@MODELS.register_module() +class DetectorWithEnhanceHead(BaseModel): + """Detector with enhance head. + + Args: + detector (dict or ConfigDict): Config for detector. + enhance_head (dict or ConfigDict, optional): Config for enhance head. + vis_enhance (bool): Whether to visualize the enhanced image during + inference. Defaults to False. + init_cfg (dict or ConfigDict, optional): The config to control the + initialization. Defaults to None. """ def __init__(self, - backbone: ConfigType, - neck: OptConfigType = None, - bbox_head: OptConfigType = None, + detector: ConfigType, enhance_head: OptConfigType = None, vis_enhance: Optional[bool] = False, - train_cfg: OptConfigType = None, - test_cfg: OptConfigType = None, - data_preprocessor: OptConfigType = None, init_cfg: OptMultiConfig = None) -> None: - super().__init__( - backbone=backbone, - neck=neck, - bbox_head=bbox_head, - train_cfg=train_cfg, - test_cfg=test_cfg, - data_preprocessor=data_preprocessor, - init_cfg=init_cfg) + super().__init__(init_cfg=init_cfg) + + # process gt_preprocessor + if enhance_head is not None: + enhance_head = self.process_gt_preprocessor(detector, enhance_head) + + # build data_preprocessor + self.data_preprocessor = MODELS.build(detector['data_preprocessor']) + # build detector + self.detector = MODELS.build(detector) + if isinstance(self.detector, SingleStageDetector): + self.detector_type = 'SingleStage' + elif isinstance(self.detector, TwoStageDetector): + self.detector_type = 'TwoStage' + else: + raise TypeError( + f'Only support SingleStageDetector and TwoStageDetector, ' + f'but got {type(self.detector)}.') + # build enhance head if enhance_head is not None: self.enhance_head = MODELS.build(enhance_head) + else: + self.enhance_head = None + if vis_enhance: + assert self.with_enhance_head self.vis_enhance = vis_enhance + @staticmethod + def process_gt_preprocessor(detector, enhance_head): + """Process the gt_preprocessor of enhance head.""" + data_preprocessor = detector.get('data_preprocessor', None) + data_preprocessor_mean = data_preprocessor['mean'] + data_preprocessor_std = data_preprocessor['std'] + data_preprocessor_bgr_to_rgb = data_preprocessor['bgr_to_rgb'] + data_preprocessor_pad_size_divisor = \ + data_preprocessor['pad_size_divisor'] + + gt_preprocessor = enhance_head.get('gt_preprocessor', None) + gt_preprocessor_mean = gt_preprocessor['mean'] + gt_preprocessor_std = gt_preprocessor['std'] + gt_preprocessor_bgr_to_rgb = gt_preprocessor['bgr_to_rgb'] + gt_preprocessor_pad_size_divisor = gt_preprocessor['pad_size_divisor'] + + if data_preprocessor_mean != gt_preprocessor_mean: + warnings.warn( + 'the `mean` of data_preprocessor and gt_preprocessor' + 'are different, force to use the `mean` of data_preprocessor.') + enhance_head['data_preprocessor']['mean'] = data_preprocessor_mean + if data_preprocessor_std != gt_preprocessor_std: + warnings.warn( + 'the `std` of data_preprocessor and gt_preprocessor' + 'are different, force to use the `std` of data_preprocessor.') + enhance_head['data_preprocessor']['std'] = data_preprocessor_std + if data_preprocessor_bgr_to_rgb != gt_preprocessor_bgr_to_rgb: + warnings.warn( + 'the `bgr_to_rgb` of data_preprocessor and gt_preprocessor' + 'are different, force to use the `bgr_to_rgb` of ' + 'data_preprocessor.') + enhance_head['data_preprocessor']['bgr_to_rgb'] = \ + data_preprocessor_bgr_to_rgb + if data_preprocessor_pad_size_divisor != \ + gt_preprocessor_pad_size_divisor: + warnings.warn('the `pad_size_divisor` of data_preprocessor and ' + 'gt_preprocessor are different, force to use the ' + '`pad_size_divisor` of data_preprocessor.') + enhance_head['data_preprocessor']['pad_size_divisor'] = \ + data_preprocessor_pad_size_divisor + return enhance_head + @property def with_enhance_head(self) -> bool: - """bool: whether the detector has a RoI head""" + """Whether has enhance head.""" return hasattr(self, 'enhance_head') and self.enhance_head is not None - def _forward(self, batch_inputs: Tensor, - batch_data_samples: SampleList) -> tuple: - """Network forward process. Usually includes backbone, neck and head - forward without any post-processing. + def forward(self, + inputs: Tensor, + data_samples: OptSampleList = None, + mode: str = 'tensor') -> ForwardResults: + """The unified entry for a forward process in both training and test. + + The method should accept three modes: "tensor", "predict" and "loss": + + - "tensor": Forward the whole network and return tensor or tuple of + tensor without any post-processing, same as a common nn.Module. + - "predict": Forward and return the predictions, which are fully + processed to a list of :obj:`DetDataSample`. + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle either back propagation or + parameter update, which are supposed to be done in :meth:`train_step`. Args: - batch_inputs (Tensor): Inputs with shape (N, C, H, W). + inputs (Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (list[:obj:`DetDataSample`], optional): A batch of + data samples that contain annotations and predictions. + Defaults to None. + mode (str): Return what kind of value. Defaults to 'tensor'. Returns: - tuple: A tuple of features from ``rpn_head`` and ``roi_head`` - forward. + The return type depends on ``mode``. + + - If ``mode="tensor"``, return a tensor or a tuple of tensor. + - If ``mode="predict"``, return a list of :obj:`DetDataSample`. + - If ``mode="loss"``, return a dict of tensor. """ - x = self.extract_feat(batch_inputs) - results = self.bbox_head.forward(x) - if self.with_enhance_head: - enhance_outs = self.enhance_head.forward(x) - results = results + (enhance_outs, ) - return results + if mode == 'loss': + return self.loss(inputs, data_samples) + elif mode == 'predict': + return self.predict(inputs, data_samples) + elif mode == 'tensor': + return self._forward(inputs, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}". ' + 'Only supports loss, predict and tensor mode') - def loss(self, batch_inputs: Tensor, - batch_data_samples: SampleList) -> dict: - """Calculate losses from a batch of inputs and data samples. + def calculate_det_loss(self, x: Tuple[Tensor], + batch_data_samples: SampleList) -> dict: + """Calculate detection loss. Args: - batch_inputs (Tensor): Input images of shape (N, C, H, W). - These should usually be mean centered and std scaled. + x (tuple[Tensor]): Tuple of multi-level img features. batch_data_samples (List[:obj:`DetDataSample`]): The batch data samples. It usually includes information such as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. @@ -80,29 +164,53 @@ def loss(self, batch_inputs: Tensor, Returns: dict: A dictionary of loss components """ - x = self.extract_feat(batch_inputs) - - losses = dict() - if self.with_enhance_head: - - enhance_loss = self.enhance_head.loss(x, batch_data_samples) - # avoid loss override - assert not set(enhance_loss.keys()) & set(losses.keys()) - losses.update(enhance_loss) + if len(x) > 5: + x = x[1:] + if self.detector_type == 'SingleStage': + losses = self.detector.bbox_head.loss(x, batch_data_samples) + else: + losses = dict() + # RPN forward and loss + if self.detector.with_rpn: + proposal_cfg = self.detector.train_cfg.get( + 'rpn_proposal', self.detector.test_cfg.rpn) + rpn_data_samples = copy.deepcopy(batch_data_samples) + # set cat_id of gt_labels to 0 in RPN + for data_sample in rpn_data_samples: + data_sample.gt_instances.labels = \ + torch.zeros_like(data_sample.gt_instances.labels) + + rpn_losses, rpn_results_list = \ + self.detector.rpn_head.loss_and_predict( + x, rpn_data_samples, proposal_cfg=proposal_cfg) + # avoid get same name with roi_head loss + keys = rpn_losses.keys() + for key in list(keys): + if 'loss' in key and 'rpn' not in key: + rpn_losses[f'rpn_{key}'] = rpn_losses.pop(key) + losses.update(rpn_losses) + else: + assert batch_data_samples[0].get('proposals', None) is not None + # use pre-defined proposals in InstanceData for the second + # stage to extract ROI features. + rpn_results_list = [ + data_sample.proposals for data_sample in batch_data_samples + ] + + roi_losses = self.detector.roi_head.loss(x, rpn_results_list, + batch_data_samples) + losses.update(roi_losses) - det_losses = self.bbox_head.loss(x, batch_data_samples) - losses.update(det_losses) return losses - def predict(self, - batch_inputs: Tensor, - batch_data_samples: SampleList, - rescale: bool = True) -> SampleList: - """Predict results from a batch of inputs and data samples with post- - processing. + def predict_det_results(self, + x: Tuple[Tensor], + batch_data_samples: SampleList, + rescale: bool = True) -> SampleList: + """Predict detection results. Args: - batch_inputs (Tensor): Inputs with shape (N, C, H, W). + x (tuple[Tensor]): Tuple of multi-level img features. batch_data_samples (List[:obj:`DetDataSample`]): The Data Samples. It usually includes information such as `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. @@ -110,7 +218,7 @@ def predict(self, Defaults to True. Returns: - list[:obj:`DataSample`]: Return the detection results of the + list[:obj:`DetDataSample`]: Return the detection results of the input images. The returns value is DetDataSample, which usually contain 'pred_instances'. And the ``pred_instances`` usually contains following keys. @@ -123,58 +231,62 @@ def predict(self, the last dimension 4 arrange as (x1, y1, x2, y2). - masks (Tensor): Has a shape (num_instances, H, W). """ - x = self.extract_feat(batch_inputs) - results_list = self.bbox_head.predict( - x, batch_data_samples, rescale=rescale) - - if self.vis_enhance and self.with_enhance_head: - enhance_list = self.enhance_head.predict( + if len(x) > 5: + x = x[1:] + if self.detector_type == 'SingleStage': + results_list = self.detector.bbox_head.predict( x, batch_data_samples, rescale=rescale) - batch_data_samples = add_pixel_pred_to_datasample( - data_samples=batch_data_samples, pixel_list=enhance_list) - - batch_data_samples = self.add_pred_to_datasample( + else: + assert self.detector.with_bbox, 'Bbox head must be implemented.' + # If there are no pre-defined proposals, use RPN to get proposals + if batch_data_samples[0].get('proposals', None) is None: + rpn_results_list = self.detector.rpn_head.predict( + x, batch_data_samples, rescale=False) + else: + rpn_results_list = [ + data_sample.proposals for data_sample in batch_data_samples + ] + + results_list = self.detector.roi_head.predict( + x, rpn_results_list, batch_data_samples, rescale=rescale) + + batch_data_samples = self.detector.add_pred_to_datasample( batch_data_samples, results_list) return batch_data_samples + def det_head_forward(self, x: Tuple[Tensor], + batch_data_samples: SampleList) -> tuple: + """Forward process of detection head. -@MODELS.register_module() -class TwoStageWithEnhanceHead(TwoStageDetector): - """Base class for two-stage detectors with enhance head. - - Two-stage detectors typically consisting of a region proposal network and a - task-specific regression head. - """ - - def __init__(self, - backbone: ConfigType, - neck: OptConfigType = None, - rpn_head: OptConfigType = None, - roi_head: OptConfigType = None, - enhance_head: OptConfigType = None, - vis_enhance: Optional[bool] = False, - train_cfg: OptConfigType = None, - test_cfg: OptConfigType = None, - data_preprocessor: OptConfigType = None, - init_cfg: OptMultiConfig = None) -> None: - super().__init__( - backbone=backbone, - neck=neck, - rpn_head=rpn_head, - roi_head=roi_head, - train_cfg=train_cfg, - test_cfg=test_cfg, - data_preprocessor=data_preprocessor, - init_cfg=init_cfg) - - if enhance_head is not None: - self.enhance_head = MODELS.build(enhance_head) - self.vis_enhance = vis_enhance + Args: + x (tuple[Tensor]): Tuple of multi-level img features. + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. - @property - def with_enhance_head(self) -> bool: - """bool: whether the detector has a RoI head""" - return hasattr(self, 'enhance_head') and self.enhance_head is not None + Returns: + tuple: A tuple of features from detector head (`bbox_head` in + single-stage detector or `rpn_head` and `roi_head` in + two-stage detector). + """ + if len(x) > 5: + x = x[1:] + if self.detector_type == 'SingleStage': + results = self.detector.bbox_head.forward(x) + else: + results = () + if self.detector.with_rpn: + rpn_results_list = self.detector.rpn_head.predict( + x, batch_data_samples, rescale=False) + else: + assert batch_data_samples[0].get('proposals', None) is not None + rpn_results_list = [ + data_sample.proposals for data_sample in batch_data_samples + ] + roi_outs = self.detector.roi_head.forward(x, rpn_results_list, + batch_data_samples) + results = results + (roi_outs, ) + return results def _forward(self, batch_inputs: Tensor, batch_data_samples: SampleList) -> tuple: @@ -188,24 +300,12 @@ def _forward(self, batch_inputs: Tensor, tuple: A tuple of features from ``rpn_head`` and ``roi_head`` forward. """ - results = () - x = self.extract_feat(batch_inputs) - - if self.with_rpn: - rpn_results_list = self.rpn_head.predict( - x, batch_data_samples, rescale=False) - else: - assert batch_data_samples[0].get('proposals', None) is not None - rpn_results_list = [ - data_sample.proposals for data_sample in batch_data_samples - ] + x = self.detector.extract_feat(batch_inputs) - if self.with_enhance_head: + results = self.det_head_forward(x, batch_data_samples) + if self.vis_enhance: enhance_outs = self.enhance_head.forward(x) results = results + (enhance_outs, ) - - roi_outs = self.roi_head.forward(x, rpn_results_list) - results = results + (roi_outs, ) return results def loss(self, batch_inputs: Tensor, @@ -222,10 +322,9 @@ def loss(self, batch_inputs: Tensor, Returns: dict: A dictionary of loss components """ - x = self.extract_feat(batch_inputs) + x = self.detector.extract_feat(batch_inputs) losses = dict() - if self.with_enhance_head: enhance_loss = self.enhance_head.loss(x, batch_data_samples) @@ -233,36 +332,10 @@ def loss(self, batch_inputs: Tensor, assert not set(enhance_loss.keys()) & set(losses.keys()) losses.update(enhance_loss) - # RPN forward and loss - if self.with_rpn: - proposal_cfg = self.train_cfg.get('rpn_proposal', - self.test_cfg.rpn) - rpn_data_samples = copy.deepcopy(batch_data_samples) - # set cat_id of gt_labels to 0 in RPN - for data_sample in rpn_data_samples: - data_sample.gt_instances.labels = \ - torch.zeros_like(data_sample.gt_instances.labels) - - rpn_losses, rpn_results_list = self.rpn_head.loss_and_predict( - x, rpn_data_samples, proposal_cfg=proposal_cfg) - # avoid get same name with roi_head loss - keys = rpn_losses.keys() - for key in keys: - if 'loss' in key and 'rpn' not in key: - rpn_losses[f'rpn_{key}'] = rpn_losses.pop(key) - losses.update(rpn_losses) - else: - # TODO: Not support currently, should have a check at Fast R-CNN - assert batch_data_samples[0].get('proposals', None) is not None - # use pre-defined proposals in InstanceData for the second stage - # to extract ROI features. - rpn_results_list = [ - data_sample.proposals for data_sample in batch_data_samples - ] - roi_losses = self.roi_head.loss(x, rpn_results_list, - batch_data_samples) - losses.update(roi_losses) - + det_losses = self.calculate_det_loss(x, batch_data_samples) + # avoid loss override + assert not set(det_losses.keys()) & set(losses.keys()) + losses.update(det_losses) return losses def predict(self, @@ -294,27 +367,14 @@ def predict(self, the last dimension 4 arrange as (x1, y1, x2, y2). - masks (Tensor): Has a shape (num_instances, H, W). """ - assert self.with_bbox, 'Bbox head must be implemented.' - x = self.extract_feat(batch_inputs) - - # If there are no pre-defined proposals, use RPN to get proposals - if batch_data_samples[0].get('proposals', None) is None: - rpn_results_list = self.rpn_head.predict( - x, batch_data_samples, rescale=False) - else: - rpn_results_list = [ - data_sample.proposals for data_sample in batch_data_samples - ] + x = self.detector.extract_feat(batch_inputs) + batch_data_samples = self.predict_det_results( + x, batch_data_samples, rescale=rescale) - if self.vis_enhance and self.with_enhance_head: + if self.vis_enhance: enhance_list = self.enhance_head.predict( x, batch_data_samples, rescale=rescale) batch_data_samples = add_pixel_pred_to_datasample( data_samples=batch_data_samples, pixel_list=enhance_list) - results_list = self.roi_head.predict( - x, rpn_results_list, batch_data_samples, rescale=rescale) - - batch_data_samples = self.add_pred_to_datasample( - batch_data_samples, results_list) return batch_data_samples diff --git a/lqit/detection/models/detectors/detector_with_enhance_model.py b/lqit/detection/models/detectors/detector_with_enhance_model.py index 892668f..d10ecb4 100644 --- a/lqit/detection/models/detectors/detector_with_enhance_model.py +++ b/lqit/detection/models/detectors/detector_with_enhance_model.py @@ -1,7 +1,6 @@ import copy from typing import Any, Dict, Optional, Tuple, Union -from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig from mmengine.model import BaseModel from mmengine.model.wrappers import MMDistributedDataParallel as MMDDP from mmengine.utils import is_list_of @@ -11,6 +10,7 @@ from lqit.detection.utils import merge_det_results from lqit.edit.models.post_processor import add_pixel_pred_to_datasample from lqit.registry import MODEL_WRAPPERS, MODELS +from lqit.utils import ConfigType, OptConfigType, OptMultiConfig ForwardResults = Union[Dict[str, Tensor], SampleList, Tuple[Tensor], Tensor] @@ -30,7 +30,7 @@ class DetectorWithEnhanceModel(BaseModel): enhance_model (dict or ConfigDict, optional): Config for enhance model. loss_weight (list): Detection loss weight for raw and enhanced image. Only used when `train_mode` is `both`. - vis_enhance (bool): Whether visualize enhance image during inference. + vis_enhance (bool): Whether visualize enhanced image during inference. Defaults to False. train_mode (str): Train mode of detector, support `raw`, `enhance` and `both`. Defaults to `enhance`. @@ -99,7 +99,7 @@ def __init__(self, @property def with_enhance_model(self) -> bool: - """bool: whether the detector has a Enhance Model""" + """Whether has a enhance model.""" return (hasattr(self, 'enhance_model') and self.enhance_model is not None) diff --git a/lqit/detection/models/necks/__init__.py b/lqit/detection/models/necks/__init__.py index d463b99..35ca8e3 100644 --- a/lqit/detection/models/necks/__init__.py +++ b/lqit/detection/models/necks/__init__.py @@ -1,3 +1,4 @@ from .dffpn import DFFPN +from .ufpn import UFPN -__all__ = ['DFFPN'] +__all__ = ['DFFPN', 'UFPN'] diff --git a/lqit/detection/models/necks/ufpn.py b/lqit/detection/models/necks/ufpn.py new file mode 100644 index 0000000..74b5531 --- /dev/null +++ b/lqit/detection/models/necks/ufpn.py @@ -0,0 +1,249 @@ +from typing import List, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmdet.models.necks.fpn import FPN +from torch import Tensor + +from lqit.registry import MODELS +from lqit.utils import ConfigType, MultiConfig, OptConfigType + + +@MODELS.register_module() +class UFPN(FPN): + """UNet-based Feature Pyramid Network, UFPN. + + This is an implementation of paper `Underwater Object Detection Aided + by Image Reconstruction + `_. + + Args: + in_channels (list[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale). + num_outs (int): Number of output scales. + start_level (int): Index of the start input backbone level used to + build the feature pyramid. Defaults to 0. + end_level (int): Index of the end input backbone level (exclusive) to + build the feature pyramid. Defaults to -1, which means the + last level. + add_extra_convs (bool | str): If bool, it decides whether to add conv + layers on top of the original feature maps. Defaults to False. + If True, it is equivalent to `add_extra_convs='on_input'`. + If str, it specifies the source feature map of the extra convs. + Only the following options are allowed + + - 'on_input': Last feat map of neck inputs (i.e. backbone feature). + - 'on_lateral': Last feature map after lateral convs. + - 'on_output': The last output feature map after fpn convs. + relu_before_extra_convs (bool): Whether to apply relu before the extra + conv. Defaults to False. + no_norm_on_lateral (bool): Whether to apply norm on lateral. + Defaults to False. + conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for + convolution layer. Defaults to None. + norm_cfg (:obj:`ConfigDict` or dict, optional): Config dict for + normalization layer. Defaults to None. + act_cfg (:obj:`ConfigDict` or dict, optional): Config dict for + activation layer in ConvModule. Defaults to None. + init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ + dict]): Initialization config dict. + """ + + def __init__( + self, + in_channels: List[int], + out_channels: int, + num_outs: int = 6, + start_level: int = 0, + end_level: int = -1, + add_extra_convs: str = 'on_output', + relu_before_extra_convs: bool = False, + no_norm_on_lateral: bool = False, + conv_cfg: OptConfigType = None, + norm_cfg: OptConfigType = None, + act_cfg: OptConfigType = None, + upsample_cfg: ConfigType = dict(mode='nearest'), + init_cfg: MultiConfig = dict( + type='Xavier', layer='Conv2d', distribution='uniform') + ) -> None: + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + num_outs=num_outs, + start_level=start_level, + end_level=end_level, + add_extra_convs=add_extra_convs, + relu_before_extra_convs=relu_before_extra_convs, + no_norm_on_lateral=no_norm_on_lateral, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + upsample_cfg=upsample_cfg, + init_cfg=init_cfg) + # add encoder pathway + self.encode_convs = nn.ModuleList() + self.connect_convs = nn.ModuleList() + for i in range(self.start_level, self.num_outs + self.start_level): + connect_conv = ConvModule( + out_channels, + out_channels, + 3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + if i < self.num_outs + self.start_level - 1: + encode_conv = ConvModule( + out_channels, + out_channels, + 3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + + self.encode_convs.append(encode_conv) + self.connect_convs.append(connect_conv) + + # add decoder pathway + self.decode_convs = nn.ModuleList() + decode_conv = ConvModule( + out_channels, + out_channels, + 3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, # may use ReLU or LeakyReLU + inplace=False) + + self.decode_convs.append(decode_conv) + + for _ in range(self.start_level + 1, self.num_outs + self.start_level): + conv1 = ConvModule( + out_channels * 2, # concat with other feature map + out_channels, + 3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, # may use ReLU or LeakyReLU + inplace=False) + conv2 = ConvModule( + out_channels, + out_channels, + 3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, # may use ReLU or LeakyReLU + inplace=False) + decode_conv = nn.Sequential(conv1, conv2) + self.decode_convs.append(decode_conv) + + def forward(self, inputs: Tuple[Tensor]) -> tuple: + """Forward function. + + Args: + inputs (tuple[Tensor]): Features from the upstream network, each + is a 4D-tensor. + + Returns: + tuple: Feature maps, each is a 4D-tensor. + """ + assert len(inputs) == len(self.in_channels) + + # build laterals + laterals = [ + lateral_conv(inputs[i + self.start_level]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + + # build top-down path + used_backbone_levels = len(laterals) + for i in range(used_backbone_levels - 1, 0, -1): + # In some cases, fixing `scale factor` (e.g. 2) is preferred, but + # it cannot co-exist with `size` in `F.interpolate`. + if 'scale_factor' in self.upsample_cfg: + # fix runtime error of "+=" inplace operation in PyTorch 1.10 + laterals[i - 1] = laterals[i - 1] + F.interpolate( + laterals[i], **self.upsample_cfg) + else: + prev_shape = laterals[i - 1].shape[2:] + laterals[i - 1] = laterals[i - 1] + F.interpolate( + laterals[i], size=prev_shape, **self.upsample_cfg) + + # build outputs + # part 1: from original levels + inter_outs = [ + self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels) + ] + # build extra outputs + + if self.num_outs > len(inter_outs): + # use max pool to get more levels on top of outputs + # (e.g., Faster R-CNN, Mask R-CNN) + if not self.add_extra_convs: + for i in range(self.num_outs - used_backbone_levels): + inter_outs.append( + F.max_pool2d(inter_outs[-1], 1, stride=2)) + # add conv layers on top of original feature maps (RetinaNet) + else: + if self.add_extra_convs == 'on_input': + extra_source = inputs[self.backbone_end_level - 1] + elif self.add_extra_convs == 'on_lateral': + extra_source = laterals[-1] + elif self.add_extra_convs == 'on_output': + extra_source = inter_outs[-1] + else: + raise NotImplementedError + inter_outs.append( + self.fpn_convs[used_backbone_levels](extra_source)) + for i in range(used_backbone_levels + 1, self.num_outs): + if self.relu_before_extra_convs: + inter_outs.append(self.fpn_convs[i](F.relu( + inter_outs[-1]))) + else: + inter_outs.append(self.fpn_convs[i](inter_outs[-1])) + + # part 2: add encoder path + connect_feats = [ + self.connect_convs[i](inter_outs[i]) for i in range(self.num_outs) + ] + + encode_outs = [] + encode_outs.append(connect_feats[0]) + for i in range(0, self.num_outs - 1): + encode_outs.append(self.encode_convs[i](connect_feats[i]) + + connect_feats[i + 1]) + + # part 3: add decoder levels + decode_outs = [ + torch.zeros_like(encode_outs[i]) for i in range(self.num_outs) + ] + decode_outs[-1] = self.decode_convs[0](encode_outs[-1]) + for i in range(1, self.num_outs): + reverse_i = self.num_outs - i + if 'scale_factor' in self.upsample_cfg: + up_feat = F.interpolate(decode_outs[reverse_i], + **self.upsample_cfg) + else: + prev_shape = encode_outs[reverse_i - 1].shape[2:] + up_feat = F.interpolate( + decode_outs[reverse_i], + size=prev_shape, + **self.upsample_cfg) + + decode_outs[reverse_i - 1] = self.decode_convs[i]( + torch.cat((encode_outs[reverse_i - 1], up_feat), dim=1)) + + return tuple(decode_outs) diff --git a/lqit/edit/models/editor_heads/basic_enhance_head.py b/lqit/edit/models/editor_heads/basic_enhance_head.py index 3a8cae8..6034ad1 100644 --- a/lqit/edit/models/editor_heads/basic_enhance_head.py +++ b/lqit/edit/models/editor_heads/basic_enhance_head.py @@ -232,13 +232,12 @@ def loss_by_feat_single(self, enhance_img, gt_img, img_meta): @MODELS.register_module() class BasicEnhanceHead(BaseEnhanceHead): - """[(convs)+ShufflePixes] * 2 - """ + """[(convs)+ShufflePixes] * 2""" def __init__(self, in_channels=256, feat_channels=256, - num_convs=5, + num_convs=2, conv_cfg=None, norm_cfg=None, act_cfg=dict(type='ReLU'), From 02b97db03b2e80816cc5f7798e6a800684b2abaa Mon Sep 17 00:00:00 2001 From: BIGWangYuDong Date: Sat, 28 Oct 2023 00:00:02 +0800 Subject: [PATCH 3/5] fully support uod-air --- configs/detection/tienet/README.md | 11 +- configs/detection/uod_air/README.md | 12 +- .../faster-rcnn_r50_fpn_1x_urpc-coco.py | 4 +- .../faster-rcnn_r50_ufpn_1x_urpc-coco.py | 121 ++++++++++++++++++ .../retinanet_r50_fpn_1x_urpc-coco.py | 91 +++++++++++++ .../retinanet_r50_ufpn-p2-p7_1x_urpc-coco.py | 91 +++++++++++++ .../retinanet_r50_ufpn_1x_urpc-coco.py | 4 +- .../uod_air/base_ehance_head/enhance_head.py | 4 +- ...-air_faster-rcnn_r50_ufpn_1x_urpc-coco.py} | 2 +- ...r_retinanet_r50_ufpn_1x_4xbs4_urpc-coco.py | 4 + ...r_retinanet_r50_ufpn_1x_urpc-coco_lr002.py | 7 + .../detectors/detector_with_enhance_head.py | 6 +- .../models/editor_heads/basic_enhance_head.py | 5 +- tools/test.py | 30 ++--- tools/train.py | 28 ++-- 15 files changed, 376 insertions(+), 44 deletions(-) create mode 100644 configs/detection/uod_air/base_detector/faster-rcnn_r50_ufpn_1x_urpc-coco.py create mode 100644 configs/detection/uod_air/base_detector/retinanet_r50_fpn_1x_urpc-coco.py create mode 100644 configs/detection/uod_air/base_detector/retinanet_r50_ufpn-p2-p7_1x_urpc-coco.py rename configs/detection/uod_air/{uod-air_faster-rcnn_r50_fpn_1x_urpc-coco.py => uod-air_faster-rcnn_r50_ufpn_1x_urpc-coco.py} (91%) create mode 100644 configs/detection/uod_air/uod-air_retinanet_r50_ufpn_1x_4xbs4_urpc-coco.py create mode 100644 configs/detection/uod_air/uod-air_retinanet_r50_ufpn_1x_urpc-coco_lr002.py diff --git a/configs/detection/tienet/README.md b/configs/detection/tienet/README.md index f0864ce..90d2d64 100644 --- a/configs/detection/tienet/README.md +++ b/configs/detection/tienet/README.md @@ -16,4 +16,13 @@ Coming soon ## Citation -Coming soon +```latex +@article{wang2023tienet, + title={{TIENet}: task-oriented image enhancement network for degraded object detection}, + author={Wang, Yudong and Guo, Jichang and Wang, Ruining and He, Wanru and Li, Chongyi}, + journal={Signal, Image and Video Processing}, + pages={1--8}, + year={2023}, + publisher={Springer} +} +``` diff --git a/configs/detection/uod_air/README.md b/configs/detection/uod_air/README.md index 9b00c99..4748c8a 100644 --- a/configs/detection/uod_air/README.md +++ b/configs/detection/uod_air/README.md @@ -14,7 +14,17 @@ Underwater object detection plays an important role in a variety of marine appli ## Results and Analysis -Coming soon +| Architecture | Neck | Lr schd | lr | box AP | Config | Download | +| :------------------------------------: | :--: | :-----: | :--: | :----: | :------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| Faster R-CNN | FPN | 1x | 0.02 | 43.5 | [config](./base_detector/faster-rcnn_r50_fpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/faster-rcnn_r50_fpn_1x_urpc-coco_20220226_105840-09ef8403.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/faster-rcnn_r50_fpn_1x_urpc-coco_20220226_105840.log.json) | +| Faster R-CNN | UFPN | 1x | 0.02 | 44.0 | [config](./base_detector/faster-rcnn_r50_ufpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/untagged-13f3dfa124d975df43f5/faster-rcnn_r50_ufpn_1x_urpc-coco_20231027_211425-61d901bb.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/untagged-13f3dfa124d975df43f5/faster-rcnn_r50_ufpn_1x_urpc-coco_20231027_211425.log.json) | +| Faster R-CNN with Image Reconstruction | UFPN | 1x | 0.02 | 44.3 | [config](./uod-air_faster-rcnn_r50_ufpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/untagged-13f3dfa124d975df43f5/uod-air_faster-rcnn_r50_ufpn_1x_urpc-coco_20231027_145407-6ae6d373.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/untagged-13f3dfa124d975df43f5/uod-air_faster-rcnn_r50_ufpn_1x_urpc-coco_20231027_145407.log.json) | +| RetinaNet | FPN | 1x | 0.01 | 40.7 | [config](./base_detector/retinanet_r50_fpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/retinanet_r50_fpn_1x_urpc-coco_20220405_214951-a39f054e.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/retinanet_r50_fpn_1x_urpc-coco_20220405_214951.log.json) | +| RetinaNet | UFPN | 1x | 0.01 | 41.8 | [config](./base_detector/retinanet_r50_ufpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/untagged-13f3dfa124d975df43f5/retinanet_r50_ufpn_1x_urpc-coco_20231027_215756-7803a5f9.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/untagged-13f3dfa124d975df43f5/retinanet_r50_ufpn_1x_urpc-coco_20231027_215756.log.json) | +| RetinaNet with Image Reconstruction | UFPN | 1x | 0.01 | 42.3 | [config](./uod-air_retinanet_r50_ufpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/untagged-13f3dfa124d975df43f5/uod-air_retinanet_r50_ufpn_1x_urpc-coco_20231027_224724-fe3acfba.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/untagged-13f3dfa124d975df43f5/uod-air_retinanet_r50_ufpn_1x_urpc-coco_20231027_224724.log.json) | +| RetinaNet with Image Reconstruction | UFPN | 1x | 0.02 | 43.3 | [config](./uod-air_retinanet_r50_ufpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/untagged-13f3dfa124d975df43f5/uod-air_retinanet_r50_ufpn_1x_urpc-coco_lr002_20231027_215752-b727baaf.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/untagged-13f3dfa124d975df43f5/uod-air_retinanet_r50_ufpn_1x_urpc-coco_lr002_20231027_215752.log.json) | + +**Note:** The original paper was developed based on MMDetection 2.0. LQIT optimized the network structure. LQIT has aligned the AP results on Faster R-CNN, but got 0.1 AP fluctuation on RetinaNet. ## Citation diff --git a/configs/detection/uod_air/base_detector/faster-rcnn_r50_fpn_1x_urpc-coco.py b/configs/detection/uod_air/base_detector/faster-rcnn_r50_fpn_1x_urpc-coco.py index 512d9d7..a28ff6f 100644 --- a/configs/detection/uod_air/base_detector/faster-rcnn_r50_fpn_1x_urpc-coco.py +++ b/configs/detection/uod_air/base_detector/faster-rcnn_r50_fpn_1x_urpc-coco.py @@ -23,11 +23,9 @@ style='pytorch', init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), neck=dict( - type='lqit.UFPN', + type='FPN', in_channels=[256, 512, 1024, 2048], out_channels=256, - start_level=0, - add_extra_convs='on_output', num_outs=5), rpn_head=dict( type='RPNHead', diff --git a/configs/detection/uod_air/base_detector/faster-rcnn_r50_ufpn_1x_urpc-coco.py b/configs/detection/uod_air/base_detector/faster-rcnn_r50_ufpn_1x_urpc-coco.py new file mode 100644 index 0000000..512d9d7 --- /dev/null +++ b/configs/detection/uod_air/base_detector/faster-rcnn_r50_ufpn_1x_urpc-coco.py @@ -0,0 +1,121 @@ +_base_ = [ + '../../_base_/datasets/urpc2020/urpc2020-validation_coco_detection.py', + '../../_base_/schedules/schedule_1x.py', '../../_base_/default_runtime.py' +] + +# model settings +model = dict( + type='FasterRCNN', + data_preprocessor=dict( + type='DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32), + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type='lqit.UFPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=0, + add_extra_convs='on_output', + num_outs=5), + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0)), + roi_head=dict( + type='StandardRoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=4, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=False, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.5), + max_per_img=100) + # soft-nms is also supported for rcnn testing + # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) + )) diff --git a/configs/detection/uod_air/base_detector/retinanet_r50_fpn_1x_urpc-coco.py b/configs/detection/uod_air/base_detector/retinanet_r50_fpn_1x_urpc-coco.py new file mode 100644 index 0000000..bfc825d --- /dev/null +++ b/configs/detection/uod_air/base_detector/retinanet_r50_fpn_1x_urpc-coco.py @@ -0,0 +1,91 @@ +_base_ = [ + '../../_base_/datasets/urpc2020/urpc2020-validation_coco_detection.py', + '../../_base_/schedules/schedule_1x.py', '../../_base_/default_runtime.py' +] + +# model settings +model = dict( + type='RetinaNet', + data_preprocessor=dict( + type='DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32), + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs='on_input', + num_outs=5), + bbox_head=dict( + type='RetinaHead', + num_classes=4, + in_channels=256, + stacked_convs=4, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + octave_base_scale=4, + scales_per_octave=3, + ratios=[0.5, 1.0, 2.0], + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0)), + # model training and testing settings + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.4, + min_pos_iou=0, + ignore_iof_thr=-1), + sampler=dict( + type='PseudoSampler'), # Focal loss should use PseudoSampler + allowed_border=-1, + pos_weight=-1, + debug=False), + test_cfg=dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.5), + max_per_img=100)) + +# optimizer +optim_wrapper = dict( + optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001), + clip_grad=dict(max_norm=35, norm_type=2)) # loss may NaN without clip_grad + +param_scheduler = [ + dict( + type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, + end=1000), + dict( + type='MultiStepLR', + begin=0, + end=12, + by_epoch=True, + milestones=[8, 11], + gamma=0.1) +] diff --git a/configs/detection/uod_air/base_detector/retinanet_r50_ufpn-p2-p7_1x_urpc-coco.py b/configs/detection/uod_air/base_detector/retinanet_r50_ufpn-p2-p7_1x_urpc-coco.py new file mode 100644 index 0000000..9dd042d --- /dev/null +++ b/configs/detection/uod_air/base_detector/retinanet_r50_ufpn-p2-p7_1x_urpc-coco.py @@ -0,0 +1,91 @@ +_base_ = [ + '../../_base_/datasets/urpc2020/urpc2020-validation_coco_detection.py', + '../../_base_/schedules/schedule_1x.py', '../../_base_/default_runtime.py' +] + +# model settings +model = dict( + type='RetinaNet', + data_preprocessor=dict( + type='DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32), + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type='lqit.UFPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=0, + add_extra_convs='on_output', + num_outs=6), + bbox_head=dict( + type='RetinaHead', + num_classes=4, + in_channels=256, + stacked_convs=4, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + octave_base_scale=4, + scales_per_octave=3, + ratios=[0.5, 1.0, 2.0], + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0)), + # model training and testing settings + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.4, + min_pos_iou=0, + ignore_iof_thr=-1), + sampler=dict( + type='PseudoSampler'), # Focal loss should use PseudoSampler + allowed_border=-1, + pos_weight=-1, + debug=False), + test_cfg=dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.5), + max_per_img=100)) + +# optimizer +optim_wrapper = dict( + optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001), + clip_grad=dict(max_norm=35, norm_type=2)) # loss may NaN without clip_grad + +param_scheduler = [ + dict( + type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, + end=1000), + dict( + type='MultiStepLR', + begin=0, + end=12, + by_epoch=True, + milestones=[8, 11], + gamma=0.1) +] diff --git a/configs/detection/uod_air/base_detector/retinanet_r50_ufpn_1x_urpc-coco.py b/configs/detection/uod_air/base_detector/retinanet_r50_ufpn_1x_urpc-coco.py index 9dd042d..f51c2fd 100644 --- a/configs/detection/uod_air/base_detector/retinanet_r50_ufpn_1x_urpc-coco.py +++ b/configs/detection/uod_air/base_detector/retinanet_r50_ufpn_1x_urpc-coco.py @@ -26,9 +26,9 @@ type='lqit.UFPN', in_channels=[256, 512, 1024, 2048], out_channels=256, - start_level=0, + start_level=1, add_extra_convs='on_output', - num_outs=6), + num_outs=5), bbox_head=dict( type='RetinaHead', num_classes=4, diff --git a/configs/detection/uod_air/base_ehance_head/enhance_head.py b/configs/detection/uod_air/base_ehance_head/enhance_head.py index 0d308f2..a1e4e1d 100644 --- a/configs/detection/uod_air/base_ehance_head/enhance_head.py +++ b/configs/detection/uod_air/base_ehance_head/enhance_head.py @@ -3,8 +3,8 @@ type='BasicEnhanceHead', in_channels=256, feat_channels=256, - num_convs=2, - loss_enhance=dict(type='L1Loss', loss_weight=1.0), + num_convs=5, + loss_enhance=dict(type='L1Loss', loss_weight=0.5), gt_preprocessor=dict( type='GTPixelPreprocessor', mean=[123.675, 116.28, 103.53], diff --git a/configs/detection/uod_air/uod-air_faster-rcnn_r50_fpn_1x_urpc-coco.py b/configs/detection/uod_air/uod-air_faster-rcnn_r50_ufpn_1x_urpc-coco.py similarity index 91% rename from configs/detection/uod_air/uod-air_faster-rcnn_r50_fpn_1x_urpc-coco.py rename to configs/detection/uod_air/uod-air_faster-rcnn_r50_ufpn_1x_urpc-coco.py index c8a8d94..6b98aba 100644 --- a/configs/detection/uod_air/uod-air_faster-rcnn_r50_fpn_1x_urpc-coco.py +++ b/configs/detection/uod_air/uod-air_faster-rcnn_r50_ufpn_1x_urpc-coco.py @@ -1,5 +1,5 @@ _base_ = [ - './base_detector/faster-rcnn_r50_fpn_1x_urpc-coco.py', + './base_detector/faster-rcnn_r50_ufpn_1x_urpc-coco.py', './base_ehance_head/enhance_head.py' ] diff --git a/configs/detection/uod_air/uod-air_retinanet_r50_ufpn_1x_4xbs4_urpc-coco.py b/configs/detection/uod_air/uod-air_retinanet_r50_ufpn_1x_4xbs4_urpc-coco.py new file mode 100644 index 0000000..1f0e4a9 --- /dev/null +++ b/configs/detection/uod_air/uod-air_retinanet_r50_ufpn_1x_4xbs4_urpc-coco.py @@ -0,0 +1,4 @@ +_base_ = [ + './uod-air_retinanet_r50_ufpn_1x_urpc-coco.py', +] +train_dataloader = dict(batch_size=4, num_workers=4) diff --git a/configs/detection/uod_air/uod-air_retinanet_r50_ufpn_1x_urpc-coco_lr002.py b/configs/detection/uod_air/uod-air_retinanet_r50_ufpn_1x_urpc-coco_lr002.py new file mode 100644 index 0000000..ea7b8b8 --- /dev/null +++ b/configs/detection/uod_air/uod-air_retinanet_r50_ufpn_1x_urpc-coco_lr002.py @@ -0,0 +1,7 @@ +_base_ = [ + './uod-air_retinanet_r50_ufpn_1x_urpc-coco.py', +] +# optimizer +optim_wrapper = dict( + optimizer=dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001), + clip_grad=dict(max_norm=35, norm_type=2)) # loss may NaN without clip_grad diff --git a/lqit/detection/models/detectors/detector_with_enhance_head.py b/lqit/detection/models/detectors/detector_with_enhance_head.py index 92e110d..5a99884 100644 --- a/lqit/detection/models/detectors/detector_with_enhance_head.py +++ b/lqit/detection/models/detectors/detector_with_enhance_head.py @@ -22,6 +22,8 @@ class DetectorWithEnhanceHead(BaseModel): Args: detector (dict or ConfigDict): Config for detector. enhance_head (dict or ConfigDict, optional): Config for enhance head. + process_gt_preprocessor (bool): Whether process `gt_preprocessor` same + as the `data_preprocessor` in detector. Defaults to True. vis_enhance (bool): Whether to visualize the enhanced image during inference. Defaults to False. init_cfg (dict or ConfigDict, optional): The config to control the @@ -31,12 +33,13 @@ class DetectorWithEnhanceHead(BaseModel): def __init__(self, detector: ConfigType, enhance_head: OptConfigType = None, + process_gt_preprocessor: bool = True, vis_enhance: Optional[bool] = False, init_cfg: OptMultiConfig = None) -> None: super().__init__(init_cfg=init_cfg) # process gt_preprocessor - if enhance_head is not None: + if enhance_head is not None and process_gt_preprocessor: enhance_head = self.process_gt_preprocessor(detector, enhance_head) # build data_preprocessor @@ -326,7 +329,6 @@ def loss(self, batch_inputs: Tensor, losses = dict() if self.with_enhance_head: - enhance_loss = self.enhance_head.loss(x, batch_data_samples) # avoid loss override assert not set(enhance_loss.keys()) & set(losses.keys()) diff --git a/lqit/edit/models/editor_heads/basic_enhance_head.py b/lqit/edit/models/editor_heads/basic_enhance_head.py index 6034ad1..f1eb616 100644 --- a/lqit/edit/models/editor_heads/basic_enhance_head.py +++ b/lqit/edit/models/editor_heads/basic_enhance_head.py @@ -232,7 +232,10 @@ def loss_by_feat_single(self, enhance_img, gt_img, img_meta): @MODELS.register_module() class BasicEnhanceHead(BaseEnhanceHead): - """[(convs)+ShufflePixes] * 2""" + """Basic enhance head. + + [Conv-BN-ReLU] * (num_convs - 1) + Conv + """ def __init__(self, in_channels=256, diff --git a/tools/test.py b/tools/test.py index 11436a1..536ca2b 100644 --- a/tools/test.py +++ b/tools/test.py @@ -135,7 +135,10 @@ def main(args): monitor_manager = None - if args.lark: + if not args.lark: + main(args) + else: + # report the running status to lark bot lark_file = args.lark_file if not osp.exists(lark_file): warnings.warn(f'{lark_file} not exists, skip.') @@ -145,21 +148,18 @@ def main(args): lark_url = lark.get('lark', None) if lark_url is None: warnings.warn(f'{lark_file} does not have `lark`, skip.') + else: + monitor_interval_seconds = lark.get('monitor_interval_seconds', + 300) + user_name = lark.get('user_name', None) + monitor_manager = initialize_monitor_manager( + cfg_file=args.config, + url=lark_url, + task_type='test', + user_name=user_name, + monitor_interval_seconds=monitor_interval_seconds, + ckpt_path=args.checkpoint) - monitor_interval_seconds = lark.get('monitor_interval_seconds', - None) - if monitor_interval_seconds is None: - monitor_interval_seconds = 300 - - user_name = lark.get('user_name', None) - - monitor_manager = initialize_monitor_manager( - cfg_file=args.config, - url=lark_url, - task_type='test', - user_name=user_name, - monitor_interval_seconds=monitor_interval_seconds, - ckpt_path=args.checkpoint) with context_monitor_manager(monitor_manager): try: main(args) diff --git a/tools/train.py b/tools/train.py index bad72b9..fc1db95 100644 --- a/tools/train.py +++ b/tools/train.py @@ -157,9 +157,8 @@ def main(args): monitor_manager = None - if not args.lark: - main(args) - else: + if args.lark: + # report the running status to lark bot lark_file = args.lark_file if not osp.exists(lark_file): warnings.warn(f'{lark_file} not exists, skip.') @@ -169,20 +168,17 @@ def main(args): lark_url = lark.get('lark', None) if lark_url is None: warnings.warn(f'{lark_file} does not have `lark`, skip.') + else: + monitor_interval_seconds = lark.get('monitor_interval_seconds', + 300) + user_name = lark.get('user_name', None) + monitor_manager = initialize_monitor_manager( + cfg_file=args.config, + url=lark_url, + task_type='train', + user_name=user_name, + monitor_interval_seconds=monitor_interval_seconds) - monitor_interval_seconds = lark.get('monitor_interval_seconds', - None) - if monitor_interval_seconds is None: - monitor_interval_seconds = 300 - - user_name = lark.get('user_name', None) - - monitor_manager = initialize_monitor_manager( - cfg_file=args.config, - url=lark_url, - task_type='train', - user_name=user_name, - monitor_interval_seconds=monitor_interval_seconds) with context_monitor_manager(monitor_manager): try: main(args) From 855da6d2e5a616cefea6e5c0449800282acd044d Mon Sep 17 00:00:00 2001 From: BIGWangYuDong Date: Sat, 28 Oct 2023 00:53:05 +0800 Subject: [PATCH 4/5] Support TIENet --- configs/detection/edffnet/README.md | 30 ++++ ...c-coco.py => atss_r50_fpn_1x_rtts-coco.py} | 45 +++--- .../faster-rcnn_r50_fpn_1x_rtts-coco.py | 133 ++++++++++++++++++ ...fcos_r50-caffe_fpn_gn-head_1x_urpc-coco.py | 77 ---------- .../retinanet_r50_fpn_1x_rtts-coco.py | 105 ++++++++++++++ .../tood_r50_fpn_1x_rtts-coco.py | 95 +++++++++++++ .../tienet_atss_r50_fpn_1x_rtts-coco.py | 35 +++++ .../tienet_atss_r50_fpn_1x_urpc-coco.py | 35 +++++ ..._faster-rcnn_r50_fpn_1x_4xbs4_urpc-coco.py | 3 - ...tienet_faster-rcnn_r50_fpn_1x_rtts-coco.py | 35 +++++ ...et_retinanet_r50_fpn_1x_4xbs4_urpc-coco.py | 3 - .../tienet_retinanet_r50_fpn_1x_rtts-coco.py | 35 +++++ .../tienet_tood_r50_fpn_1x_rtts-coco.py | 35 +++++ .../tienet_tood_r50_fpn_1x_urpc-coco.py | 35 +++++ 14 files changed, 592 insertions(+), 109 deletions(-) create mode 100644 configs/detection/edffnet/README.md rename configs/detection/tienet/base_detector/{paa_r50_fpn_1x_urpc-coco.py => atss_r50_fpn_1x_rtts-coco.py} (73%) create mode 100644 configs/detection/tienet/base_detector/faster-rcnn_r50_fpn_1x_rtts-coco.py delete mode 100644 configs/detection/tienet/base_detector/fcos_r50-caffe_fpn_gn-head_1x_urpc-coco.py create mode 100644 configs/detection/tienet/base_detector/retinanet_r50_fpn_1x_rtts-coco.py create mode 100644 configs/detection/tienet/base_detector/tood_r50_fpn_1x_rtts-coco.py create mode 100644 configs/detection/tienet/tienet_atss_r50_fpn_1x_rtts-coco.py create mode 100644 configs/detection/tienet/tienet_atss_r50_fpn_1x_urpc-coco.py delete mode 100644 configs/detection/tienet/tienet_faster-rcnn_r50_fpn_1x_4xbs4_urpc-coco.py create mode 100644 configs/detection/tienet/tienet_faster-rcnn_r50_fpn_1x_rtts-coco.py delete mode 100644 configs/detection/tienet/tienet_retinanet_r50_fpn_1x_4xbs4_urpc-coco.py create mode 100644 configs/detection/tienet/tienet_retinanet_r50_fpn_1x_rtts-coco.py create mode 100644 configs/detection/tienet/tienet_tood_r50_fpn_1x_rtts-coco.py create mode 100644 configs/detection/tienet/tienet_tood_r50_fpn_1x_urpc-coco.py diff --git a/configs/detection/edffnet/README.md b/configs/detection/edffnet/README.md new file mode 100644 index 0000000..dd25991 --- /dev/null +++ b/configs/detection/edffnet/README.md @@ -0,0 +1,30 @@ +# Edge-Guided Dynamic Feature Fusion Network for Object Detection under Foggy Conditions + + + +## Abstract + +Hazy images are often subject to blurring, low contrast and other visible quality degradation, making it challenging to solve object detection tasks. Most methods solve the domain shift problem by deep domain adaptive technology, ignoring the inaccurate object classification and localization caused by quality degradation. Different from common methods, we present an edge-guided dynamic feature fusion network (EDFFNet), which formulates the edge head as a guide to the localization task. Despite the edge head being straightforward, we demonstrate that it makes the model pay attention to the edge of object instances and improves the generalization and localization ability of the network. Considering the fuzzy details and the multi-scale problem of hazy images, we propose a dynamic fusion feature pyramid network (DF-FPN) to enhance the feature representation ability of the whole model. A unique advantage of DF-FPN is that the contribution to the fused feature map will dynamically adjust with the learning of the network. Extensive experiments verify that EDFFNet achieves 2.4% AP and 3.6% AP gains over the ATSS baseline on RTTS and Foggy Cityscapes, respectively. + +
+ +
+ +## Results and Analysis + +Coming soon + +## Citation + +```latex +@article{he2023edge, + title={Edge-guided dynamic feature fusion network for object detection under foggy conditions}, + author={He, Wanru and Guo, Jichang and Wang, Yudong and Zheng, Sida}, + journal={Signal, Image and Video Processing}, + volume={17}, + number={5}, + pages={1975--1983}, + year={2023}, + publisher={Springer} +} +``` diff --git a/configs/detection/tienet/base_detector/paa_r50_fpn_1x_urpc-coco.py b/configs/detection/tienet/base_detector/atss_r50_fpn_1x_rtts-coco.py similarity index 73% rename from configs/detection/tienet/base_detector/paa_r50_fpn_1x_urpc-coco.py rename to configs/detection/tienet/base_detector/atss_r50_fpn_1x_rtts-coco.py index a229916..9bc8ef8 100644 --- a/configs/detection/tienet/base_detector/paa_r50_fpn_1x_urpc-coco.py +++ b/configs/detection/tienet/base_detector/atss_r50_fpn_1x_rtts-coco.py @@ -1,11 +1,11 @@ _base_ = [ - '../../_base_/datasets/urpc2020/urpc2020-validation_coco_detection.py', + '../../_base_/datasets/rtts_coco.py', '../../_base_/schedules/schedule_1x.py', '../../_base_/default_runtime.py' ] # model settings model = dict( - type='PAA', + type='ATSS', data_preprocessor=dict( type='DetDataPreprocessor', mean=[123.675, 116.28, 103.53], @@ -30,10 +30,7 @@ add_extra_convs='on_output', num_outs=5), bbox_head=dict( - type='PAAHead', - reg_decoded_bbox=True, - score_voting=True, - topk=9, + type='ATSSHead', num_classes=5, in_channels=256, stacked_convs=4, @@ -54,17 +51,12 @@ gamma=2.0, alpha=0.25, loss_weight=1.0), - loss_bbox=dict(type='GIoULoss', loss_weight=1.3), + loss_bbox=dict(type='GIoULoss', loss_weight=2.0), loss_centerness=dict( - type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.5)), + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)), # training and testing settings train_cfg=dict( - assigner=dict( - type='MaxIoUAssigner', - pos_iou_thr=0.1, - neg_iou_thr=0.1, - min_pos_iou=0, - ignore_iof_thr=-1), + assigner=dict(type='ATSSAssigner', topk=9), allowed_border=-1, pos_weight=-1, debug=False), @@ -79,15 +71,16 @@ optim_wrapper = dict( optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)) -# learning rate -param_scheduler = [ - dict( - type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=500), - dict( - type='MultiStepLR', - begin=0, - end=12, - by_epoch=True, - milestones=[8, 11], - gamma=0.1) -] +# add WandbVisBackend +# vis_backends = [ +# dict(type='LocalVisBackend'), +# dict(type='WandbVisBackend', +# init_kwargs=dict( +# project='rtts_detection', +# name='atss_r50_fpn_1x_rtts', +# entity='lqit', +# ) +# ) +# ] +# visualizer = dict( +# type='DetLocalVisualizer', vis_backends=vis_backends, name='visualizer') diff --git a/configs/detection/tienet/base_detector/faster-rcnn_r50_fpn_1x_rtts-coco.py b/configs/detection/tienet/base_detector/faster-rcnn_r50_fpn_1x_rtts-coco.py new file mode 100644 index 0000000..d9fff27 --- /dev/null +++ b/configs/detection/tienet/base_detector/faster-rcnn_r50_fpn_1x_rtts-coco.py @@ -0,0 +1,133 @@ +_base_ = [ + '../../_base_/datasets/rtts_coco.py', + '../../_base_/schedules/schedule_1x.py', '../../_base_/default_runtime.py' +] + +# model settings +model = dict( + type='FasterRCNN', + data_preprocessor=dict( + type='DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32), + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=5), + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0)), + roi_head=dict( + type='StandardRoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=5, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=False, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.5), + max_per_img=100) + # soft-nms is also supported for rcnn testing + # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) + )) + +# add WandbVisBackend +# vis_backends = [ +# dict(type='LocalVisBackend'), +# dict(type='WandbVisBackend', +# init_kwargs=dict( +# project='rtts_detection', +# name='faster-rcnn_r50_fpn_1x_rtts', +# entity='lqit', +# ) +# ) +# ] +# visualizer = dict( +# type='DetLocalVisualizer', vis_backends=vis_backends, name='visualizer') diff --git a/configs/detection/tienet/base_detector/fcos_r50-caffe_fpn_gn-head_1x_urpc-coco.py b/configs/detection/tienet/base_detector/fcos_r50-caffe_fpn_gn-head_1x_urpc-coco.py deleted file mode 100644 index 8f27672..0000000 --- a/configs/detection/tienet/base_detector/fcos_r50-caffe_fpn_gn-head_1x_urpc-coco.py +++ /dev/null @@ -1,77 +0,0 @@ -_base_ = [ - '../../_base_/datasets/urpc2020/urpc2020-validation_coco_detection.py', - '../../_base_/schedules/schedule_1x.py', '../../_base_/default_runtime.py' -] - -# model settings -model = dict( - type='FCOS', - data_preprocessor=dict( - type='DetDataPreprocessor', - mean=[102.9801, 115.9465, 122.7717], - std=[1.0, 1.0, 1.0], - bgr_to_rgb=False, - pad_size_divisor=32), - backbone=dict( - type='ResNet', - depth=50, - num_stages=4, - out_indices=(0, 1, 2, 3), - frozen_stages=1, - norm_cfg=dict(type='BN', requires_grad=False), - norm_eval=True, - style='caffe', - init_cfg=dict( - type='Pretrained', - checkpoint='open-mmlab://detectron/resnet50_caffe')), - neck=dict( - type='FPN', - in_channels=[256, 512, 1024, 2048], - out_channels=256, - start_level=1, - add_extra_convs='on_output', # use P5 - num_outs=5, - relu_before_extra_convs=True), - bbox_head=dict( - type='FCOSHead', - num_classes=4, - in_channels=256, - stacked_convs=4, - feat_channels=256, - strides=[8, 16, 32, 64, 128], - loss_cls=dict( - type='FocalLoss', - use_sigmoid=True, - gamma=2.0, - alpha=0.25, - loss_weight=1.0), - loss_bbox=dict(type='IoULoss', loss_weight=1.0), - loss_centerness=dict( - type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)), - # testing settings - test_cfg=dict( - nms_pre=1000, - min_bbox_size=0, - score_thr=0.05, - nms=dict(type='nms', iou_threshold=0.5), - max_per_img=100)) - -# learning rate -param_scheduler = [ - dict( - type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, - end=1000), - dict( - type='MultiStepLR', - begin=0, - end=12, - by_epoch=True, - milestones=[8, 11], - gamma=0.1) -] - -# optimizer -optim_wrapper = dict( - optimizer=dict(lr=0.01), - paramwise_cfg=dict(bias_lr_mult=2., bias_decay_mult=0.), - clip_grad=dict(max_norm=35, norm_type=2)) # loss may NaN without clip_grad diff --git a/configs/detection/tienet/base_detector/retinanet_r50_fpn_1x_rtts-coco.py b/configs/detection/tienet/base_detector/retinanet_r50_fpn_1x_rtts-coco.py new file mode 100644 index 0000000..040e035 --- /dev/null +++ b/configs/detection/tienet/base_detector/retinanet_r50_fpn_1x_rtts-coco.py @@ -0,0 +1,105 @@ +_base_ = [ + '../../_base_/datasets/rtts_coco.py', + '../../_base_/schedules/schedule_1x.py', '../../_base_/default_runtime.py' +] + +# model settings +model = dict( + type='RetinaNet', + data_preprocessor=dict( + type='DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32), + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs='on_input', + num_outs=5), + bbox_head=dict( + type='RetinaHead', + num_classes=5, + in_channels=256, + stacked_convs=4, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + octave_base_scale=4, + scales_per_octave=3, + ratios=[0.5, 1.0, 2.0], + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0)), + # model training and testing settings + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.4, + min_pos_iou=0, + ignore_iof_thr=-1), + sampler=dict( + type='PseudoSampler'), # Focal loss should use PseudoSampler + allowed_border=-1, + pos_weight=-1, + debug=False), + test_cfg=dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.5), + max_per_img=100)) + +# optimizer +optim_wrapper = dict( + optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001), + clip_grad=dict(max_norm=35, norm_type=2)) # loss may NaN without clip_grad + +param_scheduler = [ + dict( + type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, + end=1000), + dict( + type='MultiStepLR', + begin=0, + end=12, + by_epoch=True, + milestones=[8, 11], + gamma=0.1) +] + +# add WandbVisBackend +# vis_backends = [ +# dict(type='LocalVisBackend'), +# dict(type='WandbVisBackend', +# init_kwargs=dict( +# project='rtts_detection', +# name='retinanet_r50_fpn_1x_rtts', +# entity='lqit', +# ) +# ) +# ] +# visualizer = dict( +# type='DetLocalVisualizer', vis_backends=vis_backends, name='visualizer') diff --git a/configs/detection/tienet/base_detector/tood_r50_fpn_1x_rtts-coco.py b/configs/detection/tienet/base_detector/tood_r50_fpn_1x_rtts-coco.py new file mode 100644 index 0000000..1de9e4a --- /dev/null +++ b/configs/detection/tienet/base_detector/tood_r50_fpn_1x_rtts-coco.py @@ -0,0 +1,95 @@ +_base_ = [ + '../../_base_/datasets/rtts_coco.py', + '../../_base_/schedules/schedule_1x.py', '../../_base_/default_runtime.py' +] + +# model settings +model = dict( + type='TOOD', + data_preprocessor=dict( + type='DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32), + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs='on_output', + num_outs=5), + bbox_head=dict( + type='TOODHead', + num_classes=5, + in_channels=256, + stacked_convs=6, + feat_channels=256, + anchor_type='anchor_free', + anchor_generator=dict( + type='AnchorGenerator', + ratios=[1.0], + octave_base_scale=8, + scales_per_octave=1, + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[0.1, 0.1, 0.2, 0.2]), + initial_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + activated=True, # use probability instead of logit as input + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_cls=dict( + type='QualityFocalLoss', + use_sigmoid=True, + activated=True, # use probability instead of logit as input + beta=2.0, + loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=2.0)), + train_cfg=dict( + initial_epoch=4, + initial_assigner=dict(type='ATSSAssigner', topk=9), + assigner=dict(type='TaskAlignedAssigner', topk=13), + alpha=1, + beta=6, + allowed_border=-1, + pos_weight=-1, + debug=False), + test_cfg=dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.6), + max_per_img=100)) + +# optimizer +optim_wrapper = dict( + optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)) + +# add WandbVisBackend +# vis_backends = [ +# dict(type='LocalVisBackend'), +# dict(type='WandbVisBackend', +# init_kwargs=dict( +# project='rtts_detection', +# name='tood_r50_fpn_1x_rtts', +# entity='lqit', +# ) +# ) +# ] +# visualizer = dict( +# type='DetLocalVisualizer', vis_backends=vis_backends, name='visualizer') diff --git a/configs/detection/tienet/tienet_atss_r50_fpn_1x_rtts-coco.py b/configs/detection/tienet/tienet_atss_r50_fpn_1x_rtts-coco.py new file mode 100644 index 0000000..5ab9ac7 --- /dev/null +++ b/configs/detection/tienet/tienet_atss_r50_fpn_1x_rtts-coco.py @@ -0,0 +1,35 @@ +# default scope is mmdet +_base_ = [ + './base_editor/tienet_enhance_model.py', + './base_detector/atss_r50_fpn_1x_rtts-coco.py' +] + +model = dict( + _delete_=True, + type='lqit.DetectorWithEnhanceModel', + detector={{_base_.model}}, + enhance_model={{_base_.enhance_model}}, + train_mode='enhance', + pred_mode='enhance', + detach_enhance_img=False) + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001), + clip_grad=dict(max_norm=35, norm_type=2)) + +# dataset settings +train_pipeline = [ + dict(type='LoadImageFromFile', backend_args={{_base_.backend_args}}), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='Resize', scale=(1333, 800), keep_ratio=True), + dict(type='RandomFlip', prob=0.5), + dict(type='lqit.SetInputImageAsGT'), + dict(type='lqit.PackInputs') +] +train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) + +model_wrapper_cfg = dict( + type='lqit.SelfEnhanceModelDDP', + broadcast_buffers=False, + find_unused_parameters=False) diff --git a/configs/detection/tienet/tienet_atss_r50_fpn_1x_urpc-coco.py b/configs/detection/tienet/tienet_atss_r50_fpn_1x_urpc-coco.py new file mode 100644 index 0000000..3d61804 --- /dev/null +++ b/configs/detection/tienet/tienet_atss_r50_fpn_1x_urpc-coco.py @@ -0,0 +1,35 @@ +# default scope is mmdet +_base_ = [ + './base_editor/tienet_enhance_model.py', + './base_detector/atss_r50_fpn_1x_urpc-coco.py' +] + +model = dict( + _delete_=True, + type='lqit.DetectorWithEnhanceModel', + detector={{_base_.model}}, + enhance_model={{_base_.enhance_model}}, + train_mode='enhance', + pred_mode='enhance', + detach_enhance_img=False) + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001), + clip_grad=dict(max_norm=35, norm_type=2)) + +# dataset settings +train_pipeline = [ + dict(type='LoadImageFromFile', backend_args={{_base_.backend_args}}), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='Resize', scale=(1333, 800), keep_ratio=True), + dict(type='RandomFlip', prob=0.5), + dict(type='lqit.SetInputImageAsGT'), + dict(type='lqit.PackInputs') +] +train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) + +model_wrapper_cfg = dict( + type='lqit.SelfEnhanceModelDDP', + broadcast_buffers=False, + find_unused_parameters=False) diff --git a/configs/detection/tienet/tienet_faster-rcnn_r50_fpn_1x_4xbs4_urpc-coco.py b/configs/detection/tienet/tienet_faster-rcnn_r50_fpn_1x_4xbs4_urpc-coco.py deleted file mode 100644 index 17910b9..0000000 --- a/configs/detection/tienet/tienet_faster-rcnn_r50_fpn_1x_4xbs4_urpc-coco.py +++ /dev/null @@ -1,3 +0,0 @@ -_base_ = ['./tienet_faster-rcnn_r50_fpn_1x_urpc-coco.py'] - -train_dataloader = dict(batch_size=4, num_workers=4) diff --git a/configs/detection/tienet/tienet_faster-rcnn_r50_fpn_1x_rtts-coco.py b/configs/detection/tienet/tienet_faster-rcnn_r50_fpn_1x_rtts-coco.py new file mode 100644 index 0000000..2c0384c --- /dev/null +++ b/configs/detection/tienet/tienet_faster-rcnn_r50_fpn_1x_rtts-coco.py @@ -0,0 +1,35 @@ +# default scope is mmdet +_base_ = [ + './base_editor/tienet_enhance_model.py', + './base_detector/faster-rcnn_r50_fpn_1x_rtts-coco.py' +] + +model = dict( + _delete_=True, + type='lqit.DetectorWithEnhanceModel', + detector={{_base_.model}}, + enhance_model={{_base_.enhance_model}}, + train_mode='enhance', + pred_mode='enhance', + detach_enhance_img=False) + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001), + clip_grad=dict(max_norm=35, norm_type=2)) + +# dataset settings +train_pipeline = [ + dict(type='LoadImageFromFile', backend_args={{_base_.backend_args}}), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='Resize', scale=(1333, 800), keep_ratio=True), + dict(type='RandomFlip', prob=0.5), + dict(type='lqit.SetInputImageAsGT'), + dict(type='lqit.PackInputs') +] +train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) + +model_wrapper_cfg = dict( + type='lqit.SelfEnhanceModelDDP', + broadcast_buffers=False, + find_unused_parameters=False) diff --git a/configs/detection/tienet/tienet_retinanet_r50_fpn_1x_4xbs4_urpc-coco.py b/configs/detection/tienet/tienet_retinanet_r50_fpn_1x_4xbs4_urpc-coco.py deleted file mode 100644 index 3c5ab3b..0000000 --- a/configs/detection/tienet/tienet_retinanet_r50_fpn_1x_4xbs4_urpc-coco.py +++ /dev/null @@ -1,3 +0,0 @@ -_base_ = ['./tienet_retinanet_r50_fpn_1x_urpc-coco.py'] - -train_dataloader = dict(batch_size=4, num_workers=4) diff --git a/configs/detection/tienet/tienet_retinanet_r50_fpn_1x_rtts-coco.py b/configs/detection/tienet/tienet_retinanet_r50_fpn_1x_rtts-coco.py new file mode 100644 index 0000000..49b40d8 --- /dev/null +++ b/configs/detection/tienet/tienet_retinanet_r50_fpn_1x_rtts-coco.py @@ -0,0 +1,35 @@ +# default scope is mmdet +_base_ = [ + './base_editor/tienet_enhance_model.py', + './base_detector/retinanet_r50_fpn_1x_rtts-coco.py' +] + +model = dict( + _delete_=True, + type='lqit.DetectorWithEnhanceModel', + detector={{_base_.model}}, + enhance_model={{_base_.enhance_model}}, + train_mode='enhance', + pred_mode='enhance', + detach_enhance_img=False) + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001), + clip_grad=dict(max_norm=35, norm_type=2)) + +# dataset settings +train_pipeline = [ + dict(type='LoadImageFromFile', backend_args={{_base_.backend_args}}), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='Resize', scale=(1333, 800), keep_ratio=True), + dict(type='RandomFlip', prob=0.5), + dict(type='lqit.SetInputImageAsGT'), + dict(type='lqit.PackInputs') +] +train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) + +model_wrapper_cfg = dict( + type='lqit.SelfEnhanceModelDDP', + broadcast_buffers=False, + find_unused_parameters=False) diff --git a/configs/detection/tienet/tienet_tood_r50_fpn_1x_rtts-coco.py b/configs/detection/tienet/tienet_tood_r50_fpn_1x_rtts-coco.py new file mode 100644 index 0000000..88c2aa3 --- /dev/null +++ b/configs/detection/tienet/tienet_tood_r50_fpn_1x_rtts-coco.py @@ -0,0 +1,35 @@ +# default scope is mmdet +_base_ = [ + './base_editor/tienet_enhance_model.py', + './base_detector/tood_r50_fpn_1x_rtts-coco.py' +] + +model = dict( + _delete_=True, + type='lqit.DetectorWithEnhanceModel', + detector={{_base_.model}}, + enhance_model={{_base_.enhance_model}}, + train_mode='enhance', + pred_mode='enhance', + detach_enhance_img=False) + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001), + clip_grad=dict(max_norm=35, norm_type=2)) + +# dataset settings +train_pipeline = [ + dict(type='LoadImageFromFile', backend_args={{_base_.backend_args}}), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='Resize', scale=(1333, 800), keep_ratio=True), + dict(type='RandomFlip', prob=0.5), + dict(type='lqit.SetInputImageAsGT'), + dict(type='lqit.PackInputs') +] +train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) + +model_wrapper_cfg = dict( + type='lqit.SelfEnhanceModelDDP', + broadcast_buffers=False, + find_unused_parameters=False) diff --git a/configs/detection/tienet/tienet_tood_r50_fpn_1x_urpc-coco.py b/configs/detection/tienet/tienet_tood_r50_fpn_1x_urpc-coco.py new file mode 100644 index 0000000..ea8eebf --- /dev/null +++ b/configs/detection/tienet/tienet_tood_r50_fpn_1x_urpc-coco.py @@ -0,0 +1,35 @@ +# default scope is mmdet +_base_ = [ + './base_editor/tienet_enhance_model.py', + './base_detector/tood_r50_fpn_1x_urpc-coco.py' +] + +model = dict( + _delete_=True, + type='lqit.DetectorWithEnhanceModel', + detector={{_base_.model}}, + enhance_model={{_base_.enhance_model}}, + train_mode='enhance', + pred_mode='enhance', + detach_enhance_img=False) + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001), + clip_grad=dict(max_norm=35, norm_type=2)) + +# dataset settings +train_pipeline = [ + dict(type='LoadImageFromFile', backend_args={{_base_.backend_args}}), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='Resize', scale=(1333, 800), keep_ratio=True), + dict(type='RandomFlip', prob=0.5), + dict(type='lqit.SetInputImageAsGT'), + dict(type='lqit.PackInputs') +] +train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) + +model_wrapper_cfg = dict( + type='lqit.SelfEnhanceModelDDP', + broadcast_buffers=False, + find_unused_parameters=False) From 20d613f5e21afc12c62400d2fff848f5133bd15a Mon Sep 17 00:00:00 2001 From: BIGWangYuDong Date: Sat, 28 Oct 2023 22:33:14 +0800 Subject: [PATCH 5/5] update readme and add download links --- ...ster-rcnn_r50_fpn_basic-enhance_1x_coco.py | 56 ----- configs/detection/duo_dataset/README.md | 60 +----- configs/detection/edffnet/README.md | 13 +- .../atss_r50_dffpn_1x_rtts-coco_lr002.py | 12 ++ ....py => atss_r50_fpn_1x_rtts-coco_lr002.py} | 14 ++ ...fnet_atss_r50_dffpn_1x_rtts-coco_lr002.py} | 34 +-- configs/detection/rtts_dataset/README.md | 40 ++++ configs/detection/ruod_dataset/README.md | 2 +- configs/detection/tienet/README.md | 34 ++- configs/detection/uod_air/README.md | 24 ++- configs/detection/urpc2020_dataset/README.md | 77 +++---- lqit/detection/models/detectors/__init__.py | 5 +- .../detectors/detector_with_enhance_head.py | 4 +- .../detectors/detector_with_enhance_model.py | 4 +- lqit/detection/models/detectors/edffnet.py | 37 ++-- .../detectors/single_stage_enhance_head.py | 134 ------------ .../detectors/two_stage_enhance_head.py | 193 ------------------ lqit/edit/models/editor_heads/edge_head.py | 2 +- 18 files changed, 189 insertions(+), 556 deletions(-) delete mode 100644 configs/detection/detector_with_enhance_head/faster-rcnn_r50_fpn_basic-enhance_1x_coco.py create mode 100644 configs/detection/edffnet/atss_r50_dffpn_1x_rtts-coco_lr002.py rename configs/detection/edffnet/{atss_r50_fpn_1x_2xb8_rtts.py => atss_r50_fpn_1x_rtts-coco_lr002.py} (88%) rename configs/detection/edffnet/{edffnet.py => edffnet_atss_r50_dffpn_1x_rtts-coco_lr002.py} (55%) delete mode 100644 lqit/detection/models/detectors/single_stage_enhance_head.py delete mode 100644 lqit/detection/models/detectors/two_stage_enhance_head.py diff --git a/configs/detection/detector_with_enhance_head/faster-rcnn_r50_fpn_basic-enhance_1x_coco.py b/configs/detection/detector_with_enhance_head/faster-rcnn_r50_fpn_basic-enhance_1x_coco.py deleted file mode 100644 index fe63228..0000000 --- a/configs/detection/detector_with_enhance_head/faster-rcnn_r50_fpn_basic-enhance_1x_coco.py +++ /dev/null @@ -1,56 +0,0 @@ -_base_ = 'mmdet::faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py' - -model = dict( - type='lqit.TwoStageWithEnhanceHead', - backbone=dict(norm_eval=False), - enhance_head=dict( - _scope_='lqit', - type='BasicEnhanceHead', - in_channels=256, - feat_channels=256, - num_convs=5, - loss_enhance=dict(type='L1Loss', loss_weight=0.1), - gt_preprocessor=dict( - type='GTPixelPreprocessor', - mean=[123.675, 116.28, 103.53], - std=[58.395, 57.12, 57.375], - bgr_to_rgb=True, - pad_size_divisor=32, - element_name='img')), -) -# dataset settings -dataset_type = 'CocoDataset' -data_root = 'data/coco/' - -train_pipeline = [ - dict(type='LoadImageFromFile', backend_args={{_base_.backend_args}}), - dict( - type='lqit.LoadGTImageFromFile', backend_args={{_base_.backend_args}}), - dict(type='LoadAnnotations', with_bbox=True), - dict( - type='lqit.TransBroadcaster', - src_key='img', - dst_key='gt_img', - transforms=[ - dict(type='Resize', scale=(1333, 800), keep_ratio=True), - dict(type='RandomFlip', prob=0.5), - ]), - dict(type='lqit.PackInputs') -] -train_dataloader = dict( - batch_size=2, - num_workers=2, - dataset=dict( - _delete_=True, - type='lqit.DatasetWithGTImageWrapper', - dataset=dict( - type=dataset_type, - data_root=data_root, - ann_file='annotations/instances_train2017.json', - data_prefix=dict(img='train2017/', gt_img_path='train2017/'), - filter_cfg=dict(filter_empty_gt=True, min_size=32), - pipeline=train_pipeline))) - -optim_wrapper = dict( - type='OptimWrapper', - optimizer=dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)) diff --git a/configs/detection/duo_dataset/README.md b/configs/detection/duo_dataset/README.md index 589f424..82154b0 100644 --- a/configs/detection/duo_dataset/README.md +++ b/configs/detection/duo_dataset/README.md @@ -12,57 +12,19 @@ Underwater object detection for robot picking has attracted a lot of interest. H -**Note:** DUO contains URPC2020, the categories of both datasets are same. DUO introduced URPC2020 and other underwater object detection datasets in the paper. - -**TODO:** - -- [ ] Support DUO Dataset and release models. -- [ ] Unify Dataset name in `LQIT` - -## Results and Models - -### URPC2020 - -| Architecture | Backbone | Style | Lr schd | box AP | Config | Download | -| :-----------: | :---------: | :-----: | :-----: | :----: | :----------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | -| Faster R-CNN | R-50 | pytorch | 1x | 43.5 | [config](./faster-rcnn_r50_fpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/faster-rcnn_r50_fpn_1x_urpc-coco_20220226_105840-09ef8403.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/faster-rcnn_r50_fpn_1x_urpc-coco_20220226_105840.log.json) | -| Faster R-CNN | R-101 | pytorch | 1x | 44.8 | [config](./faster-rcnn_r101_fpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/faster-rcnn_r101_fpn_1x_urpc-coco_20220227_182523-de4a666c.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/faster-rcnn_r101_fpn_1x_urpc-coco_20220227_182523.log.json) | -| Faster R-CNN | X-101-32x4d | pytorch | 1x | 44.6 | [config](./faster-rcnn_x101-32x4d_fpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/faster-rcnn_x101-32x4d_fpn_1x_urpc-coco_20230511_190905-7074a9f7.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/faster-rcnn_x101-32x4d_fpn_1x_urpc-coco_20230511_190905.log.json) | -| Faster R-CNN | X-101-64x4d | pytorch | 1x | 45.3 | [config](./faster-rcnn_x101-64x4d_fpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/faster-rcnn_x101-64x4d_fpn_1x_urpc-coco_20220405_193758-5d2a37e4.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/faster-rcnn_x101-64x4d_fpn_1x_urpc-coco_20220405_193758.log.json) | -| Cascade R-CNN | R-50 | pytorch | 1x | 44.3 | [config](./cascade-rcnn_r50_fpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/cascade-rcnn_r50_fpn_1x_urpc-coco_20220405_160342-044e6858.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/cascade-rcnn_r50_fpn_1x_urpc-coco_20220405_160342.log.json) | -| RetinaNet | R-50 | pytorch | 1x | 40.7 | [config](./retinanet_r50_fpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/retinanet_r50_fpn_1x_urpc-coco_20220405_214951-a39f054e.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/retinanet_r50_fpn_1x_urpc-coco_20220405_214951.log.json) | -| FCOS | R-50 | caffe | 1x | 41.4 | [config](./fcos_r50-caffe_fpn_gn-head_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/fcos_r50-caffe_fpn_gn-head_1x_urpc-coco_20220227_204555-305ab6aa.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/fcos_r50-caffe_fpn_gn-head_1x_urpc-coco_20220227_204555.log.json) | -| ATSS | R-50 | pytorch | 1x | 44.8 | [config](./atss_r50_fpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/atss_r50_fpn_1x_urpc-coco_20220405_160345-cf776917.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/atss_r50_fpn_1x_urpc-coco_20220405_160345.log.json) | -| TOOD | R-50 | pytorch | 1x | 45.4 | [config](./tood_r50_fpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/tood_r50_fpn_1x_urpc-coco_20220405_164450-1fbf815b.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/tood_r50_fpn_1x_urpc-coco_20220405_164450.log.json) | -| SSD300 | VGG16 | - | 120e | 35.1 | [config](./ssd300_120e_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/ssd300_120e_urpc-coco_20230426_122625-b6f0b01e.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/ssd512_120e_urpc-coco_20220405_185511.log.json) | -| SSD512 | VGG16 | - | 120e | 38.6 | [config](./ssd300_120e_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/ssd512_120e_urpc-coco_20220405_185511-88c18764.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/ssd512_120e_urpc-coco_20220405_185511.log.json) | - -### DUO +## Results Coming soon ## Citation -- If you use `URPC2020` or other `URPC` series dataset in your research, please cite it as below: - - **Note:** The URL may not be valid, but this link is cited by many papers. - - ```latex - @online{urpc, - title = {Underwater Robot Professional Contest}, - url = {http://uodac.pcl.ac.cn/}, - } - ``` - -- If you use `DUO` dataset in your research, please cite it as below: - - ```latex - @inproceedings{liu2021dataset, - title={A dataset and benchmark of underwater object detection for robot picking}, - author={Liu, Chongwei and Li, Haojie and Wang, Shuchang and Zhu, Ming and Wang, Dong and Fan, Xin and Wang, Zhihui}, - booktitle={2021 IEEE International Conference on Multimedia \& Expo Workshops (ICMEW)}, - pages={1--6}, - year={2021}, - organization={IEEE} - } - ``` +```latex +@inproceedings{liu2021dataset, + title={A dataset and benchmark of underwater object detection for robot picking}, + author={Liu, Chongwei and Li, Haojie and Wang, Shuchang and Zhu, Ming and Wang, Dong and Fan, Xin and Wang, Zhihui}, + booktitle={2021 IEEE International Conference on Multimedia \& Expo Workshops (ICMEW)}, + pages={1--6}, + year={2021}, + organization={IEEE} +} +``` diff --git a/configs/detection/edffnet/README.md b/configs/detection/edffnet/README.md index dd25991..60abd22 100644 --- a/configs/detection/edffnet/README.md +++ b/configs/detection/edffnet/README.md @@ -1,18 +1,27 @@ # Edge-Guided Dynamic Feature Fusion Network for Object Detection under Foggy Conditions +> [Edge-Guided Dynamic Feature Fusion Network for Object Detection under Foggy Conditions](https://link.springer.com/article/10.1007/s11760-022-02410-0) + ## Abstract Hazy images are often subject to blurring, low contrast and other visible quality degradation, making it challenging to solve object detection tasks. Most methods solve the domain shift problem by deep domain adaptive technology, ignoring the inaccurate object classification and localization caused by quality degradation. Different from common methods, we present an edge-guided dynamic feature fusion network (EDFFNet), which formulates the edge head as a guide to the localization task. Despite the edge head being straightforward, we demonstrate that it makes the model pay attention to the edge of object instances and improves the generalization and localization ability of the network. Considering the fuzzy details and the multi-scale problem of hazy images, we propose a dynamic fusion feature pyramid network (DF-FPN) to enhance the feature representation ability of the whole model. A unique advantage of DF-FPN is that the contribution to the fused feature map will dynamically adjust with the learning of the network. Extensive experiments verify that EDFFNet achieves 2.4% AP and 3.6% AP gains over the ATSS baseline on RTTS and Foggy Cityscapes, respectively. + +
-## Results and Analysis +## Results on RTTS -Coming soon +| Architecture | Neck | Lr schd | Edge Head | lr | box AP | Config | Download | +| :----------: | :---: | :-----: | :-------: | :--: | :----: | :------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| ATSS | FPN | 1x | - | 0.01 | 48.2 | [config](../rtts_dataset/atss_r50_fpn_1x_rtts-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2/atss_r50_fpn_1x_rtts-coco_20231023_210916-98b5356b.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2/atss_r50_fpn_1x_rtts-coco_20231023_210916.log.json) | +| ATSS | FPN | 1x | - | 0.02 | 49.6 | [config](./atss_r50_fpn_1x_rtts-coco_lr002.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2/atss_r50_fpn_1x_rtts-coco_lr002_20231028_104029-114517ae.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2/atss_r50_fpn_1x_rtts-coco_lr002_20231028_104029.log.json) | +| ATSS | DFFPN | 1x | - | 0.02 | 50.3 | [config](./atss_r50_dffpn_1x_rtts-coco_lr002.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2/atss_r50_dffpn_1x_rtts-coco_lr002_20231028_104638-8f22abd9.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2/atss_r50_dffpn_1x_rtts-coco_lr002_20231028_104638.log.json) | +| ATSS | DFFPN | 1x | Y | 0.02 | 50.8 | [config](./edffnet_atss_r50_dffpn_1x_rtts-coco_lr002.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2/edffnet_atss_r50_dffpn_1x_rtts-coco_lr002_20231028_111154-89311078.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2/edffnet_atss_r50_dffpn_1x_rtts-coco_lr002_20231028_111154.log.json) | ## Citation diff --git a/configs/detection/edffnet/atss_r50_dffpn_1x_rtts-coco_lr002.py b/configs/detection/edffnet/atss_r50_dffpn_1x_rtts-coco_lr002.py new file mode 100644 index 0000000..8cdd85a --- /dev/null +++ b/configs/detection/edffnet/atss_r50_dffpn_1x_rtts-coco_lr002.py @@ -0,0 +1,12 @@ +_base_ = ['./atss_r50_fpn_1x_rtts-coco_lr002.py'] + +# model settings +model = dict( + neck=dict( + type='lqit.DFFPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs='on_input', + shape_level=2, + num_outs=5)) diff --git a/configs/detection/edffnet/atss_r50_fpn_1x_2xb8_rtts.py b/configs/detection/edffnet/atss_r50_fpn_1x_rtts-coco_lr002.py similarity index 88% rename from configs/detection/edffnet/atss_r50_fpn_1x_2xb8_rtts.py rename to configs/detection/edffnet/atss_r50_fpn_1x_rtts-coco_lr002.py index bf66e1e..34df95b 100644 --- a/configs/detection/edffnet/atss_r50_fpn_1x_2xb8_rtts.py +++ b/configs/detection/edffnet/atss_r50_fpn_1x_rtts-coco_lr002.py @@ -67,5 +67,19 @@ nms=dict(type='nms', iou_threshold=0.6), max_per_img=100)) +# optimizer optim_wrapper = dict( optimizer=dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)) + +param_scheduler = [ + dict( + type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, + end=1000), + dict( + type='MultiStepLR', + begin=0, + end=12, + by_epoch=True, + milestones=[8, 11], + gamma=0.1) +] diff --git a/configs/detection/edffnet/edffnet.py b/configs/detection/edffnet/edffnet_atss_r50_dffpn_1x_rtts-coco_lr002.py similarity index 55% rename from configs/detection/edffnet/edffnet.py rename to configs/detection/edffnet/edffnet_atss_r50_dffpn_1x_rtts-coco_lr002.py index 858f822..f02d74f 100644 --- a/configs/detection/edffnet/edffnet.py +++ b/configs/detection/edffnet/edffnet_atss_r50_dffpn_1x_rtts-coco_lr002.py @@ -1,17 +1,10 @@ -_base_ = '../edffnet/atss_r50_fpn_1x_2xb8_rtts.py' +_base_ = ['./atss_r50_dffpn_1x_rtts-coco_lr002.py'] model = dict( + _delete_=True, type='lqit.EDFFNet', - backbone=dict(norm_eval=True), - neck=dict( - type='lqit.DFFPN', - in_channels=[256, 512, 1024, 2048], - out_channels=256, - start_level=1, - add_extra_convs='on_input', - shape_level=2, - num_outs=5), - enhance_head=dict( + detector={{_base_.model}}, + edge_head=dict( _scope_='lqit', type='EdgeHead', in_channels=256, @@ -23,7 +16,8 @@ mean=[128], std=[57.12], pad_size_divisor=32, - element_name='edge'))) + element_name='edge')), + vis_enhance=False) # dataset settings train_pipeline = [ @@ -41,19 +35,3 @@ dict(type='lqit.PackInputs', ) ] train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) - -optim_wrapper = dict( - optimizer=dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)) - -param_scheduler = [ - dict( - type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, - end=1000), - dict( - type='MultiStepLR', - begin=0, - end=12, - by_epoch=True, - milestones=[8, 11], - gamma=0.1) -] diff --git a/configs/detection/rtts_dataset/README.md b/configs/detection/rtts_dataset/README.md index e69de29..ce8e58e 100644 --- a/configs/detection/rtts_dataset/README.md +++ b/configs/detection/rtts_dataset/README.md @@ -0,0 +1,40 @@ +# Benchmarking single-image dehazing and beyond + +> [Benchmarking single-image dehazing and beyond](https://ieeexplore.ieee.org/abstract/document/8451944) + + + +We present a comprehensive study and evaluation of existing single-image dehazing algorithms, using a new large-scale benchmark consisting of both synthetic and real-world hazy images, called REalistic Single-Image DEhazing (RESIDE). RESIDE highlights diverse data sources and image contents, and is divided into five subsets, each serving different training or evaluation purposes. We further provide a rich variety of criteria for dehazing algorithm evaluation, ranging from full-reference metrics to no-reference metrics and to subjective evaluation, and the novel task-driven evaluation. Experiments on RESIDE shed light on the comparisons and limitations of the state-of-the-art dehazing algorithms, and suggest promising future directions. + + + +
+ +
+ +## Results + +| Architecture | Backbone | Style | Lr schd | box AP | Config | Download | +| :-----------: | :------: | :-----: | :-----: | :----: | :----------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| Faster R-CNN | R-50 | pytorch | 1x | 48.1 | [config](./faster-rcnn_r50_fpn_1x_rtts-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2faster-rcnn_r50_fpn_1x_rtts-coco_20231023_211050-81f577b7.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2faster-rcnn_r50_fpn_1x_rtts-coco_20231023_211050.log.json) | +| Cascade R-CNN | R-50 | pytorch | 1x | 50.8 | [config](./cascade-rcnn_r50_fpn_1x_rtts-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2cascade-rcnn_r50_fpn_1x_rtts-coco_20231023_211029-ebfd7705.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2cascade-rcnn_r50_fpn_1x_rtts-coco_20231023_211029.log.json) | +| RetinaNet | R-50 | pytorch | 1x | 33.7 | [config](./retinanet_r50_fpn_1x_rtts-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2retinanet_r50_fpn_1x_rtts-coco_20231023_211252-594f407a.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2retinanet_r50_fpn_1x_rtts-coco_20231023_211252.log.json) | +| FCOS | R-50 | caffe | 1x | 41.0 | [config](./fcos_r50-caffe_fpn_gn-head_1x_rtts-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2fcos_r50-caffe_fpn_gn-head_1x_rtts-coco_20231023_211216-b7e2e105.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2fcos_r50-caffe_fpn_gn-head_1x_rtts-coco_20231023_211216.log.json) | +| ATSS | R-50 | pytorch | 1x | 48.2 | [config](./atss_r50_fpn_1x_rtts-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2atss_r50_fpn_1x_rtts-coco_20231023_210916-98b5356b.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2atss_r50_fpn_1x_rtts-coco_20231023_210916.log.json) | +| TOOD | R-50 | pytorch | 1x | 50.8 | [config](./tood_r50_fpn_1x_rtts-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2tood_r50_fpn_1x_rtts-coco_20231023_211348-6339a1f6.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2tood_r50_fpn_1x_rtts-coco_20231023_211348.log.json) | +| PAA | R-50 | pytorch | 1x | 49.3 | [config](./paa_r50_fpn_1x_rtts-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2paa_r50_fpn_1x_rtts-coco_20231024_001806-04ca4793.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2paa_r50_fpn_1x_rtts-coco_20231024_001806.log.json) | + +## Citation + +```latex +@article{li2018benchmarking, + title={Benchmarking single-image dehazing and beyond}, + author={Li, Boyi and Ren, Wenqi and Fu, Dengpan and Tao, Dacheng and Feng, Dan and Zeng, Wenjun and Wang, Zhangyang}, + journal={IEEE Transactions on Image Processing}, + volume={28}, + number={1}, + pages={492--505}, + year={2018}, + publisher={IEEE} +} +``` diff --git a/configs/detection/ruod_dataset/README.md b/configs/detection/ruod_dataset/README.md index 6783c3f..0a2e64f 100644 --- a/configs/detection/ruod_dataset/README.md +++ b/configs/detection/ruod_dataset/README.md @@ -28,7 +28,7 @@ In this paper, we conduct a comprehensive study of Underwater Object Detection ( ## Is Underwater Image Enhancement All Object Detectors Need? -TODO +Coming Soon ## Citation diff --git a/configs/detection/tienet/README.md b/configs/detection/tienet/README.md index 90d2d64..bc79c04 100644 --- a/configs/detection/tienet/README.md +++ b/configs/detection/tienet/README.md @@ -1,18 +1,46 @@ # TIENet: Task-oriented Image Enhancement Network for degraded object detection +> [TIENet: Task-oriented Image Enhancement Network for degraded object detection](https://link.springer.com/article/10.1007/s11760-023-02695-9) + ## Abstract Degraded images often suffer from low contrast, color deviations, and blurring details, which significantly affect the performance of detectors. Many previous works have attempted to obtain high-quality images based on human perception using image enhancement algorithms. However, these enhancement algorithms usually suppress the performance of degraded object detection. In this paper, we propose a taskoriented image enhancement network (TIENet) to directly improve degraded object detection’s performance by enhancing the degraded images. Unlike common human perception-based image-to-image methods, TIENet is a zero-reference enhancement network, which obtains a detectionfavorable structure image that is added to the original degraded image. In addition, this paper presents a fast Fourier transform-based structure loss for the enhancement task. With the new loss, our TIENet enables the structure image obtained to enhance more useful detection-favorable structural information and suppress irrelevant information. Extensive experiments and comprehensive evaluations on underwater (URPC2020) and foggy (RTTS) datasets show that our proposed framework can achieve 0.5–1.6% AP absolute improvements on classic detectors, including Faster R-CNN, RetinaNet, FCOS, ATSS, PAA, and TOOD. Besides, our method also generalizes well to the PASCAL VOC dataset, which can achieve 0.2–0.7% gains. We expect this study can draw more attention to high-level task-oriented degraded image enhancement. + +
-## Results and Analysis - -Coming soon +## Results + +### URPC2020 + +| Architecture | Lr schd | box AP | Config | Download | +| :-------------------: | :-----: | :----: | :-----------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| Faster R-CNN | 1x | 43.5 | [config](./base_detector/faster-rcnn_r50_fpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/faster-rcnn_r50_fpn_1x_urpc-coco_20220226_105840-09ef8403.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/faster-rcnn_r50_fpn_1x_urpc-coco_20220226_105840.log.json) | +| Faster R-CNN + TIENet | 1x | 44.3 | [config](./tienet_faster-rcnn_r50_fpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2/tienet_faster-rcnn_r50_fpn_1x_urpc-coco_20221121_003439-0eb8ea32.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2/tienet_faster-rcnn_r50_fpn_1x_urpc-coco_20221121_003439.log.json) | +| RetinaNet | 1x | 40.7 | [config](./base_detector/retinanet_r50_fpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/retinanet_r50_fpn_1x_urpc-coco_20220405_214951-a39f054e.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/retinanet_r50_fpn_1x_urpc-coco_20220405_214951.log.json) | +| RetinaNet + TIENet | 1x | 42.2 | [config](./tienet_retinanet_r50_fpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2/tienet_retinanet_r50_fpn_1x_urpc-coco_20221119_190211-2d1f311c.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2/tienet_retinanet_r50_fpn_1x_urpc-coco_20221119_190211.log.json) | +| ATSS | 1x | 44.8 | [config](./base_detector/atss_r50_fpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/atss_r50_fpn_1x_urpc-coco_20220405_160345-cf776917.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/atss_r50_fpn_1x_urpc-coco_20220405_160345.log.json) | +| ATSS + TIENet | 1x | 45.9 | [config](./tienet_atss_r50_fpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2/tienet_atss_r50_fpn_1x_urpc-coco_20230209_181359-473de7c1.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2/tienet_atss_r50_fpn_1x_urpc-coco_20230209_181359.log.json) | +| TOOD | 1x | 45.4 | [config](./base_detector/tood_r50_fpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/tood_r50_fpn_1x_urpc-coco_20220405_164450-1fbf815b.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/tood_r50_fpn_1x_urpc-coco_20220405_164450.log.json) | +| TOOD + TIENet | 1x | 46.7 | [config](./tienet_tood_r50_fpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2/tienet_tood_r50_fpn_1x_urpc-coco_20221119_212831-5dc036d5.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2/tienet_tood_r50_fpn_1x_urpc-coco_20221119_212831.log.json) | + +### RTTS + +| Architecture | Lr schd | box AP | Config | Download | +| :-------------------: | :-----: | :----: | :-----------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| Faster R-CNN | 1x | 48.1 | [config](./base_detector/faster-rcnn_r50_fpn_1x_rtts-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/untagged-cd79fcb6ab215a0cf240/faster-rcnn_r50_fpn_1x_rtts-coco_20231023_211050-81f577b7.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/untagged-cd79fcb6ab215a0cf240/faster-rcnn_r50_fpn_1x_rtts-coco_20231023_211050.log.json) | +| Faster R-CNN + TIENet | 1x | 49.2 | [config](./tienet_faster-rcnn_r50_fpn_1x_rtts-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2/tienet_faster-rcnn_r50_fpn_1x_rtts-coco_20221120_215748-50af5920.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2/tienet_faster-rcnn_r50_fpn_1x_rtts-coco_20221120_215748.log.json) | +| RetinaNet | 1x | 33.7 | [config](./base_detector/retinanet_r50_fpn_1x_rtts-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/untagged-cd79fcb6ab215a0cf240/retinanet_r50_fpn_1x_rtts-coco_20231023_211252-594f407a.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/untagged-cd79fcb6ab215a0cf240/retinanet_r50_fpn_1x_rtts-coco_20231023_211252.log.json) | +| RetinaNet + TIENet | 1x | 34.1 | [config](./tienet_retinanet_r50_fpn_1x_rtts-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2/tienet_retinanet_r50_fpn_1x_rtts-coco_20221204_213217-b43e333d.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2/tienet_retinanet_r50_fpn_1x_rtts-coco_20221204_213217.log.json) | +| ATSS | 1x | 48.2 | [config](./base_detector/atss_r50_fpn_1x_rtts-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/untagged-cd79fcb6ab215a0cf240/atss_r50_fpn_1x_rtts-coco_20231023_210916-98b5356b.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/untagged-cd79fcb6ab215a0cf240/atss_r50_fpn_1x_rtts-coco_20231023_210916.log.json) | +| ATSS + TIENet | 1x | 49.5 | [config](./tienet_atss_r50_fpn_1x_rtts-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2/tienet_atss_r50_fpn_1x_rtrs-coco_20221120_105748-ec573a04.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2/tienet_atss_r50_fpn_1x_rtrs-coco_20221120_105748.log.json) | +| TOOD | 1x | 50.8 | [config](./base_detector/tood_r50_fpn_1x_rtts-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/untagged-cd79fcb6ab215a0cf240/tood_r50_fpn_1x_rtts-coco_20231023_211348-6339a1f6.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/untagged-cd79fcb6ab215a0cf240/tood_r50_fpn_1x_rtts-coco_20231023_211348.log.json) | +| TOOD + TIENet | 1x | 52.1 | [config](./tienet_tood_r50_fpn_1x_rtts-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2/tienet_tood_r50_fpn_1x_rtts-coco_20221119_230205-e028a3bb.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2/tienet_tood_r50_fpn_1x_rtts-coco_20221119_230205.log.json) | ## Citation diff --git a/configs/detection/uod_air/README.md b/configs/detection/uod_air/README.md index 4748c8a..028bd8b 100644 --- a/configs/detection/uod_air/README.md +++ b/configs/detection/uod_air/README.md @@ -8,21 +8,23 @@ Underwater object detection plays an important role in a variety of marine applications. However, the complexity of the underwater environment (e.g. complex background) and the quality degradation problems (e.g. color deviation) significantly affect the performance of the deep learning-based detector. Many previous works tried to improve the underwater image quality by overcoming the degradation of underwater or designing stronger network structures to enhance the detector feature extraction ability to achieve a higher performance in underwater object detection. However, the former usually inhibits the performance of underwater object detection while the latter does not consider the gap between open-air and underwater domains. This paper presents a novel framework to combine underwater object detection with image reconstruction through a shared backbone and Feature Pyramid Network (FPN). The loss between the reconstructed image and the original image in the reconstruction task is used to make the shared structure have better generalization capability and adaptability to the underwater domain, which can improve the performance of underwater object detection. Moreover, to combine different level features more effectively, UNet-based FPN (UFPN) is proposed to integrate better semantic and texture information obtained from deep and shallow layers, respectively. Extensive experiments and comprehensive evaluation on the URPC2020 dataset show that our approach can lead to 1.4% mAP and 1.0% mAP absolute improvements on RetinaNet and Faster R-CNN baseline with negligible extra overhead. + +
-## Results and Analysis - -| Architecture | Neck | Lr schd | lr | box AP | Config | Download | -| :------------------------------------: | :--: | :-----: | :--: | :----: | :------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | -| Faster R-CNN | FPN | 1x | 0.02 | 43.5 | [config](./base_detector/faster-rcnn_r50_fpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/faster-rcnn_r50_fpn_1x_urpc-coco_20220226_105840-09ef8403.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/faster-rcnn_r50_fpn_1x_urpc-coco_20220226_105840.log.json) | -| Faster R-CNN | UFPN | 1x | 0.02 | 44.0 | [config](./base_detector/faster-rcnn_r50_ufpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/untagged-13f3dfa124d975df43f5/faster-rcnn_r50_ufpn_1x_urpc-coco_20231027_211425-61d901bb.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/untagged-13f3dfa124d975df43f5/faster-rcnn_r50_ufpn_1x_urpc-coco_20231027_211425.log.json) | -| Faster R-CNN with Image Reconstruction | UFPN | 1x | 0.02 | 44.3 | [config](./uod-air_faster-rcnn_r50_ufpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/untagged-13f3dfa124d975df43f5/uod-air_faster-rcnn_r50_ufpn_1x_urpc-coco_20231027_145407-6ae6d373.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/untagged-13f3dfa124d975df43f5/uod-air_faster-rcnn_r50_ufpn_1x_urpc-coco_20231027_145407.log.json) | -| RetinaNet | FPN | 1x | 0.01 | 40.7 | [config](./base_detector/retinanet_r50_fpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/retinanet_r50_fpn_1x_urpc-coco_20220405_214951-a39f054e.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/retinanet_r50_fpn_1x_urpc-coco_20220405_214951.log.json) | -| RetinaNet | UFPN | 1x | 0.01 | 41.8 | [config](./base_detector/retinanet_r50_ufpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/untagged-13f3dfa124d975df43f5/retinanet_r50_ufpn_1x_urpc-coco_20231027_215756-7803a5f9.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/untagged-13f3dfa124d975df43f5/retinanet_r50_ufpn_1x_urpc-coco_20231027_215756.log.json) | -| RetinaNet with Image Reconstruction | UFPN | 1x | 0.01 | 42.3 | [config](./uod-air_retinanet_r50_ufpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/untagged-13f3dfa124d975df43f5/uod-air_retinanet_r50_ufpn_1x_urpc-coco_20231027_224724-fe3acfba.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/untagged-13f3dfa124d975df43f5/uod-air_retinanet_r50_ufpn_1x_urpc-coco_20231027_224724.log.json) | -| RetinaNet with Image Reconstruction | UFPN | 1x | 0.02 | 43.3 | [config](./uod-air_retinanet_r50_ufpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/untagged-13f3dfa124d975df43f5/uod-air_retinanet_r50_ufpn_1x_urpc-coco_lr002_20231027_215752-b727baaf.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/untagged-13f3dfa124d975df43f5/uod-air_retinanet_r50_ufpn_1x_urpc-coco_lr002_20231027_215752.log.json) | +## Results + +| Architecture | Neck | Lr schd | lr | box AP | Config | Download | +| :------------------------------------: | :--: | :-----: | :--: | :----: | :------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| Faster R-CNN | FPN | 1x | 0.02 | 43.5 | [config](./base_detector/faster-rcnn_r50_fpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/faster-rcnn_r50_fpn_1x_urpc-coco_20220226_105840-09ef8403.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/faster-rcnn_r50_fpn_1x_urpc-coco_20220226_105840.log.json) | +| Faster R-CNN | UFPN | 1x | 0.02 | 44.0 | [config](./base_detector/faster-rcnn_r50_ufpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2/faster-rcnn_r50_ufpn_1x_urpc-coco_20231027_211425-61d901bb.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2/faster-rcnn_r50_ufpn_1x_urpc-coco_20231027_211425.log.json) | +| Faster R-CNN with Image Reconstruction | UFPN | 1x | 0.02 | 44.3 | [config](./uod-air_faster-rcnn_r50_ufpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2/uod-air_faster-rcnn_r50_ufpn_1x_urpc-coco_20231027_145407-6ae6d373.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2/uod-air_faster-rcnn_r50_ufpn_1x_urpc-coco_20231027_145407.log.json) | +| RetinaNet | FPN | 1x | 0.01 | 40.7 | [config](./base_detector/retinanet_r50_fpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/retinanet_r50_fpn_1x_urpc-coco_20220405_214951-a39f054e.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/retinanet_r50_fpn_1x_urpc-coco_20220405_214951.log.json) | +| RetinaNet | UFPN | 1x | 0.01 | 41.8 | [config](./base_detector/retinanet_r50_ufpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2/retinanet_r50_ufpn_1x_urpc-coco_20231027_215756-7803a5f9.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2/retinanet_r50_ufpn_1x_urpc-coco_20231027_215756.log.json) | +| RetinaNet with Image Reconstruction | UFPN | 1x | 0.01 | 42.3 | [config](./uod-air_retinanet_r50_ufpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2/uod-air_retinanet_r50_ufpn_1x_urpc-coco_20231027_224724-fe3acfba.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2/uod-air_retinanet_r50_ufpn_1x_urpc-coco_20231027_224724.log.json) | +| RetinaNet with Image Reconstruction | UFPN | 1x | 0.02 | 43.3 | [config](./uod-air_retinanet_r50_ufpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2/uod-air_retinanet_r50_ufpn_1x_urpc-coco_lr002_20231027_215752-b727baaf.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc2/uod-air_retinanet_r50_ufpn_1x_urpc-coco_lr002_20231027_215752.log.json) | **Note:** The original paper was developed based on MMDetection 2.0. LQIT optimized the network structure. LQIT has aligned the AP results on Faster R-CNN, but got 0.1 AP fluctuation on RetinaNet. diff --git a/configs/detection/urpc2020_dataset/README.md b/configs/detection/urpc2020_dataset/README.md index e074880..d9974ec 100644 --- a/configs/detection/urpc2020_dataset/README.md +++ b/configs/detection/urpc2020_dataset/README.md @@ -1,69 +1,48 @@ -# Detecting Underwater Objects +# Underwater Robot Professional Contest 2020 -> [Detecting Underwater Objects](https://arxiv.org/abs/2106.05681) > [Underwater Robot Professional Contest 2020](https://www.heywhale.com/home/competition/5e535a612537a0002ca864ac/content/0) +> +> [Datasets at OpenI](https://openi.pcl.ac.cn/OpenOrcinus_orca/URPC_opticalimage_dataset/datasets) -Underwater object detection for robot picking has attracted a lot of interest. However, it is still an unsolved problem due to several challenges. We take steps towards making it more realistic by addressing the following challenges. Firstly, the currently available datasets basically lack the test set annotations, causing researchers must compare their method with other SOTAs on a self-divided test set (from the training set). Training other methods lead to an increase in workload and different researchers divide different datasets, resulting there is no unified benchmark to compare the performance of different algorithms. Secondly, these datasets also have other shortcomings, e.g., too many similar images or incomplete labels. Towards these challenges we introduce a dataset, Detecting Underwater Objects (DUO), and a corresponding benchmark, based on the collection and re-annotation of all relevant datasets. DUO contains a collection of diverse underwater images with more rational annotations. The corresponding benchmark provides indicators of both efficiency and accuracy of SOTAs (under the MMDtection framework) for academic research and industrial applications, where JETSON AGX XAVIER is used to assess detector speed to simulate the robot-embedded environment. +The Object Detection Algorithm Competition is the first phase of the National Underwater Robotics (Zhanjiang) Competition jointly organized by the National Natural Science Foundation of China, Pengcheng Laboratory, and the People's Government of Zhanjiang. This competition focuses on the field of underwater object detection algorithms and innovatively combines artificial intelligence with underwater robots. It opens up optical and acoustic images of the real underwater environment to a wider community of artificial intelligence and algorithm researchers, establishing a new domain for object detection and recognition. The competition is divided into two categories: "Optical Image Object Detection" and "Acoustic Image Object Detection." This project has collected relevant datasets for the "Optical Image Object Detection" category, with the hope that these data will be of assistance to researchers in related fields.
- +
-**Note:** DUO contains URPC2020, the categories of both datasets are same. DUO introduced URPC2020 and other underwater object detection datasets in the paper. +## Results -**TODO:** +### Validation-set Results -- [ ] Support DUO Dataset and release models. -- [ ] Unify Dataset name in `LQIT` +| Architecture | Backbone | Style | Lr schd | box AP | Config | Download | +| :-----------: | :---------: | :-----: | :-----: | :----: | :---------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| Faster R-CNN | R-50 | pytorch | 1x | 43.5 | [config](./train_validation/faster-rcnn_r50_fpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/faster-rcnn_r50_fpn_1x_urpc-coco_20220226_105840-09ef8403.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/faster-rcnn_r50_fpn_1x_urpc-coco_20220226_105840.log.json) | +| Faster R-CNN | R-101 | pytorch | 1x | 44.8 | [config](./train_validation/faster-rcnn_r101_fpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/faster-rcnn_r101_fpn_1x_urpc-coco_20220227_182523-de4a666c.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/faster-rcnn_r101_fpn_1x_urpc-coco_20220227_182523.log.json) | +| Faster R-CNN | X-101-32x4d | pytorch | 1x | 44.6 | [config](./train_validation/faster-rcnn_x101-32x4d_fpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/faster-rcnn_x101-32x4d_fpn_1x_urpc-coco_20230511_190905-7074a9f7.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/faster-rcnn_x101-32x4d_fpn_1x_urpc-coco_20230511_190905.log.json) | +| Faster R-CNN | X-101-64x4d | pytorch | 1x | 45.3 | [config](./train_validation/faster-rcnn_x101-64x4d_fpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/faster-rcnn_x101-64x4d_fpn_1x_urpc-coco_20220405_193758-5d2a37e4.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/faster-rcnn_x101-64x4d_fpn_1x_urpc-coco_20220405_193758.log.json) | +| Cascade R-CNN | R-50 | pytorch | 1x | 44.3 | [config](./train_validation/cascade-rcnn_r50_fpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/cascade-rcnn_r50_fpn_1x_urpc-coco_20220405_160342-044e6858.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/cascade-rcnn_r50_fpn_1x_urpc-coco_20220405_160342.log.json) | +| RetinaNet | R-50 | pytorch | 1x | 40.7 | [config](./train_validation/retinanet_r50_fpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/retinanet_r50_fpn_1x_urpc-coco_20220405_214951-a39f054e.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/retinanet_r50_fpn_1x_urpc-coco_20220405_214951.log.json) | +| FCOS | R-50 | caffe | 1x | 41.4 | [config](./train_validation/fcos_r50-caffe_fpn_gn-head_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/fcos_r50-caffe_fpn_gn-head_1x_urpc-coco_20220227_204555-305ab6aa.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/fcos_r50-caffe_fpn_gn-head_1x_urpc-coco_20220227_204555.log.json) | +| ATSS | R-50 | pytorch | 1x | 44.8 | [config](./train_validation/atss_r50_fpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/atss_r50_fpn_1x_urpc-coco_20220405_160345-cf776917.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/atss_r50_fpn_1x_urpc-coco_20220405_160345.log.json) | +| TOOD | R-50 | pytorch | 1x | 45.4 | [config](./train_validation/tood_r50_fpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/tood_r50_fpn_1x_urpc-coco_20220405_164450-1fbf815b.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/tood_r50_fpn_1x_urpc-coco_20220405_164450.log.json) | +| SSD300 | VGG16 | - | 120e | 35.1 | [config](./train_validation/ssd300_120e_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/ssd300_120e_urpc-coco_20230426_122625-b6f0b01e.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/ssd512_120e_urpc-coco_20220405_185511.log.json) | +| SSD512 | VGG16 | - | 120e | 38.6 | [config](./train_validation/ssd300_120e_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/ssd512_120e_urpc-coco_20220405_185511-88c18764.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/ssd512_120e_urpc-coco_20220405_185511.log.json) | -## Results and Models - -### URPC2020 - -| Architecture | Backbone | Style | Lr schd | box AP | Config | Download | -| :-----------: | :---------: | :-----: | :-----: | :----: | :----------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | -| Faster R-CNN | R-50 | pytorch | 1x | 43.5 | [config](./faster-rcnn_r50_fpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/faster-rcnn_r50_fpn_1x_urpc-coco_20220226_105840-09ef8403.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/faster-rcnn_r50_fpn_1x_urpc-coco_20220226_105840.log.json) | -| Faster R-CNN | R-101 | pytorch | 1x | 44.8 | [config](./faster-rcnn_r101_fpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/faster-rcnn_r101_fpn_1x_urpc-coco_20220227_182523-de4a666c.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/faster-rcnn_r101_fpn_1x_urpc-coco_20220227_182523.log.json) | -| Faster R-CNN | X-101-32x4d | pytorch | 1x | 44.6 | [config](./faster-rcnn_x101-32x4d_fpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/faster-rcnn_x101-32x4d_fpn_1x_urpc-coco_20230511_190905-7074a9f7.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/faster-rcnn_x101-32x4d_fpn_1x_urpc-coco_20230511_190905.log.json) | -| Faster R-CNN | X-101-64x4d | pytorch | 1x | 45.3 | [config](./faster-rcnn_x101-64x4d_fpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/faster-rcnn_x101-64x4d_fpn_1x_urpc-coco_20220405_193758-5d2a37e4.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/faster-rcnn_x101-64x4d_fpn_1x_urpc-coco_20220405_193758.log.json) | -| Cascade R-CNN | R-50 | pytorch | 1x | 44.3 | [config](./cascade-rcnn_r50_fpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/cascade-rcnn_r50_fpn_1x_urpc-coco_20220405_160342-044e6858.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/cascade-rcnn_r50_fpn_1x_urpc-coco_20220405_160342.log.json) | -| RetinaNet | R-50 | pytorch | 1x | 40.7 | [config](./retinanet_r50_fpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/retinanet_r50_fpn_1x_urpc-coco_20220405_214951-a39f054e.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/retinanet_r50_fpn_1x_urpc-coco_20220405_214951.log.json) | -| FCOS | R-50 | caffe | 1x | 41.4 | [config](./fcos_r50-caffe_fpn_gn-head_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/fcos_r50-caffe_fpn_gn-head_1x_urpc-coco_20220227_204555-305ab6aa.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/fcos_r50-caffe_fpn_gn-head_1x_urpc-coco_20220227_204555.log.json) | -| ATSS | R-50 | pytorch | 1x | 44.8 | [config](./atss_r50_fpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/atss_r50_fpn_1x_urpc-coco_20220405_160345-cf776917.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/atss_r50_fpn_1x_urpc-coco_20220405_160345.log.json) | -| TOOD | R-50 | pytorch | 1x | 45.4 | [config](./tood_r50_fpn_1x_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/tood_r50_fpn_1x_urpc-coco_20220405_164450-1fbf815b.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/tood_r50_fpn_1x_urpc-coco_20220405_164450.log.json) | -| SSD300 | VGG16 | - | 120e | 35.1 | [config](./ssd300_120e_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/ssd300_120e_urpc-coco_20230426_122625-b6f0b01e.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/ssd512_120e_urpc-coco_20220405_185511.log.json) | -| SSD512 | VGG16 | - | 120e | 38.6 | [config](./ssd300_120e_urpc-coco.py) | [model](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/ssd512_120e_urpc-coco_20220405_185511-88c18764.pth) \| [log](https://github.com/BIGWangYuDong/lqit/releases/download/v0.0.1rc1/ssd512_120e_urpc-coco_20220405_185511.log.json) | - -### DUO +## Test-set Results Coming soon ## Citation -- If you use `URPC2020` or other `URPC` series dataset in your research, please cite it as below: - - **Note:** The URL may not be valid, but this link is cited by many papers. - - ```latex - @online{urpc, - title = {Underwater Robot Professional Contest}, - url = {http://uodac.pcl.ac.cn/}, - } - ``` - -- If you use `DUO` dataset in your research, please cite it as below: +**Note:** The URL may not be valid, but this link is cited by many papers. - ```latex - @inproceedings{liu2021dataset, - title={A dataset and benchmark of underwater object detection for robot picking}, - author={Liu, Chongwei and Li, Haojie and Wang, Shuchang and Zhu, Ming and Wang, Dong and Fan, Xin and Wang, Zhihui}, - booktitle={2021 IEEE International Conference on Multimedia \& Expo Workshops (ICMEW)}, - pages={1--6}, - year={2021}, - organization={IEEE} - } - ``` +```latex +@online{urpc, +title = {Underwater Robot Professional Contest}, +url = {http://uodac.pcl.ac.cn/}, +} +``` diff --git a/lqit/detection/models/detectors/__init__.py b/lqit/detection/models/detectors/__init__.py index b9fe038..8633783 100644 --- a/lqit/detection/models/detectors/__init__.py +++ b/lqit/detection/models/detectors/__init__.py @@ -2,11 +2,8 @@ from .detector_with_enhance_model import DetectorWithEnhanceModel from .edffnet import EDFFNet from .multi_input_wrapper import MultiInputDetectorWrapper -from .single_stage_enhance_head import SingleStageDetector -from .two_stage_enhance_head import TwoStageWithEnhanceHead __all__ = [ - 'TwoStageWithEnhanceHead', 'MultiInputDetectorWrapper', - 'SingleStageDetector', 'EDFFNet', 'DetectorWithEnhanceModel', + 'MultiInputDetectorWrapper', 'EDFFNet', 'DetectorWithEnhanceModel', 'DetectorWithEnhanceHead' ] diff --git a/lqit/detection/models/detectors/detector_with_enhance_head.py b/lqit/detection/models/detectors/detector_with_enhance_head.py index 5a99884..18bc7ba 100644 --- a/lqit/detection/models/detectors/detector_with_enhance_head.py +++ b/lqit/detection/models/detectors/detector_with_enhance_head.py @@ -1,6 +1,6 @@ import copy import warnings -from typing import Dict, Optional, Tuple, Union +from typing import Dict, Tuple, Union import torch from mmdet.models import SingleStageDetector, TwoStageDetector @@ -34,7 +34,7 @@ def __init__(self, detector: ConfigType, enhance_head: OptConfigType = None, process_gt_preprocessor: bool = True, - vis_enhance: Optional[bool] = False, + vis_enhance: bool = False, init_cfg: OptMultiConfig = None) -> None: super().__init__(init_cfg=init_cfg) diff --git a/lqit/detection/models/detectors/detector_with_enhance_model.py b/lqit/detection/models/detectors/detector_with_enhance_model.py index d10ecb4..d9647eb 100644 --- a/lqit/detection/models/detectors/detector_with_enhance_model.py +++ b/lqit/detection/models/detectors/detector_with_enhance_model.py @@ -1,5 +1,5 @@ import copy -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Tuple, Union from mmengine.model import BaseModel from mmengine.model.wrappers import MMDistributedDataParallel as MMDDP @@ -48,7 +48,7 @@ def __init__(self, detector: ConfigType, enhance_model: OptConfigType = None, loss_weight: list = [0.5, 0.5], - vis_enhance: Optional[bool] = False, + vis_enhance: bool = False, train_mode: str = 'enhance', pred_mode: str = 'enhance', detach_enhance_img: bool = False, diff --git a/lqit/detection/models/detectors/edffnet.py b/lqit/detection/models/detectors/edffnet.py index cb5c584..c13fbf2 100644 --- a/lqit/detection/models/detectors/edffnet.py +++ b/lqit/detection/models/detectors/edffnet.py @@ -1,31 +1,26 @@ -from typing import Optional - -from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig - from lqit.registry import MODELS -from .single_stage_enhance_head import SingleStageWithEnhanceHead +from lqit.utils import ConfigType, OptMultiConfig +from .detector_with_enhance_head import DetectorWithEnhanceHead @MODELS.register_module() -class EDFFNet(SingleStageWithEnhanceHead): +class EDFFNet(DetectorWithEnhanceHead): + """Implementation of EDFFNet. + + `_ + """ def __init__(self, - backbone: ConfigType, - neck: OptConfigType = None, - bbox_head: OptConfigType = None, - enhance_head: OptConfigType = None, - vis_enhance: Optional[bool] = False, - train_cfg: OptConfigType = None, - test_cfg: OptConfigType = None, - data_preprocessor: OptConfigType = None, + detector: ConfigType, + edge_head: ConfigType, + process_gt_preprocessor: bool = False, + vis_enhance: bool = False, init_cfg: OptMultiConfig = None) -> None: + assert not process_gt_preprocessor, \ + 'process_gt_preprocessor is not supported in EDFFNet' super().__init__( - backbone=backbone, - neck=neck, - bbox_head=bbox_head, - enhance_head=enhance_head, + detector=detector, + enhance_head=edge_head, + process_gt_preprocessor=process_gt_preprocessor, vis_enhance=vis_enhance, - train_cfg=train_cfg, - test_cfg=test_cfg, - data_preprocessor=data_preprocessor, init_cfg=init_cfg) diff --git a/lqit/detection/models/detectors/single_stage_enhance_head.py b/lqit/detection/models/detectors/single_stage_enhance_head.py deleted file mode 100644 index a86e1e3..0000000 --- a/lqit/detection/models/detectors/single_stage_enhance_head.py +++ /dev/null @@ -1,134 +0,0 @@ -from mmdet.models import SingleStageDetector -from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig -from torch import Tensor - -from lqit.common.structures import SampleList -from lqit.edit.models import add_pixel_pred_to_datasample -from lqit.registry import MODELS - - -@MODELS.register_module() -class SingleStageWithEnhanceHead(SingleStageDetector): - """Base class for single-stage detectors with enhance head. - - Two-stage detectors typically consisting of a region proposal network and a - task-specific regression head. - """ - - def __init__(self, - backbone: ConfigType, - neck: OptConfigType = None, - bbox_head: OptConfigType = None, - enhance_head: OptConfigType = None, - vis_enhance: bool = False, - train_cfg: OptConfigType = None, - test_cfg: OptConfigType = None, - data_preprocessor: OptConfigType = None, - init_cfg: OptMultiConfig = None) -> None: - super().__init__( - backbone=backbone, - neck=neck, - bbox_head=bbox_head, - train_cfg=train_cfg, - test_cfg=test_cfg, - data_preprocessor=data_preprocessor, - init_cfg=init_cfg) - - if enhance_head is not None: - self.enhance_head = MODELS.build(enhance_head) - self.vis_enhance = vis_enhance - - @property - def with_enhance_head(self) -> bool: - """bool: whether the detector has a RoI head""" - return hasattr(self, 'enhance_head') and self.enhance_head is not None - - def _forward(self, batch_inputs: Tensor, - batch_data_samples: SampleList) -> tuple: - """Network forward process. Usually includes backbone, neck and head - forward without any post-processing. - - Args: - batch_inputs (Tensor): Inputs with shape (N, C, H, W). - - Returns: - tuple: A tuple of features from ``rpn_head`` and ``roi_head`` - forward. - """ - x = self.extract_feat(batch_inputs) - results = self.bbox_head.forward(x) - if self.with_enhance_head: - enhance_outs = self.enhance_head.forward(x) - results = results + (enhance_outs, ) - return results - - def loss(self, batch_inputs: Tensor, - batch_data_samples: SampleList) -> dict: - """Calculate losses from a batch of inputs and data samples. - - Args: - batch_inputs (Tensor): Input images of shape (N, C, H, W). - These should usually be mean centered and std scaled. - batch_data_samples (List[:obj:`DetDataSample`]): The batch - data samples. It usually includes information such - as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. - - Returns: - dict: A dictionary of loss components - """ - x = self.extract_feat(batch_inputs) - - losses = dict() - if self.with_enhance_head: - - enhance_loss = self.enhance_head.loss(x, batch_data_samples) - # avoid loss override - assert not set(enhance_loss.keys()) & set(losses.keys()) - losses.update(enhance_loss) - - det_losses = self.bbox_head.loss(x, batch_data_samples) - losses.update(det_losses) - return losses - - def predict(self, - batch_inputs: Tensor, - batch_data_samples: SampleList, - rescale: bool = True) -> SampleList: - """Predict results from a batch of inputs and data samples with post- - processing. - - Args: - batch_inputs (Tensor): Inputs with shape (N, C, H, W). - batch_data_samples (List[:obj:`DetDataSample`]): The Data - Samples. It usually includes information such as - `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. - rescale (bool): Whether to rescale the results. - Defaults to True. - - Returns: - list[:obj:`DataSample`]: Return the detection results of the - input images. The returns value is DetDataSample, - which usually contain 'pred_instances'. And the - ``pred_instances`` usually contains following keys. - - - scores (Tensor): Classification scores, has a shape - (num_instance, ) - - labels (Tensor): Labels of bboxes, has a shape - (num_instances, ). - - bboxes (Tensor): Has a shape (num_instances, 4), - the last dimension 4 arrange as (x1, y1, x2, y2). - - masks (Tensor): Has a shape (num_instances, H, W). - """ - x = self.extract_feat(batch_inputs) - results_list = self.bbox_head.predict( - x, batch_data_samples, rescale=rescale) - - if self.vis_enhance and self.with_enhance_head: - enhance_list = self.enhance_head.predict( - x, batch_data_samples, rescale=rescale) - batch_data_samples = add_pixel_pred_to_datasample( - data_samples=batch_data_samples, pixel_list=enhance_list) - - batch_data_samples = self.add_pred_to_datasample( - batch_data_samples, results_list) - return batch_data_samples diff --git a/lqit/detection/models/detectors/two_stage_enhance_head.py b/lqit/detection/models/detectors/two_stage_enhance_head.py deleted file mode 100644 index 67b8fbe..0000000 --- a/lqit/detection/models/detectors/two_stage_enhance_head.py +++ /dev/null @@ -1,193 +0,0 @@ -import copy -from typing import Optional - -import torch -from mmdet.models import TwoStageDetector -from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig -from torch import Tensor - -from lqit.common.structures import SampleList -from lqit.edit.models import add_pixel_pred_to_datasample -from lqit.registry import MODELS - - -@MODELS.register_module() -class TwoStageWithEnhanceHead(TwoStageDetector): - """Base class for two-stage detectors with enhance head. - - Two-stage detectors typically consisting of a region proposal network and a - task-specific regression head. - """ - - def __init__(self, - backbone: ConfigType, - neck: OptConfigType = None, - rpn_head: OptConfigType = None, - roi_head: OptConfigType = None, - enhance_head: OptConfigType = None, - vis_enhance: Optional[bool] = False, - train_cfg: OptConfigType = None, - test_cfg: OptConfigType = None, - data_preprocessor: OptConfigType = None, - init_cfg: OptMultiConfig = None) -> None: - super().__init__( - backbone=backbone, - neck=neck, - rpn_head=rpn_head, - roi_head=roi_head, - train_cfg=train_cfg, - test_cfg=test_cfg, - data_preprocessor=data_preprocessor, - init_cfg=init_cfg) - - if enhance_head is not None: - self.enhance_head = MODELS.build(enhance_head) - self.vis_enhance = vis_enhance - - @property - def with_enhance_head(self) -> bool: - """bool: whether the detector has a RoI head""" - return hasattr(self, 'enhance_head') and self.enhance_head is not None - - def _forward(self, batch_inputs: Tensor, - batch_data_samples: SampleList) -> tuple: - """Network forward process. Usually includes backbone, neck and head - forward without any post-processing. - - Args: - batch_inputs (Tensor): Inputs with shape (N, C, H, W). - - Returns: - tuple: A tuple of features from ``rpn_head`` and ``roi_head`` - forward. - """ - results = () - x = self.extract_feat(batch_inputs) - - if self.with_rpn: - rpn_results_list = self.rpn_head.predict( - x, batch_data_samples, rescale=False) - else: - assert batch_data_samples[0].get('proposals', None) is not None - rpn_results_list = [ - data_sample.proposals for data_sample in batch_data_samples - ] - - if self.with_enhance_head: - enhance_outs = self.enhance_head.forward(x) - results = results + (enhance_outs, ) - - roi_outs = self.roi_head.forward(x, rpn_results_list) - results = results + (roi_outs, ) - return results - - def loss(self, batch_inputs: Tensor, - batch_data_samples: SampleList) -> dict: - """Calculate losses from a batch of inputs and data samples. - - Args: - batch_inputs (Tensor): Input images of shape (N, C, H, W). - These should usually be mean centered and std scaled. - batch_data_samples (List[:obj:`DetDataSample`]): The batch - data samples. It usually includes information such - as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. - - Returns: - dict: A dictionary of loss components - """ - x = self.extract_feat(batch_inputs) - - losses = dict() - - if self.with_enhance_head: - - enhance_loss = self.enhance_head.loss(x, batch_data_samples) - # avoid loss override - assert not set(enhance_loss.keys()) & set(losses.keys()) - losses.update(enhance_loss) - - # RPN forward and loss - if self.with_rpn: - proposal_cfg = self.train_cfg.get('rpn_proposal', - self.test_cfg.rpn) - rpn_data_samples = copy.deepcopy(batch_data_samples) - # set cat_id of gt_labels to 0 in RPN - for data_sample in rpn_data_samples: - data_sample.gt_instances.labels = \ - torch.zeros_like(data_sample.gt_instances.labels) - - rpn_losses, rpn_results_list = self.rpn_head.loss_and_predict( - x, rpn_data_samples, proposal_cfg=proposal_cfg) - # avoid get same name with roi_head loss - keys = rpn_losses.keys() - for key in keys: - if 'loss' in key and 'rpn' not in key: - rpn_losses[f'rpn_{key}'] = rpn_losses.pop(key) - losses.update(rpn_losses) - else: - # TODO: Not support currently, should have a check at Fast R-CNN - assert batch_data_samples[0].get('proposals', None) is not None - # use pre-defined proposals in InstanceData for the second stage - # to extract ROI features. - rpn_results_list = [ - data_sample.proposals for data_sample in batch_data_samples - ] - roi_losses = self.roi_head.loss(x, rpn_results_list, - batch_data_samples) - losses.update(roi_losses) - - return losses - - def predict(self, - batch_inputs: Tensor, - batch_data_samples: SampleList, - rescale: bool = True) -> SampleList: - """Predict results from a batch of inputs and data samples with post- - processing. - - Args: - batch_inputs (Tensor): Inputs with shape (N, C, H, W). - batch_data_samples (List[:obj:`DetDataSample`]): The Data - Samples. It usually includes information such as - `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. - rescale (bool): Whether to rescale the results. - Defaults to True. - - Returns: - list[:obj:`DataSample`]: Return the detection results of the - input images. The returns value is DetDataSample, - which usually contain 'pred_instances'. And the - ``pred_instances`` usually contains following keys. - - - scores (Tensor): Classification scores, has a shape - (num_instance, ) - - labels (Tensor): Labels of bboxes, has a shape - (num_instances, ). - - bboxes (Tensor): Has a shape (num_instances, 4), - the last dimension 4 arrange as (x1, y1, x2, y2). - - masks (Tensor): Has a shape (num_instances, H, W). - """ - assert self.with_bbox, 'Bbox head must be implemented.' - x = self.extract_feat(batch_inputs) - - # If there are no pre-defined proposals, use RPN to get proposals - if batch_data_samples[0].get('proposals', None) is None: - rpn_results_list = self.rpn_head.predict( - x, batch_data_samples, rescale=False) - else: - rpn_results_list = [ - data_sample.proposals for data_sample in batch_data_samples - ] - - if self.vis_enhance and self.with_enhance_head: - enhance_list = self.enhance_head.predict( - x, batch_data_samples, rescale=rescale) - batch_data_samples = add_pixel_pred_to_datasample( - data_samples=batch_data_samples, pixel_list=enhance_list) - - results_list = self.roi_head.predict( - x, rpn_results_list, batch_data_samples, rescale=rescale) - - batch_data_samples = self.add_pred_to_datasample( - batch_data_samples, results_list) - return batch_data_samples diff --git a/lqit/edit/models/editor_heads/edge_head.py b/lqit/edit/models/editor_heads/edge_head.py index f0b110f..63fa995 100644 --- a/lqit/edit/models/editor_heads/edge_head.py +++ b/lqit/edit/models/editor_heads/edge_head.py @@ -69,4 +69,4 @@ def loss_by_feat(self, enhance_img, gt_imgs, img_metas): reshape_gt_imgs = F.interpolate( gt_imgs, size=enhance_img.shape[-2:], mode='bilinear') enhance_loss = self.loss_enhance(enhance_img, reshape_gt_imgs) - return dict(loss_enhance=enhance_loss) + return dict(loss_edge=enhance_loss)