From c92b3b74b00fb432b7e7cabebbc6419ee90d4ba3 Mon Sep 17 00:00:00 2001 From: hewanru-bit <2558905977@qq.com> Date: Sat, 29 Oct 2022 15:11:01 +0800 Subject: [PATCH 01/10] [Feature] Add EDFFNet --- .../detection/_base_/datasets/rtts_coco.py | 57 +++++ configs/detection/edffnet/atss_r50_fpn_1x.py | 72 ++++++ configs/detection/edffnet/edffnet.py | 39 ++++ lqit/common/transforms/__init__.py | 3 +- lqit/common/transforms/get_edge.py | 220 ++++++++++++++++++ lqit/detection/datasets/__init__.py | 3 +- lqit/detection/datasets/rtts.py | 17 ++ lqit/detection/models/__init__.py | 1 + lqit/detection/models/detectors/__init__.py | 3 +- lqit/detection/models/detectors/edffnet.py | 31 +++ lqit/detection/models/necks/__init__.py | 4 + lqit/detection/models/necks/dffpn.py | 213 +++++++++++++++++ lqit/edit/models/enhance_heads/__init__.py | 4 +- lqit/edit/models/enhance_heads/edge_head.py | 73 ++++++ 14 files changed, 736 insertions(+), 4 deletions(-) create mode 100644 configs/detection/_base_/datasets/rtts_coco.py create mode 100644 configs/detection/edffnet/atss_r50_fpn_1x.py create mode 100644 configs/detection/edffnet/edffnet.py create mode 100644 lqit/common/transforms/get_edge.py create mode 100644 lqit/detection/datasets/rtts.py create mode 100644 lqit/detection/models/detectors/edffnet.py create mode 100644 lqit/detection/models/necks/__init__.py create mode 100644 lqit/detection/models/necks/dffpn.py create mode 100644 lqit/edit/models/enhance_heads/edge_head.py diff --git a/configs/detection/_base_/datasets/rtts_coco.py b/configs/detection/_base_/datasets/rtts_coco.py new file mode 100644 index 0000000..be1d319 --- /dev/null +++ b/configs/detection/_base_/datasets/rtts_coco.py @@ -0,0 +1,57 @@ +# dataset settings +dataset_type = 'RTTSCocoDataset' +data_root = 'data/RESIDE/' +file_client_args = dict(backend='disk') + +train_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='Resize', scale=(1333, 800), keep_ratio=True), + dict(type='RandomFlip', prob=0.5), + dict(type='PackDetInputs', ) +] +test_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict(type='Resize', scale=(1333, 800), keep_ratio=True), + # avoid bboxes being resized + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] + +train_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + batch_sampler=dict(type='AspectRatioBatchSampler'), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='RTTS/annotations_json/rtts_train.json', + data_prefix=dict(img='RTTS/'), + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=train_pipeline)) +val_dataloader = dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='RTTS/annotations_json/rtts_test.json', + data_prefix=dict(img='RTTS/'), + test_mode=True, + pipeline=test_pipeline)) +test_dataloader = val_dataloader + +val_evaluator = dict( + type='CocoMetric', + ann_file=data_root + 'RTTS/annotations_json/rtts_test.json', + metric='bbox', + format_only=False) +test_evaluator = val_evaluator diff --git a/configs/detection/edffnet/atss_r50_fpn_1x.py b/configs/detection/edffnet/atss_r50_fpn_1x.py new file mode 100644 index 0000000..94ed31f --- /dev/null +++ b/configs/detection/edffnet/atss_r50_fpn_1x.py @@ -0,0 +1,72 @@ +_base_ = [ + '../_base_/datasets/rtts_coco.py', '../_base_/schedules/schedule_1x.py', + '../_base_/default_runtime.py' +] + +# model settings +model = dict( + type='ATSS', + data_preprocessor=dict( + type='DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32), + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs='on_output', + num_outs=5), + bbox_head=dict( + type='ATSSHead', + num_classes=5, + in_channels=256, + stacked_convs=4, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + ratios=[1.0], + octave_base_scale=8, + scales_per_octave=1, + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[0.1, 0.1, 0.2, 0.2]), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=2.0), + loss_centerness=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)), + # training and testing settings + train_cfg=dict( + assigner=dict(type='ATSSAssigner', topk=9), + allowed_border=-1, + pos_weight=-1, + debug=False), + test_cfg=dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.6), + max_per_img=100)) + +# TODO: check the lr = 0.01 or 0.02 +optim_wrapper = dict( + optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)) diff --git a/configs/detection/edffnet/edffnet.py b/configs/detection/edffnet/edffnet.py new file mode 100644 index 0000000..f950973 --- /dev/null +++ b/configs/detection/edffnet/edffnet.py @@ -0,0 +1,39 @@ +_base_ = '../edffnet/atss_r50_fpn_1x.py' + +model = dict( + type='EDFFNet', + backbone=dict(norm_eval=False), + neck=dict( + type='DFFPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs='on_output', + shape_level=2, + num_outs=5), + enhance_head=dict( + type='lqit.EdgeHead', + in_channels=256, + feat_channels=256, + num_convs=5, + loss_enhance=dict(type='mmdet.L1Loss', loss_weight=0.7), + gt_preprocessor=dict( + type='lqit.GTPixelPreprocessor', + mean=[123.675], + std=[58.395], + pad_size_divisor=32, + element_name='edge')), +) + +# dataset settings +train_pipeline = [ + dict( + type='LoadImageFromFile', + file_client_args={{_base_.file_client_args}}), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='Resize', scale=(1333, 800), keep_ratio=True), + dict(type='RandomFlip', prob=0.5), + dict(type='lqit.GetEdgeGTFromImage', method='scharr'), + dict(type='lqit.PackInputs') +] +train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) diff --git a/lqit/common/transforms/__init__.py b/lqit/common/transforms/__init__.py index d5850d3..0785fba 100644 --- a/lqit/common/transforms/__init__.py +++ b/lqit/common/transforms/__init__.py @@ -1,8 +1,9 @@ from .formatting import PackInputs +from .get_edge import GetEdgeGTFromImage from .loading import LoadGTImageFromFile, SetInputImageAsGT from .wrapper import TransBroadcaster __all__ = [ 'PackInputs', 'LoadGTImageFromFile', 'TransBroadcaster', - 'SetInputImageAsGT' + 'SetInputImageAsGT', 'GetEdgeGTFromImage' ] diff --git a/lqit/common/transforms/get_edge.py b/lqit/common/transforms/get_edge.py new file mode 100644 index 0000000..03244f9 --- /dev/null +++ b/lqit/common/transforms/get_edge.py @@ -0,0 +1,220 @@ +from typing import List, Optional, Union + +import cv2 +import numpy as np +from mmcv.transforms import BaseTransform +from numpy import ndarray + +from lqit.registry import TRANSFORMS + + +@TRANSFORMS.register_module() +class GetEdgeGTFromImage(BaseTransform): + """Get the edge gt from image, and set it into results dict. + + Required Keys: + + - img + + Modified Keys: + + - edge_img + + Args: + method (str): The calculate edge method. Defaults to 'scharr'. + kernel_size: (List[int] or int) The gaussian blur kernel size. + Defaults to [3, 3]. + threshold_value (int) The threshold value which used in 'roberts', + 'prewitt', 'sobel', and 'laplacian'. Defaults to 127. + results_key (str): The name that going to save gt image in the results + dict. Defaults to 'gt_edge'. + + Note: + This transforms should add before `PackInputs`. Otherwise, some + transforms will change the `img` and do not change `gt_edge`. + """ + + def __init__(self, + method: str = 'scharr', + kernel_size: Union[List[int], int] = [3, 3], + threshold_value: int = 127, + results_key: str = 'gt_edge') -> None: + assert method in [ + 'roberts', 'prewitt', 'canny', 'scharr', 'sobel', 'laplacian', + 'log' + ] + if isinstance(kernel_size, int): + self.kernel_size = (kernel_size, kernel_size) + elif isinstance(kernel_size, (tuple, list)): + self.kernel_size = tuple(kernel_size) + else: + raise TypeError('kernel_size should be a list of int or int,' + f'but get {type(kernel_size)}') + self.threshold_value = threshold_value + self.method = method + self.results_key = results_key + + def transform(self, results: dict) -> Optional[dict]: + """Functions to get edge image from image. + + Args: + results (dict): Result dict from :obj:``mmcv.BaseDataset``. + + Returns: + dict: The dict contains loaded image and meta information. + """ + assert 'img' in results + img = results['img'] + get_edge_func = getattr(self, self.method) + edge_img = get_edge_func(img) + results[self.results_key] = edge_img + return results + + def _get_gaussian_blur(self, + img: ndarray, + threshold: bool = True) -> ndarray: + """Get gaussian blur of the image. + + Args: + img (np.ndarry): The image going to be blurred by gaussian. + threshold (bool): Whether to threshold the blurred image. + + Returns: + np.ndarry: The blurred image or the threshold image. + """ + gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + gaussian_blur = cv2.GaussianBlur(gray_image, self.kernel_size, 0) + if threshold: + _, binary = cv2.threshold(gaussian_blur, self.threshold_value, 255, + cv2.THRESH_BINARY) + return binary + else: + return gaussian_blur + + def roberts(self, img: ndarray) -> ndarray: + """Get image based on roberts. + + Args: + img (np.ndarry): The image going to get edge image. + + Returns: + np.ndarry: The edge image. + """ + binary = self._get_gaussian_blur(img=img, threshold=True) + + kernel_x = np.array([[-1, 0], [0, 1]], dtype=int) + kernel_y = np.array([[0, -1], [1, 0]], dtype=int) + x = cv2.filter2D(binary, cv2.CV_16S, kernel_x) + y = cv2.filter2D(binary, cv2.CV_16S, kernel_y) + abs_x = cv2.convertScaleAbs(x) + abs_y = cv2.convertScaleAbs(y) + edge = cv2.addWeighted(abs_x, 0.5, abs_y, 0.5, 0) + return edge + + def prewitt(self, img: ndarray) -> ndarray: + """Get image based on prewitt. + + Args: + img (np.ndarry): The image going to get edge image. + + Returns: + np.ndarry: The edge image. + """ + binary = self._get_gaussian_blur(img=img, threshold=True) + + kernel_x = np.array([[1, 1, 1], [0, 0, 0], [-1, -1, -1]], dtype=int) + kernel_y = np.array([[-1, 0, 1], [-1, 0, 1], [-1, 0, 1]], dtype=int) + x = cv2.filter2D(binary, cv2.CV_16S, kernel_x) + y = cv2.filter2D(binary, cv2.CV_16S, kernel_y) + abs_x = cv2.convertScaleAbs(x) + abs_y = cv2.convertScaleAbs(y) + edge = cv2.addWeighted(abs_x, 0.5, abs_y, 0.5, 0) + return edge + + def sobel(self, img: ndarray) -> ndarray: + """Get image based on sobel. + + Args: + img (np.ndarry): The image going to get edge image. + + Returns: + np.ndarry: The edge image. + """ + binary = self._get_gaussian_blur(img=img, threshold=True) + + x = cv2.Sobel(binary, cv2.CV_16S, 1, 0) + y = cv2.Sobel(binary, cv2.CV_16S, 0, 1) + abs_x = cv2.convertScaleAbs(x) + abs_y = cv2.convertScaleAbs(y) + edge = cv2.addWeighted(abs_x, 0.5, abs_y, 0.5, 0) + return edge + + def laplacian(self, img: ndarray) -> ndarray: + """Get image based on laplacian. + + Args: + img (np.ndarry): The image going to get edge image. + + Returns: + np.ndarry: The edge image. + """ + binary = self._get_gaussian_blur(img=img, threshold=True) + + dst = cv2.Laplacian(binary, cv2.CV_16S, ksize=3) + edge = cv2.convertScaleAbs(dst) + return edge + + def scharr(self, img: ndarray) -> ndarray: + """Get image based on scharr. + + Args: + img (np.ndarry): The image going to get edge image. + + Returns: + np.ndarry: The edge image. + """ + gaussian_blur = self._get_gaussian_blur(img=img, threshold=False) + + x = cv2.Scharr(gaussian_blur, cv2.CV_32F, 1, 0) + y = cv2.Scharr(gaussian_blur, cv2.CV_32F, 0, 1) + abs_x = cv2.convertScaleAbs(x) + abs_y = cv2.convertScaleAbs(y) + edge = cv2.addWeighted(abs_x, 0.5, abs_y, 0.5, 0) + return edge + + def canny(self, img: ndarray) -> ndarray: + """Get image based on canny. + + Args: + img (np.ndarry): The image going to get edge image. + + Returns: + np.ndarry: The edge image. + """ + gaussian_blur = self._get_gaussian_blur(img=img, threshold=False) + + edge = cv2.Canny(gaussian_blur, 50, 150) + return edge + + def log(self, img: ndarray) -> ndarray: + """Get image based on log. + + Args: + img (np.ndarry): The image going to get edge image. + + Returns: + np.ndarry: The edge image. + """ + gaussian_blur = self._get_gaussian_blur(img=img, threshold=False) + + dst = cv2.Laplacian(gaussian_blur, cv2.CV_16S, ksize=3) + edge = cv2.convertScaleAbs(dst) + return edge + + def __repr__(self): + repr_str = (f'{self.__class__.__name__}(' + f'method={self.method}, ' + f'threshold_value={self.threshold_value}, ' + f'method={self.method}, ' + f"results_key='{self.results_key}')") + return repr_str diff --git a/lqit/detection/datasets/__init__.py b/lqit/detection/datasets/__init__.py index b3b4cd9..1a736c5 100644 --- a/lqit/detection/datasets/__init__.py +++ b/lqit/detection/datasets/__init__.py @@ -1,3 +1,4 @@ +from .rtts import RTTSCocoDataset from .urpc import URPCCocoDataset, URPCXMLDataset -__all__ = ['URPCCocoDataset', 'URPCXMLDataset'] +__all__ = ['URPCCocoDataset', 'URPCXMLDataset', 'RTTSCocoDataset'] diff --git a/lqit/detection/datasets/rtts.py b/lqit/detection/datasets/rtts.py new file mode 100644 index 0000000..455b734 --- /dev/null +++ b/lqit/detection/datasets/rtts.py @@ -0,0 +1,17 @@ +from mmdet.datasets import CocoDataset +from mmdet.registry import DATASETS + +RTTS_METAINFO = { + 'CLASSES': ('bicycle', 'bus', 'car', 'motorbike', 'person'), + 'PALETTE': [(255, 97, 0), (0, 201, 87), (176, 23, 31), (138, 43, 226), + (30, 144, 255)] +} + + +@DATASETS.register_module() +class RTTSCocoDataset(CocoDataset): + """Foggy object detection dataset in RESIDE `RTSS. + + `_ + """ + METAINFO = RTTS_METAINFO diff --git a/lqit/detection/models/__init__.py b/lqit/detection/models/__init__.py index b104f1a..2519c03 100644 --- a/lqit/detection/models/__init__.py +++ b/lqit/detection/models/__init__.py @@ -1 +1,2 @@ from .detectors import * # noqa: F401,F403 +from .necks import * # noqa: F401,F403 diff --git a/lqit/detection/models/detectors/__init__.py b/lqit/detection/models/detectors/__init__.py index 58c8e25..21c50c9 100644 --- a/lqit/detection/models/detectors/__init__.py +++ b/lqit/detection/models/detectors/__init__.py @@ -1,8 +1,9 @@ +from .edffnet import EDFFNet from .multi_input_wrapper import MultiInputDetectorWrapper from .single_stage_enhance_head import SingleStageDetector from .two_stage_enhance_head import TwoStageWithEnhanceHead __all__ = [ 'TwoStageWithEnhanceHead', 'MultiInputDetectorWrapper', - 'SingleStageDetector' + 'SingleStageDetector', 'EDFFNet' ] diff --git a/lqit/detection/models/detectors/edffnet.py b/lqit/detection/models/detectors/edffnet.py new file mode 100644 index 0000000..161b985 --- /dev/null +++ b/lqit/detection/models/detectors/edffnet.py @@ -0,0 +1,31 @@ +from typing import Optional + +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig + +from .single_stage_enhance_head import SingleStageDetector + + +@MODELS.register_module() +class EDFFNet(SingleStageDetector): + + def __init__(self, + backbone: ConfigType, + neck: OptConfigType = None, + bbox_head: OptConfigType = None, + enhance_head: OptConfigType = None, + vis_enhance: Optional[bool] = False, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + bbox_head=bbox_head, + enhance_head=enhance_head, + vis_enhance=vis_enhance, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) diff --git a/lqit/detection/models/necks/__init__.py b/lqit/detection/models/necks/__init__.py new file mode 100644 index 0000000..3ee293e --- /dev/null +++ b/lqit/detection/models/necks/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .dffpn import DFFPN + +__all__ = ['DFFPN'] diff --git a/lqit/detection/models/necks/dffpn.py b/lqit/detection/models/necks/dffpn.py new file mode 100644 index 0000000..9e2af80 --- /dev/null +++ b/lqit/detection/models/necks/dffpn.py @@ -0,0 +1,213 @@ +import copy +from typing import List, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmdet.registry import MODELS +from mmdet.utils import OptConfigType +from mmengine.model import BaseModule + + +@MODELS.register_module() +class DFFPN(BaseModule): + """Dynamic feature fusion pyramid network.""" + + def __init__(self, + in_channels: List[int], + out_channels: int = 256, + num_outs: int = 5, + start_level: int = 0, + end_level: int = -1, + add_extra_convs: Union[bool, str] = False, + shape_level: int = 2, + relu_before_extra_convs: bool = False, + no_norm_on_lateral: bool = False, + conv_cfg: OptConfigType = None, + norm_cfg: OptConfigType = None, + act_cfg: OptConfigType = None, + upsample_cfg: OptConfigType = dict(mode='nearest'), + init_cfg: OptConfigType = dict( + type='Xavier', layer='Conv2d', distribution='uniform')) \ + -> None: + super().__init__(init_cfg=init_cfg) + assert isinstance(in_channels, list) + self.in_channels = in_channels + self.out_channels = out_channels + self.num_ins = len(in_channels) + self.num_outs = num_outs + self.shape_level = shape_level + self.no_norm_on_lateral = no_norm_on_lateral + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.upsample_cfg = upsample_cfg.copy() + self.relu_before_extra_convs = relu_before_extra_convs + + if end_level == -1: + self.backbone_end_level = self.num_ins + assert num_outs >= self.num_ins - start_level + else: + # if end_level < inputs, no extra level is allowed + self.backbone_end_level = end_level + assert end_level <= len(in_channels) + assert num_outs == end_level - start_level + self.start_level = start_level + self.end_level = end_level + self.add_extra_convs = add_extra_convs + assert isinstance(add_extra_convs, (str, bool)) + if isinstance(add_extra_convs, str): + # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output' + assert add_extra_convs in ('on_input', 'on_lateral', 'on_output') + elif add_extra_convs: # True + self.add_extra_convs = 'on_input' + + self.la1_convs = nn.ModuleList() + self.d_p = nn.ModuleList() + self.fpn_convs = nn.ModuleList() + self.fpl_convs = nn.ModuleList() + self.dff = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(self.out_channels * self.num_outs, self.num_outs, 1)) + for i in range(self.start_level, self.backbone_end_level): + l1_conv = ConvModule( + in_channels[i], + out_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg if not self.no_norm_on_lateral else None, + act_cfg=act_cfg, + inplace=False) + f2_conv = ConvModule( + out_channels, + out_channels, + 3, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + self.la1_convs.append(l1_conv) + self.fpn_convs.append(f2_conv) + + for j in range(self.num_outs): + fl_conv = ConvModule( + out_channels, + out_channels, + 3, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + dp_conv = ConvModule( + out_channels, + out_channels, + 3, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + self.fpl_convs.append(fl_conv) + self.d_p.append(dp_conv) + self.pooling = F.adaptive_avg_pool2d + + # add extra conv layers (e.g., RetinaNet) + extra_levels = num_outs - self.backbone_end_level + self.start_level + if self.add_extra_convs and extra_levels >= 1: + for i in range(extra_levels): + if i == 0 and self.add_extra_convs == 'on_input': + in_channels = self.in_channels[self.backbone_end_level - 1] + else: + in_channels = out_channels + extra_fpn_conv = ConvModule( + in_channels, + out_channels, + 3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + self.fpn_convs.append(extra_fpn_conv) + + def forward(self, inputs: tuple) -> tuple: + """Forward function.""" + assert len(inputs) == self.num_ins + + # 1. Unify channel through 1*1 Conv layer + lat = [ + la1_conv(inputs[i + self.start_level]) + for i, la1_conv in enumerate(self.la1_convs) + ] + + laterals = copy.copy(lat) + # 2. fpn up to down: + used_backbone_levels = len(lat) + for i in range(used_backbone_levels - 1, 0, -1): + prev_shape = lat[i - 1].shape[2:] + laterals[i - 1] = lat[i - 1] + F.interpolate( + lat[i], size=prev_shape, **self.upsample_cfg) + + laterals = [ + self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels) + ] + + # add extra layers + if self.num_outs > len(laterals): + # use max pool to get more levels on top of outputs + # (e.g., Faster R-CNN, Mask R-CNN) + if not self.add_extra_convs: + for i in range(self.num_outs - self.num_ins): + laterals.append(F.max_pool2d(laterals[-1], 1, stride=2)) + # add conv layers on top of original feature maps (RetinaNet) + else: + if self.add_extra_convs == 'on_input': + extra_source = inputs[self.backbone_end_level - 1] + elif self.add_extra_convs == 'on_lateral': + extra_source = laterals[-1] + elif self.add_extra_convs == 'on_output': + extra_source = laterals[-1] + else: + raise NotImplementedError + laterals.append( + self.fpn_convs[used_backbone_levels](extra_source)) + for i in range(used_backbone_levels + 1, self.num_outs): + if self.relu_before_extra_convs: + laterals.append(self.fpn_convs[i](F.relu( + laterals[-1]))) + else: + laterals.append(self.fpn_convs[i](laterals[-1])) + + # 3. pooling and concat + t_outs = [] + pool_shape = laterals[self.shape_level].size()[2:] + + for i in range(0, self.num_outs): + t_outs.append(self.pooling(laterals[i], pool_shape)) + + t_out = torch.cat(t_outs, dim=1) + # 4. get each feature map weights + ws = self.dff(t_out) + ws = torch.sigmoid(ws) + w = torch.split(ws, 1, dim=1) + + inner_outs = [] + + for i in range(0, self.num_outs): + inner_outs.append(laterals[i] * w[i]) + + for i in range(self.num_outs - 1): + prev_shape = inner_outs[i + 1].shape[2:] + inner_outs[i + 1] = inner_outs[i + 1] + F.interpolate( + inner_outs[i], size=prev_shape, **self.upsample_cfg) + + outs = [ + self.fpl_convs[i](inner_outs[i] + laterals[i]) + for i in range(self.num_outs) + ] + + return tuple(outs) diff --git a/lqit/edit/models/enhance_heads/__init__.py b/lqit/edit/models/enhance_heads/__init__.py index 6cc785a..de58211 100644 --- a/lqit/edit/models/enhance_heads/__init__.py +++ b/lqit/edit/models/enhance_heads/__init__.py @@ -1,5 +1,7 @@ from .basic_enhance_head import (BasicEnhanceHead, SingleEnhanceHead, UpSingleEnhanceHead) +from .edge_head import EdgeHead __all__ = [ - 'SingleEnhanceHead', 'UpSingleEnhanceHead', 'BasicEnhanceHead'] + 'SingleEnhanceHead', 'UpSingleEnhanceHead', 'BasicEnhanceHead', 'EdgeHead' +] diff --git a/lqit/edit/models/enhance_heads/edge_head.py b/lqit/edit/models/enhance_heads/edge_head.py new file mode 100644 index 0000000..866a5b9 --- /dev/null +++ b/lqit/edit/models/enhance_heads/edge_head.py @@ -0,0 +1,73 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule + +from lqit.registry import MODELS +from .base_enhance_head import BaseEnhanceHead + + +@MODELS.register_module() +class EdgeHead(BaseEnhanceHead): + """[conv+GN+relu]*4+1*1conv.""" + + def __init__(self, + in_channels=256, + feat_channels=256, + num_convs=5, + conv_cfg=None, + norm_cfg=dict(type='GN', num_groups=32, requires_grad=True), + act_cfg=dict(type='ReLU'), + gt_preprocessor=None, + loss_enhance=dict(type='mmdet.L1Loss', loss_weight=1.0), + init_cfg=dict(type='Normal', layer='Conv2d', std=0.01)): + super().__init__( + loss_enhance=loss_enhance, + gt_preprocessor=gt_preprocessor, + init_cfg=init_cfg) + self.in_channels = in_channels + self.feat_channels = feat_channels + self.num_convs = num_convs + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self._init_layers() + + def _init_layers(self): + assert self.num_convs > 0 + enhance_conv = [] + for i in range(self.num_convs): + in_channels = self.in_channels if i == 0 \ + else self.feat_channels + if i < (self.num_convs - 1): + enhance_conv.append( + ConvModule( + in_channels, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + else: + enhance_conv.append( + nn.Conv2d( + in_channels=in_channels, + out_channels=1, + kernel_size=1, + stride=1, + padding=1)) + self.enhance_conv = nn.Sequential(*enhance_conv) + + def forward(self, x): + if len(x) > 1 and (isinstance(x, tuple) or isinstance(x, list)): + x = x[0] + outs = self.enhance_conv(x) + return outs + + def loss_by_feat(self, enhance_img, gt_imgs, img_metas): + reshape_gt_imgs = F.interpolate( + gt_imgs, size=enhance_img.shape[-2:], mode='bilinear') + enhance_loss = self.loss_enhance(enhance_img, reshape_gt_imgs) + return dict(loss_enhance=enhance_loss) From 96b1f032e82b5ae4ad366dadabd8714725223583 Mon Sep 17 00:00:00 2001 From: hewanru-bit <2558905977@qq.com> Date: Sat, 5 Nov 2022 21:30:46 +0800 Subject: [PATCH 02/10] update --- configs/detection/edffnet/edffnet_new.py | 62 ++++++++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 configs/detection/edffnet/edffnet_new.py diff --git a/configs/detection/edffnet/edffnet_new.py b/configs/detection/edffnet/edffnet_new.py new file mode 100644 index 0000000..d0603b1 --- /dev/null +++ b/configs/detection/edffnet/edffnet_new.py @@ -0,0 +1,62 @@ +_base_ = '../edffnet/atss_r50_fpn_1x.py' + +model = dict( + type='EDFFNet', + # backbone=dict(norm_eval=False), + neck=dict( + type='DFFPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs='on_input', + shape_level=2, + num_outs=5), + enhance_head=dict( + type='lqit.EdgeHead', + in_channels=256, + feat_channels=256, + num_convs=5, + loss_enhance=dict(type='mmdet.L1Loss', loss_weight=0.7), + gt_preprocessor=dict( + type='lqit.GTPixelPreprocessor', + mean=[128], + std=[57.12], + pad_size_divisor=32, + element_name='edge')), +) + +# dataset settings +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='lqit.GetEdgeGTFromImage', method='scharr'), + dict( + type='lqit.TransBroadcaster', + src_key='img', + dst_key='gt_edge', + transforms=[ + dict(type='Resize', scale=(1333, 800), keep_ratio=True), + dict(type='RandomFlip', prob=0.5), + ]), + dict(type='lqit.PackInputs', ) +] +train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='SGD', lr=0.0025, momentum=0.9, weight_decay=0.0001)) +# learning rate +param_scheduler = [ + dict( + type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, + end=1000), + dict( + type='MultiStepLR', + begin=0, + end=12, + by_epoch=True, + milestones=[8, 11], + gamma=0.1) +] + +# load_from = '/home/test/data2/HWR/mmdet_works/edffnet_50.7.pth' From 1f3c3c8061c5f65d5e98b85e5b9e4c221760aa97 Mon Sep 17 00:00:00 2001 From: hewanru-bit <2558905977@qq.com> Date: Sat, 5 Nov 2022 21:59:34 +0800 Subject: [PATCH 03/10] [feature]AODNet --- .../detection/_base_/datasets/rtts_coco.py | 6 +- configs/detection/edffnet/edffnet.py | 6 +- configs/edit/aodnet/aodnet.py | 77 ++++++++++++ lqit/common/__init__.py | 4 +- lqit/common/dataset_wrappers.py | 114 +++++++++++++++++ .../detectors/single_stage_enhance_model.py | 116 ++++++++++++++++++ lqit/edit/models/editors/__init__.py | 1 + lqit/edit/models/editors/aodnet/__init__.py | 5 + lqit/edit/models/editors/aodnet/aodnet.py | 102 +++++++++++++++ .../models/editors/aodnet/aodnet_generator.py | 55 +++++++++ 10 files changed, 480 insertions(+), 6 deletions(-) create mode 100644 configs/edit/aodnet/aodnet.py create mode 100644 lqit/detection/models/detectors/single_stage_enhance_model.py create mode 100644 lqit/edit/models/editors/aodnet/__init__.py create mode 100644 lqit/edit/models/editors/aodnet/aodnet.py create mode 100644 lqit/edit/models/editors/aodnet/aodnet_generator.py diff --git a/configs/detection/_base_/datasets/rtts_coco.py b/configs/detection/_base_/datasets/rtts_coco.py index be1d319..e6eab80 100644 --- a/configs/detection/_base_/datasets/rtts_coco.py +++ b/configs/detection/_base_/datasets/rtts_coco.py @@ -1,18 +1,18 @@ # dataset settings dataset_type = 'RTTSCocoDataset' -data_root = 'data/RESIDE/' +data_root = '/home/tju531/hwr/Datasets/RESIDE/' file_client_args = dict(backend='disk') train_pipeline = [ dict(type='LoadImageFromFile', file_client_args=file_client_args), dict(type='LoadAnnotations', with_bbox=True), - dict(type='Resize', scale=(1333, 800), keep_ratio=True), + dict(type='Resize', scale=(256, 256), keep_ratio=True), dict(type='RandomFlip', prob=0.5), dict(type='PackDetInputs', ) ] test_pipeline = [ dict(type='LoadImageFromFile', file_client_args=file_client_args), - dict(type='Resize', scale=(1333, 800), keep_ratio=True), + dict(type='Resize', scale=(256, 256), keep_ratio=True), # avoid bboxes being resized dict(type='LoadAnnotations', with_bbox=True), dict( diff --git a/configs/detection/edffnet/edffnet.py b/configs/detection/edffnet/edffnet.py index f42aa1c..f648515 100644 --- a/configs/detection/edffnet/edffnet.py +++ b/configs/detection/edffnet/edffnet.py @@ -34,7 +34,7 @@ src_key='img', dst_key='gt_edge', transforms=[ - dict(type='Resize', scale=(1333, 800), keep_ratio=True), + dict(type='Resize', scale=(256, 256), keep_ratio=True), dict(type='RandomFlip', prob=0.5) ]), dict(type='lqit.PackInputs', ) @@ -53,3 +53,7 @@ milestones=[8, 11], gamma=0.1) ] + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='SGD', lr=0.0025, momentum=0.9, weight_decay=0.0001)) \ No newline at end of file diff --git a/configs/edit/aodnet/aodnet.py b/configs/edit/aodnet/aodnet.py new file mode 100644 index 0000000..de060ef --- /dev/null +++ b/configs/edit/aodnet/aodnet.py @@ -0,0 +1,77 @@ +_base_ = [ + '../../detection/_base_/schedules/schedule_1x.py', '../../detection/_base_/default_runtime.py' +] + +model = dict( + type='lqit.BaseEditModel', + data_preprocessor=dict( + type='lqit.EditDataPreprocessor', + mean=[0.0, 0.0, 0.0], + std=[255.0, 255.0, 255.0], + bgr_to_rgb=True, + pad_size_divisor=32, + gt_name='img'), + generator=dict( + _scope_='lqit', + type='AODNetGenerator', + aodnet=dict(type='AODNet'), + pixel_loss=dict(type='MSELoss', loss_weight=1.0) + )) + +# dataset settings +dataset_type = 'CityscapesDataset' +data_root = '/home/tju531/hwr/Datasets/' + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='lqit.LoadGTImageFromFile'), + dict( + type='lqit.TransBroadcaster', + src_key='img', + dst_key='gt_img', + transforms=[ + dict(type='Resize', scale=(256, 256), keep_ratio=True), + dict(type='RandomFlip', prob=0.5) + ]), + dict(type='lqit.PackInputs', ) +] + +train_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + batch_sampler=dict(type='AspectRatioBatchSampler'), + dataset=dict( + type='lqit.DatasetWithClearImageWrapper', + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='cityscape_foggy/annotations_json/instancesonly_filtered_gtFine_train.json', + data_prefix=dict(img='cityscape_foggy/train/', gt_img_path='cityscape/train/'), + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=train_pipeline), + suffix='png' + )) + +val_dataloader = None +val_cfg = None +test_cfg = None + +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=10, val_interval=1) +param_scheduler = [ + dict( + type='LinearLR', start_factor=0.0001, by_epoch=False, begin=0, + end=1000), + dict( + type='MultiStepLR', + begin=0, + end=10, + by_epoch=True, + milestones=[6, 9], + gamma=0.1) +] + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='SGD', lr=0.001, momentum=0.9, weight_decay=0.0001)) \ No newline at end of file diff --git a/lqit/common/__init__.py b/lqit/common/__init__.py index 7639b50..c4a9ae8 100644 --- a/lqit/common/__init__.py +++ b/lqit/common/__init__.py @@ -1,6 +1,6 @@ from .data_preprocessor import * # noqa: F401,F403 -from .dataset_wrappers import DatasetWithGTImageWrapper +from .dataset_wrappers import DatasetWithGTImageWrapper, DatasetWithClearImageWrapper from .structures import * # noqa: F401,F403 from .transforms import * # noqa: F401,F403 -__all__ = ['DatasetWithGTImageWrapper'] +__all__ = ['DatasetWithGTImageWrapper', 'DatasetWithClearImageWrapper'] diff --git a/lqit/common/dataset_wrappers.py b/lqit/common/dataset_wrappers.py index 2bd276d..f14f06f 100644 --- a/lqit/common/dataset_wrappers.py +++ b/lqit/common/dataset_wrappers.py @@ -118,3 +118,117 @@ def parse_gt_img_info(self, data_info: dict) -> Union[dict, List[dict]]: f'.{self.suffix}' data_info['gt_img_path'] = osp.join(gt_img_root, img_name) return data_info + + + +@DATASETS.register_module() +class DatasetWithClearImageWrapper: + """Dataset wrapper for image dehazing task. Add `gt_image_path` simultaneously. + + Args: + dataset (BaseDataset or dict): The dataset + suffix (str): gt_image suffix. Defaults to 'jpg'. + lazy_init (bool, optional): whether to load annotation during + instantiation. Defaults to False + """ + + def __init__(self, + dataset: Union[BaseDataset, dict], + suffix: str = 'jpg', + lazy_init: bool = False) -> None: + self.suffix = suffix + if isinstance(dataset, dict): + self.dataset = DATASETS.build(dataset) + elif isinstance(dataset, BaseDataset): + self.dataset = dataset + else: + raise TypeError( + 'elements in datasets sequence should be config or ' + f'`BaseDataset` instance, but got {type(dataset)}') + self._metainfo = self.dataset.metainfo + + self._fully_initialized = False + if not lazy_init: + self.full_init() + + @property + def metainfo(self) -> dict: + """Get the meta information of the repeated dataset. + + Returns: + dict: The meta information of repeated dataset. + """ + return self._metainfo + + def full_init(self): + self.dataset.full_init() + + def get_data_info(self, idx: int) -> dict: + return self.dataset.get_data_info(idx) + + def prepare_data(self, idx) -> Any: + """Get data processed by ``self.pipeline``. + + Args: + idx (int): The index of ``data_info``. + + Returns: + Any: Depends on ``self.pipeline``. + """ + data_info = self.get_data_info(idx) + data_info = self.parse_gt_img_info(data_info) + return self.dataset.pipeline(data_info) + + def __getitem__(self, idx): + if not self.dataset._fully_initialized: + warnings.warn( + 'Please call `full_init()` method manually to accelerate ' + 'the speed.') + self.dataset.full_init() + + if self.dataset.test_mode: + data = self.prepare_data(idx) + if data is None: + raise Exception('Test time pipline should not get `None` ' + 'data_sample') + return data + + for _ in range(self.dataset.max_refetch + 1): + data = self.prepare_data(idx) + # Broken images or random augmentations may cause the returned data + # to be None + if data is None: + idx = self.dataset._rand_another() + continue + return data + + raise Exception(f'Cannot find valid image after {self.max_refetch}! ' + 'Please check your image path and pipeline') + + def __len__(self): + return len(self.dataset) + + def parse_gt_img_info(self, data_info: dict) -> Union[dict, List[dict]]: + """Parse raw annotation to target format. + + Args: + raw_data_info (dict): Raw data information load from ``ann_file`` + + Returns: + Union[dict, List[dict]]: Parsed annotation. + """ + + gt_img_root = self.dataset.data_prefix.get('gt_img_path', None) + + if gt_img_root is None: + warnings.warn( + 'Cannot get gt_img_root, please set `gt_img_path` in ' + '`dataset.data_prefix`') + data_info['gt_img_path'] = data_info['img_path'] + else: + img_name = \ + f"{osp.split(data_info['img_path'])[0].split('/')[-1]}" + '/'\ + f"{osp.split(data_info['img_path'])[-1].split('_foggy_')[0]}" \ + f'.{self.suffix}' + data_info['gt_img_path'] = osp.join(gt_img_root, img_name) + return data_info diff --git a/lqit/detection/models/detectors/single_stage_enhance_model.py b/lqit/detection/models/detectors/single_stage_enhance_model.py new file mode 100644 index 0000000..4c92ee0 --- /dev/null +++ b/lqit/detection/models/detectors/single_stage_enhance_model.py @@ -0,0 +1,116 @@ +from typing import Optional + +from mmdet.models import SingleStageDetector +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from torch import Tensor + +from lqit.common.structures import SampleList +from lqit.edit.models import add_pixel_pred_to_datasample + + +@MODELS.register_module() +class SingleStageWithEnhanceModel(SingleStageDetector): + """Base class for one-stage detectors with enhance model. + + """ + + def __init__(self, + backbone: ConfigType, + enhance_model: OptConfigType = None, + neck: OptConfigType = None, + bbox_head: OptConfigType = None, + vis_enhance: Optional[bool] = False, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + bbox_head=bbox_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) + + if enhance_model is not None: + self.enhance_model = MODELS.build(enhance_model) + self.vis_enhance = vis_enhance + + @property + def with_enhance_model(self) -> bool: + """bool: whether the detector has a RoI head""" + return hasattr(self, 'enhance_model') and self.enhance_model is not None + + + def loss(self, batch_inputs: Tensor, + batch_data_samples: SampleList) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Args: + batch_inputs (Tensor): Input images of shape (N, C, H, W). + These should usually be mean centered and std scaled. + batch_data_samples (List[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + Returns: + dict: A dictionary of loss components + """ + losses = dict() + if self.with_enhance_model: + enhance_loss = self.enhance_model.loss(batch_inputs, batch_data_samples) + # avoid loss override + assert not set(enhance_loss.keys()) & set(losses.keys()) + losses.update(enhance_loss) + + x = self.extract_feat(batch_inputs) + + det_losses = self.bbox_head.loss(x, batch_data_samples) + losses.update(det_losses) + return losses + + def predict(self, + batch_inputs: Tensor, + batch_data_samples: SampleList, + rescale: bool = True) -> SampleList: + """Predict results from a batch of inputs and data samples with post- + processing. + + Args: + batch_inputs (Tensor): Inputs with shape (N, C, H, W). + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + rescale (bool): Whether to rescale the results. + Defaults to True. + + Returns: + list[:obj:`DataSample`]: Return the detection results of the + input images. The returns value is DetDataSample, + which usually contain 'pred_instances'. And the + ``pred_instances`` usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, H, W). + """ + + x = self.extract_feat(batch_inputs) + results_list = self.bbox_head.predict( + x, batch_data_samples, rescale=rescale) + + if self.vis_enhance and self.with_enhance_model: + enhance_list = self.enhance_model.predict( + x, batch_data_samples, rescale=rescale) + batch_data_samples = add_pixel_pred_to_datasample( + data_samples=batch_data_samples, pixel_list=enhance_list) + + batch_data_samples = self.add_pred_to_datasample( + batch_data_samples, results_list) + return batch_data_samples diff --git a/lqit/edit/models/editors/__init__.py b/lqit/edit/models/editors/__init__.py index dd58005..307d93e 100644 --- a/lqit/edit/models/editors/__init__.py +++ b/lqit/edit/models/editors/__init__.py @@ -1,2 +1,3 @@ from .unet import * # noqa: F401,F403 from .zero_dce import * # noqa: F401,F403 +from .aodnet import * # noqa: F401,F403 \ No newline at end of file diff --git a/lqit/edit/models/editors/aodnet/__init__.py b/lqit/edit/models/editors/aodnet/__init__.py new file mode 100644 index 0000000..1cd3826 --- /dev/null +++ b/lqit/edit/models/editors/aodnet/__init__.py @@ -0,0 +1,5 @@ +from .aodnet import AODNet +from .aodnet_generator import AODNetGenerator + + +__all__ = ['AODNet', 'AODNetGenerator'] diff --git a/lqit/edit/models/editors/aodnet/aodnet.py b/lqit/edit/models/editors/aodnet/aodnet.py new file mode 100644 index 0000000..86a731d --- /dev/null +++ b/lqit/edit/models/editors/aodnet/aodnet.py @@ -0,0 +1,102 @@ +import torch +import torch.nn as nn +import warnings +from mmcv.cnn import ConvModule +from mmcv.cnn.bricks.activation import build_activation_layer +from mmengine.model import BaseModule + +from lqit.registry import MODELS + + +@MODELS.register_module() +class AODNet(BaseModule): + """AOD-Net: All-in-One Dehazing Network.""" + + def __init__(self, + in_channels=(1, 1, 2, 2, 4), + base_channels=3, + out_channels=(3, 3, 3, 3, 3), + num_stages=5, + kernel_size=(1, 3, 5, 7, 3), + padding=(0, 1, 2, 3, 1), + act_cfg=dict(type='ReLU'), + plugins=None, + pretrained=None, + norm_eval=False, + init_cfg=[ + dict(type='Normal', layer='Conv2d', mean=0, std=0.02), + dict( + type='Normal', + layer='BatchNorm2d', + mean=1.0, + std=0.02, + bias=0), + ]): + super().__init__(init_cfg=init_cfg) + + self.pretrained = pretrained + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be setting at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is a deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ] + else: + raise TypeError('pretrained must be a str or None') + assert plugins is None, 'Not implemented yet.' + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.num_stages = num_stages + self.base_channels = base_channels + self.padding = padding + self.with_activation = act_cfg is not None + self.norm_eval = norm_eval + self.act_cfg = act_cfg + # build activation layer + if self.with_activation: + act_cfg_ = act_cfg.copy() + # nn.Tanh has no 'inplace' argument + if act_cfg_['type'] not in [ + 'Tanh', 'PReLU', 'Sigmoid', 'HSigmoid', 'Swish' + ]: + act_cfg_.setdefault('inplace', True) + self.activate = build_activation_layer(act_cfg_) + + self._init_layer() + + def _init_layer(self): + + self.CONVM = nn.ModuleList() + for i in range(self.num_stages): + conv_act = ConvModule( + in_channels=self.in_channels[i] * self.base_channels, out_channels=self.out_channels[i], + kernel_size=self.kernel_size[i], stride=1, padding=self.padding[i], bias=True, act_cfg=self.act_cfg) + self.CONVM.append(conv_act) + + + def forward(self, inputs): + outs = [] + x1 = inputs + for i in range(self.num_stages): + if i > 1 and i != (self.num_stages - 1): # from i=2 concat + x1 = torch.cat((outs[i - 2], outs[i - 1]), 1) + + if i == self.num_stages - 1: # last concat all + x1 = torch.cat([outs[j] for j in range(len(outs))], 1) + + x1 = self.CONVM[i](x1) + outs.append(x1) + result = self.activate((outs[-1] * inputs) - outs[-1] + 1) + + return result diff --git a/lqit/edit/models/editors/aodnet/aodnet_generator.py b/lqit/edit/models/editors/aodnet/aodnet_generator.py new file mode 100644 index 0000000..a7bf7d0 --- /dev/null +++ b/lqit/edit/models/editors/aodnet/aodnet_generator.py @@ -0,0 +1,55 @@ +from typing import List + +from lqit.edit.models.base_models import BaseGenerator +from lqit.edit.structures import BatchPixelData +from lqit.registry import MODELS +from lqit.utils.typing import ConfigType, OptMultiConfig + + +@MODELS.register_module() +class AODNetGenerator(BaseGenerator): + + def __init__(self, + aodnet: ConfigType, + pixel_loss: ConfigType = dict( + type='MSELoss', loss_weight=1.0), + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + pixel_loss=pixel_loss, + init_cfg=init_cfg) + + # build network + self.model = MODELS.build(aodnet) + + + def forward(self, x): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + return self.model(x) + + def loss(self, loss_input: BatchPixelData, batch_img_metas: List[dict]): + """Calculate the loss based on the outputs of generator.""" + batch_outputs = loss_input.output + batch_inputs = loss_input.input + batch_gt = loss_input.gt + + pixel_loss = self.pixel_loss(batch_outputs, batch_gt) + + losses = dict(pixel_loss=pixel_loss) + return losses + + def post_precess(self, outputs): + # ZeroDCE return enhance loss and curve at the same time. + assert outputs.dim() in [3, 4] + in_channels = self.model.in_channels + if outputs.dim() == 4: + enhance_img = outputs[:, :in_channels, :, :] + else: + enhance_img = outputs[:in_channels, :, :] + return enhance_img From 1e3c3ed2771a4394c948037c072d06bf08ea298f Mon Sep 17 00:00:00 2001 From: hewanru-bit <2558905977@qq.com> Date: Mon, 7 Nov 2022 17:03:26 +0800 Subject: [PATCH 04/10] up --- .../cityscape_enhancement_with_anno.py | 75 +++++++++++++++++++ .../datasets/cityscape_enhancet_with_txt.py | 73 ++++++++++++++++++ configs/edit/aodnet/aodnet.py | 44 +---------- 3 files changed, 151 insertions(+), 41 deletions(-) create mode 100644 configs/edit/_base_/datasets/cityscape_enhancement_with_anno.py create mode 100644 configs/edit/_base_/datasets/cityscape_enhancet_with_txt.py diff --git a/configs/edit/_base_/datasets/cityscape_enhancement_with_anno.py b/configs/edit/_base_/datasets/cityscape_enhancement_with_anno.py new file mode 100644 index 0000000..f5e2e2f --- /dev/null +++ b/configs/edit/_base_/datasets/cityscape_enhancement_with_anno.py @@ -0,0 +1,75 @@ +# dataset settings +dataset_type = 'mmdet.CityscapesDataset' +data_root = '/home/tju531/hwr/Datasets/' + +file_client_args = dict(backend='disk') + +train_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict(type='LoadGTImageFromFile', file_client_args=file_client_args), + dict( + type='TransBroadcaster', + src_key='img', + dst_key='gt_img', + transforms=[ + dict(type='Resize', scale=(512, 512), keep_ratio=True), + dict(type='RandomFlip', prob=0.5), + ]), + dict(type='PackInputs') +] +test_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict(type='LoadGTImageFromFile', file_client_args=file_client_args), + dict( + type='TransBroadcaster', + src_key='img', + dst_key='gt_img', + transforms=[dict(type='Resize', scale=(512, 512), keep_ratio=True)]), + dict( + type='PackInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] +train_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type='lqit.DatasetWithClearImageWrapper', + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='cityscape_foggy/annotations_json/instancesonly_filtered_gtFine_train.json', + data_prefix=dict(img='cityscape_foggy/train/', gt_img_path='cityscape/train/'), + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=train_pipeline), + suffix='png' + )) + +val_dataloader = dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type='lqit.DatasetWithClearImageWrapper', + dataset=dict( + type=dataset_type, + data_root=data_root, + test_mode=True, + indices=100, + ann_file='cityscape_foggy/annotations_json/instancesonly_filtered_gtFine_test.json', + data_prefix=dict(img='cityscape_foggy/test/', gt_img_path='cityscape/test/'), + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=test_pipeline), + suffix='png' + )) + +test_dataloader = val_dataloader + +val_evaluator = [ + dict(type='MSE', gt_key='img', pred_key='pred_img'), +] +test_evaluator = val_evaluator diff --git a/configs/edit/_base_/datasets/cityscape_enhancet_with_txt.py b/configs/edit/_base_/datasets/cityscape_enhancet_with_txt.py new file mode 100644 index 0000000..74581a7 --- /dev/null +++ b/configs/edit/_base_/datasets/cityscape_enhancet_with_txt.py @@ -0,0 +1,73 @@ +# dataset settings +dataset_type = 'BasicImageDataset' +data_root = '/home/tju531/hwr/Datasets/' + +file_client_args = dict(backend='disk') + +train_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict(type='LoadGTImageFromFile', file_client_args=file_client_args), + dict( + type='TransBroadcaster', + src_key='img', + dst_key='gt_img', + transforms=[ + dict(type='Resize', scale=(512, 512), keep_ratio=True), + dict(type='RandomFlip', prob=0.5), + ]), + dict(type='PackInputs') +] +test_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict(type='LoadGTImageFromFile', file_client_args=file_client_args), + dict( + type='TransBroadcaster', + src_key='img', + dst_key='gt_img', + transforms=[dict(type='Resize', scale=(512, 512), keep_ratio=True)]), + dict( + type='PackInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] +train_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type=dataset_type, + data_root=data_root, + metainfo=dict( + dataset_type='cityscape_enhancement', task_name='enhancement'), + ann_file='cityscape_foggy/train/train.txt', + data_prefix=dict(img='cityscape_foggy/train/', gt_img_path='cityscape/train/'), + search_key='img', + img_suffix=dict(img='png', gt_img='png'), + file_client_args=file_client_args, + pipeline=train_pipeline)) +val_dataloader = dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + test_mode=True, + indices=100, + metainfo=dict( + dataset_type='cityscape_enhancement', task_name='enhancement'), + ann_file='cityscape_foggy/test/test.txt', + data_prefix=dict(img='cityscape_foggy/test/', gt_img_path='cityscape/test/'), + search_key='img', + img_suffix=dict(img='png', gt_img='png'), + file_client_args=file_client_args, + pipeline=test_pipeline)) +test_dataloader = val_dataloader + +val_evaluator = [ + dict(type='MSE', gt_key='img', pred_key='pred_img'), +] +test_evaluator = val_evaluator diff --git a/configs/edit/aodnet/aodnet.py b/configs/edit/aodnet/aodnet.py index de060ef..c0b8eb8 100644 --- a/configs/edit/aodnet/aodnet.py +++ b/configs/edit/aodnet/aodnet.py @@ -1,7 +1,8 @@ _base_ = [ - '../../detection/_base_/schedules/schedule_1x.py', '../../detection/_base_/default_runtime.py' + '../_base_/datasets/cityscape_enhancement_with_txt.py', + '../_base_/schedules/schedule_1x.py', + '../_base_/default_runtime.py' ] - model = dict( type='lqit.BaseEditModel', data_preprocessor=dict( @@ -18,45 +19,6 @@ pixel_loss=dict(type='MSELoss', loss_weight=1.0) )) -# dataset settings -dataset_type = 'CityscapesDataset' -data_root = '/home/tju531/hwr/Datasets/' - -train_pipeline = [ - dict(type='LoadImageFromFile'), - dict(type='lqit.LoadGTImageFromFile'), - dict( - type='lqit.TransBroadcaster', - src_key='img', - dst_key='gt_img', - transforms=[ - dict(type='Resize', scale=(256, 256), keep_ratio=True), - dict(type='RandomFlip', prob=0.5) - ]), - dict(type='lqit.PackInputs', ) -] - -train_dataloader = dict( - batch_size=2, - num_workers=2, - persistent_workers=True, - sampler=dict(type='DefaultSampler', shuffle=True), - batch_sampler=dict(type='AspectRatioBatchSampler'), - dataset=dict( - type='lqit.DatasetWithClearImageWrapper', - dataset=dict( - type=dataset_type, - data_root=data_root, - ann_file='cityscape_foggy/annotations_json/instancesonly_filtered_gtFine_train.json', - data_prefix=dict(img='cityscape_foggy/train/', gt_img_path='cityscape/train/'), - filter_cfg=dict(filter_empty_gt=True, min_size=32), - pipeline=train_pipeline), - suffix='png' - )) - -val_dataloader = None -val_cfg = None -test_cfg = None train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=10, val_interval=1) param_scheduler = [ From 336bfdf34ee757a593f3483f846080be6050152c Mon Sep 17 00:00:00 2001 From: hewanru-bit <2558905977@qq.com> Date: Mon, 7 Nov 2022 19:20:37 +0800 Subject: [PATCH 05/10] [feature] add AODnet and support cityscape_foggy dataset --- .../cityscape_enhancement_with_anno.py | 2 +- ...t.py => cityscape_enhancement_with_txt.py} | 17 +-- configs/edit/aodnet/aodnet.py | 2 +- .../detectors/single_stage_enhance_model.py | 116 ------------------ lqit/edit/datasets/__init__.py | 3 +- lqit/edit/datasets/cityscape_foggy_dataset.py | 100 +++++++++++++++ .../models/editors/aodnet/aodnet_generator.py | 7 +- 7 files changed, 116 insertions(+), 131 deletions(-) rename configs/edit/_base_/datasets/{cityscape_enhancet_with_txt.py => cityscape_enhancement_with_txt.py} (85%) delete mode 100644 lqit/detection/models/detectors/single_stage_enhance_model.py create mode 100644 lqit/edit/datasets/cityscape_foggy_dataset.py diff --git a/configs/edit/_base_/datasets/cityscape_enhancement_with_anno.py b/configs/edit/_base_/datasets/cityscape_enhancement_with_anno.py index f5e2e2f..48ca2f7 100644 --- a/configs/edit/_base_/datasets/cityscape_enhancement_with_anno.py +++ b/configs/edit/_base_/datasets/cityscape_enhancement_with_anno.py @@ -1,6 +1,6 @@ # dataset settings dataset_type = 'mmdet.CityscapesDataset' -data_root = '/home/tju531/hwr/Datasets/' +data_root = 'data/Datasets/' file_client_args = dict(backend='disk') diff --git a/configs/edit/_base_/datasets/cityscape_enhancet_with_txt.py b/configs/edit/_base_/datasets/cityscape_enhancement_with_txt.py similarity index 85% rename from configs/edit/_base_/datasets/cityscape_enhancet_with_txt.py rename to configs/edit/_base_/datasets/cityscape_enhancement_with_txt.py index 74581a7..fdab24c 100644 --- a/configs/edit/_base_/datasets/cityscape_enhancet_with_txt.py +++ b/configs/edit/_base_/datasets/cityscape_enhancement_with_txt.py @@ -1,6 +1,6 @@ # dataset settings -dataset_type = 'BasicImageDataset' -data_root = '/home/tju531/hwr/Datasets/' +dataset_type = 'CityscapeFoggyImageDataset' +data_root = 'data/Datasets/' file_client_args = dict(backend='disk') @@ -41,11 +41,13 @@ metainfo=dict( dataset_type='cityscape_enhancement', task_name='enhancement'), ann_file='cityscape_foggy/train/train.txt', - data_prefix=dict(img='cityscape_foggy/train/', gt_img_path='cityscape/train/'), + data_prefix=dict(img='cityscape_foggy/train/', gt_img='cityscape/train/'), search_key='img', img_suffix=dict(img='png', gt_img='png'), file_client_args=file_client_args, - pipeline=train_pipeline)) + pipeline=train_pipeline, + split_str='_foggy' + )) val_dataloader = dict( batch_size=1, num_workers=2, @@ -56,15 +58,16 @@ type=dataset_type, data_root=data_root, test_mode=True, - indices=100, metainfo=dict( dataset_type='cityscape_enhancement', task_name='enhancement'), ann_file='cityscape_foggy/test/test.txt', - data_prefix=dict(img='cityscape_foggy/test/', gt_img_path='cityscape/test/'), + data_prefix=dict(img='cityscape_foggy/test/', gt_img='cityscape/test/'), search_key='img', img_suffix=dict(img='png', gt_img='png'), file_client_args=file_client_args, - pipeline=test_pipeline)) + pipeline=test_pipeline, + split_str='_foggy' + )) test_dataloader = val_dataloader val_evaluator = [ diff --git a/configs/edit/aodnet/aodnet.py b/configs/edit/aodnet/aodnet.py index c0b8eb8..3ed4ac0 100644 --- a/configs/edit/aodnet/aodnet.py +++ b/configs/edit/aodnet/aodnet.py @@ -15,7 +15,7 @@ generator=dict( _scope_='lqit', type='AODNetGenerator', - aodnet=dict(type='AODNet'), + model=dict(type='AODNet'), pixel_loss=dict(type='MSELoss', loss_weight=1.0) )) diff --git a/lqit/detection/models/detectors/single_stage_enhance_model.py b/lqit/detection/models/detectors/single_stage_enhance_model.py deleted file mode 100644 index 4c92ee0..0000000 --- a/lqit/detection/models/detectors/single_stage_enhance_model.py +++ /dev/null @@ -1,116 +0,0 @@ -from typing import Optional - -from mmdet.models import SingleStageDetector -from mmdet.registry import MODELS -from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig -from torch import Tensor - -from lqit.common.structures import SampleList -from lqit.edit.models import add_pixel_pred_to_datasample - - -@MODELS.register_module() -class SingleStageWithEnhanceModel(SingleStageDetector): - """Base class for one-stage detectors with enhance model. - - """ - - def __init__(self, - backbone: ConfigType, - enhance_model: OptConfigType = None, - neck: OptConfigType = None, - bbox_head: OptConfigType = None, - vis_enhance: Optional[bool] = False, - train_cfg: OptConfigType = None, - test_cfg: OptConfigType = None, - data_preprocessor: OptConfigType = None, - init_cfg: OptMultiConfig = None) -> None: - super().__init__( - backbone=backbone, - neck=neck, - bbox_head=bbox_head, - train_cfg=train_cfg, - test_cfg=test_cfg, - data_preprocessor=data_preprocessor, - init_cfg=init_cfg) - - if enhance_model is not None: - self.enhance_model = MODELS.build(enhance_model) - self.vis_enhance = vis_enhance - - @property - def with_enhance_model(self) -> bool: - """bool: whether the detector has a RoI head""" - return hasattr(self, 'enhance_model') and self.enhance_model is not None - - - def loss(self, batch_inputs: Tensor, - batch_data_samples: SampleList) -> dict: - """Calculate losses from a batch of inputs and data samples. - - Args: - batch_inputs (Tensor): Input images of shape (N, C, H, W). - These should usually be mean centered and std scaled. - batch_data_samples (List[:obj:`DetDataSample`]): The batch - data samples. It usually includes information such - as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. - - Returns: - dict: A dictionary of loss components - """ - losses = dict() - if self.with_enhance_model: - enhance_loss = self.enhance_model.loss(batch_inputs, batch_data_samples) - # avoid loss override - assert not set(enhance_loss.keys()) & set(losses.keys()) - losses.update(enhance_loss) - - x = self.extract_feat(batch_inputs) - - det_losses = self.bbox_head.loss(x, batch_data_samples) - losses.update(det_losses) - return losses - - def predict(self, - batch_inputs: Tensor, - batch_data_samples: SampleList, - rescale: bool = True) -> SampleList: - """Predict results from a batch of inputs and data samples with post- - processing. - - Args: - batch_inputs (Tensor): Inputs with shape (N, C, H, W). - batch_data_samples (List[:obj:`DetDataSample`]): The Data - Samples. It usually includes information such as - `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. - rescale (bool): Whether to rescale the results. - Defaults to True. - - Returns: - list[:obj:`DataSample`]: Return the detection results of the - input images. The returns value is DetDataSample, - which usually contain 'pred_instances'. And the - ``pred_instances`` usually contains following keys. - - - scores (Tensor): Classification scores, has a shape - (num_instance, ) - - labels (Tensor): Labels of bboxes, has a shape - (num_instances, ). - - bboxes (Tensor): Has a shape (num_instances, 4), - the last dimension 4 arrange as (x1, y1, x2, y2). - - masks (Tensor): Has a shape (num_instances, H, W). - """ - - x = self.extract_feat(batch_inputs) - results_list = self.bbox_head.predict( - x, batch_data_samples, rescale=rescale) - - if self.vis_enhance and self.with_enhance_model: - enhance_list = self.enhance_model.predict( - x, batch_data_samples, rescale=rescale) - batch_data_samples = add_pixel_pred_to_datasample( - data_samples=batch_data_samples, pixel_list=enhance_list) - - batch_data_samples = self.add_pred_to_datasample( - batch_data_samples, results_list) - return batch_data_samples diff --git a/lqit/edit/datasets/__init__.py b/lqit/edit/datasets/__init__.py index 156046e..f2878d7 100644 --- a/lqit/edit/datasets/__init__.py +++ b/lqit/edit/datasets/__init__.py @@ -1,3 +1,4 @@ from .basic_image_dataset import BasicImageDataset +from .cityscape_foggy_dataset import CityscapeFoggyImageDataset -__all__ = ['BasicImageDataset'] +__all__ = ['BasicImageDataset', 'CityscapeFoggyImageDataset'] diff --git a/lqit/edit/datasets/cityscape_foggy_dataset.py b/lqit/edit/datasets/cityscape_foggy_dataset.py new file mode 100644 index 0000000..9a56e01 --- /dev/null +++ b/lqit/edit/datasets/cityscape_foggy_dataset.py @@ -0,0 +1,100 @@ +# Modified from https://github.com/open-mmlab/mmediting/tree/1.x/ +import warnings +import os.path as osp +from typing import Any, Callable, List, Optional, Union + +from .basic_image_dataset import BasicImageDataset + +from lqit.registry import DATASETS + + +@DATASETS.register_module() +class CityscapeFoggyImageDataset(BasicImageDataset): + """CityscapeFoggyImageDataset for pixel-level vision tasks that have aligned gts, + such as image dehaze using cityscape and cityscape foggy datasets. + + Args: + ann_file (str): Annotation file path. Defaults to ''. + metainfo (dict, optional): Meta information for dataset, such as class + information. Defaults to None. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Defaults to None. + data_prefix (dict, optional): Prefix for data. Defaults to + dict(img=''). + mapping_table (dict): Mapping table for data. + Defaults to dict(). + pipeline (list, optional): Processing pipeline. Defaults to []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Defaults to False. + search_key (str): The key used for searching the folder to get + data_list. Defaults to 'gt'. + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`mmengine.fileio.FileClient` for details. + Defaults to dict(backend='disk'). + img_suffix (str or dict[str]): Image suffix that we are interested in. + Defaults to jpg. + recursive (bool): If set to True, recursively scan the + directory. Defaults to False. + split_str (str): split string that used to split image name to gt image name. + Defaults to '_foggy'. + """ + + def __init__(self, + ann_file: str = '', + metainfo: Optional[dict] = None, + data_root: Optional[str] = None, + data_prefix: dict = dict(img=''), + mapping_table: dict = dict(), + pipeline: List[Union[dict, Callable]] = [], + test_mode: bool = False, + search_key: Optional[str] = None, + file_client_args: dict = dict(backend='disk'), + img_suffix: Union[str, dict] = 'jpg', + recursive: bool = False, + split_str: str = '_foggy', + **kwards): + + self.split_str = split_str + + super().__init__( + ann_file=ann_file, + metainfo=metainfo, + data_root=data_root, + data_prefix=data_prefix, + mapping_table=mapping_table, + pipeline=pipeline, + test_mode=test_mode, + search_key=search_key, + file_client_args=file_client_args, + img_suffix=img_suffix, + recursive=recursive, + **kwards) + + + def load_data_list(self) -> List[dict]: + """Load data list from folder or annotation file. + + Returns: + list[dict]: A list of annotation. + """ + + img_ids = self._get_img_list() + + data_list = [] + # deal with img and gt img path + for img_id in img_ids: + data = dict(key=img_id) + data['img_id'] = img_id + for key in self.data_prefix: + img_id = self.mapping_table[key].format(img_id) + + if key == 'gt_img': + img_id = img_id.split(self.split_str)[0] + + path = osp.join(self.data_prefix[key], + f'{img_id}.{self.img_suffix[key]}') + data[f'{key}_path'] = path + data_list.append(data) + return data_list + + diff --git a/lqit/edit/models/editors/aodnet/aodnet_generator.py b/lqit/edit/models/editors/aodnet/aodnet_generator.py index a7bf7d0..7eb023e 100644 --- a/lqit/edit/models/editors/aodnet/aodnet_generator.py +++ b/lqit/edit/models/editors/aodnet/aodnet_generator.py @@ -10,18 +10,15 @@ class AODNetGenerator(BaseGenerator): def __init__(self, - aodnet: ConfigType, + model: ConfigType, pixel_loss: ConfigType = dict( type='MSELoss', loss_weight=1.0), init_cfg: OptMultiConfig = None) -> None: super().__init__( + model=model, pixel_loss=pixel_loss, init_cfg=init_cfg) - # build network - self.model = MODELS.build(aodnet) - - def forward(self, x): """Forward function. From 3ede0b36dc62f57d9522931a96de3df9f7bd5f77 Mon Sep 17 00:00:00 2001 From: hewanru-bit <2558905977@qq.com> Date: Mon, 7 Nov 2022 19:28:17 +0800 Subject: [PATCH 06/10] [Feature]add AODNet and support cityscape foggy dataset --- configs/detection/_base_/datasets/rtts_coco.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/configs/detection/_base_/datasets/rtts_coco.py b/configs/detection/_base_/datasets/rtts_coco.py index e6eab80..be1d319 100644 --- a/configs/detection/_base_/datasets/rtts_coco.py +++ b/configs/detection/_base_/datasets/rtts_coco.py @@ -1,18 +1,18 @@ # dataset settings dataset_type = 'RTTSCocoDataset' -data_root = '/home/tju531/hwr/Datasets/RESIDE/' +data_root = 'data/RESIDE/' file_client_args = dict(backend='disk') train_pipeline = [ dict(type='LoadImageFromFile', file_client_args=file_client_args), dict(type='LoadAnnotations', with_bbox=True), - dict(type='Resize', scale=(256, 256), keep_ratio=True), + dict(type='Resize', scale=(1333, 800), keep_ratio=True), dict(type='RandomFlip', prob=0.5), dict(type='PackDetInputs', ) ] test_pipeline = [ dict(type='LoadImageFromFile', file_client_args=file_client_args), - dict(type='Resize', scale=(256, 256), keep_ratio=True), + dict(type='Resize', scale=(1333, 800), keep_ratio=True), # avoid bboxes being resized dict(type='LoadAnnotations', with_bbox=True), dict( From 01b950cefe78e21d4fb22ee1b439309d93387308 Mon Sep 17 00:00:00 2001 From: hewanru-bit <2558905977@qq.com> Date: Mon, 7 Nov 2022 20:47:35 +0800 Subject: [PATCH 07/10] modify --- configs/detection/edffnet/edffnet.py | 4 +- ...t_with_txt.py => cityscape_enhancement.py} | 12 +- .../cityscape_enhancement_with_anno.py | 75 ----------- configs/edit/aodnet/aodnet.py | 22 +-- lqit/common/__init__.py | 4 +- lqit/common/dataset_wrappers.py | 114 ---------------- lqit/edit/datasets/cityscape_foggy_dataset.py | 62 ++++----- lqit/edit/models/editors/__init__.py | 2 +- lqit/edit/models/editors/aodnet/__init__.py | 1 - lqit/edit/models/editors/aodnet/aodnet.py | 126 +++++------------- .../models/editors/aodnet/aodnet_generator.py | 6 +- 11 files changed, 82 insertions(+), 346 deletions(-) rename configs/edit/_base_/datasets/{cityscape_enhancement_with_txt.py => cityscape_enhancement.py} (90%) delete mode 100644 configs/edit/_base_/datasets/cityscape_enhancement_with_anno.py diff --git a/configs/detection/edffnet/edffnet.py b/configs/detection/edffnet/edffnet.py index f648515..c2dc359 100644 --- a/configs/detection/edffnet/edffnet.py +++ b/configs/detection/edffnet/edffnet.py @@ -34,7 +34,7 @@ src_key='img', dst_key='gt_edge', transforms=[ - dict(type='Resize', scale=(256, 256), keep_ratio=True), + dict(type='Resize', scale=(1333, 800), keep_ratio=True), dict(type='RandomFlip', prob=0.5) ]), dict(type='lqit.PackInputs', ) @@ -56,4 +56,4 @@ optim_wrapper = dict( type='OptimWrapper', - optimizer=dict(type='SGD', lr=0.0025, momentum=0.9, weight_decay=0.0001)) \ No newline at end of file + optimizer=dict(type='SGD', lr=0.0025, momentum=0.9, weight_decay=0.0001)) diff --git a/configs/edit/_base_/datasets/cityscape_enhancement_with_txt.py b/configs/edit/_base_/datasets/cityscape_enhancement.py similarity index 90% rename from configs/edit/_base_/datasets/cityscape_enhancement_with_txt.py rename to configs/edit/_base_/datasets/cityscape_enhancement.py index fdab24c..b623cf0 100644 --- a/configs/edit/_base_/datasets/cityscape_enhancement_with_txt.py +++ b/configs/edit/_base_/datasets/cityscape_enhancement.py @@ -41,13 +41,13 @@ metainfo=dict( dataset_type='cityscape_enhancement', task_name='enhancement'), ann_file='cityscape_foggy/train/train.txt', - data_prefix=dict(img='cityscape_foggy/train/', gt_img='cityscape/train/'), + data_prefix=dict( + img='cityscape_foggy/train/', gt_img='cityscape/train/'), search_key='img', img_suffix=dict(img='png', gt_img='png'), file_client_args=file_client_args, pipeline=train_pipeline, - split_str='_foggy' - )) + split_str='_foggy')) val_dataloader = dict( batch_size=1, num_workers=2, @@ -61,13 +61,13 @@ metainfo=dict( dataset_type='cityscape_enhancement', task_name='enhancement'), ann_file='cityscape_foggy/test/test.txt', - data_prefix=dict(img='cityscape_foggy/test/', gt_img='cityscape/test/'), + data_prefix=dict( + img='cityscape_foggy/test/', gt_img='cityscape/test/'), search_key='img', img_suffix=dict(img='png', gt_img='png'), file_client_args=file_client_args, pipeline=test_pipeline, - split_str='_foggy' - )) + split_str='_foggy')) test_dataloader = val_dataloader val_evaluator = [ diff --git a/configs/edit/_base_/datasets/cityscape_enhancement_with_anno.py b/configs/edit/_base_/datasets/cityscape_enhancement_with_anno.py deleted file mode 100644 index 48ca2f7..0000000 --- a/configs/edit/_base_/datasets/cityscape_enhancement_with_anno.py +++ /dev/null @@ -1,75 +0,0 @@ -# dataset settings -dataset_type = 'mmdet.CityscapesDataset' -data_root = 'data/Datasets/' - -file_client_args = dict(backend='disk') - -train_pipeline = [ - dict(type='LoadImageFromFile', file_client_args=file_client_args), - dict(type='LoadGTImageFromFile', file_client_args=file_client_args), - dict( - type='TransBroadcaster', - src_key='img', - dst_key='gt_img', - transforms=[ - dict(type='Resize', scale=(512, 512), keep_ratio=True), - dict(type='RandomFlip', prob=0.5), - ]), - dict(type='PackInputs') -] -test_pipeline = [ - dict(type='LoadImageFromFile', file_client_args=file_client_args), - dict(type='LoadGTImageFromFile', file_client_args=file_client_args), - dict( - type='TransBroadcaster', - src_key='img', - dst_key='gt_img', - transforms=[dict(type='Resize', scale=(512, 512), keep_ratio=True)]), - dict( - type='PackInputs', - meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', - 'scale_factor')) -] -train_dataloader = dict( - batch_size=2, - num_workers=2, - persistent_workers=True, - sampler=dict(type='DefaultSampler', shuffle=True), - dataset=dict( - type='lqit.DatasetWithClearImageWrapper', - dataset=dict( - type=dataset_type, - data_root=data_root, - ann_file='cityscape_foggy/annotations_json/instancesonly_filtered_gtFine_train.json', - data_prefix=dict(img='cityscape_foggy/train/', gt_img_path='cityscape/train/'), - filter_cfg=dict(filter_empty_gt=True, min_size=32), - pipeline=train_pipeline), - suffix='png' - )) - -val_dataloader = dict( - batch_size=1, - num_workers=2, - persistent_workers=True, - drop_last=False, - sampler=dict(type='DefaultSampler', shuffle=False), - dataset=dict( - type='lqit.DatasetWithClearImageWrapper', - dataset=dict( - type=dataset_type, - data_root=data_root, - test_mode=True, - indices=100, - ann_file='cityscape_foggy/annotations_json/instancesonly_filtered_gtFine_test.json', - data_prefix=dict(img='cityscape_foggy/test/', gt_img_path='cityscape/test/'), - filter_cfg=dict(filter_empty_gt=True, min_size=32), - pipeline=test_pipeline), - suffix='png' - )) - -test_dataloader = val_dataloader - -val_evaluator = [ - dict(type='MSE', gt_key='img', pred_key='pred_img'), -] -test_evaluator = val_evaluator diff --git a/configs/edit/aodnet/aodnet.py b/configs/edit/aodnet/aodnet.py index 3ed4ac0..f2ddc76 100644 --- a/configs/edit/aodnet/aodnet.py +++ b/configs/edit/aodnet/aodnet.py @@ -1,14 +1,13 @@ _base_ = [ - '../_base_/datasets/cityscape_enhancement_with_txt.py', - '../_base_/schedules/schedule_1x.py', - '../_base_/default_runtime.py' + '../_base_/datasets/cityscape_enhancement.py', + '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' ] model = dict( type='lqit.BaseEditModel', data_preprocessor=dict( type='lqit.EditDataPreprocessor', - mean=[0.0, 0.0, 0.0], - std=[255.0, 255.0, 255.0], + mean=[0.5, 0.5, 0.5], + std=[0.5, 0.5, 0.5], bgr_to_rgb=True, pad_size_divisor=32, gt_name='img'), @@ -16,14 +15,15 @@ _scope_='lqit', type='AODNetGenerator', model=dict(type='AODNet'), - pixel_loss=dict(type='MSELoss', loss_weight=1.0) - )) - + pixel_loss=dict(type='MSELoss', loss_weight=1.0))) train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=10, val_interval=1) param_scheduler = [ dict( - type='LinearLR', start_factor=0.0001, by_epoch=False, begin=0, + type='LinearLR', + start_factor=0.0001, + by_epoch=False, + begin=0, end=1000), dict( type='MultiStepLR', @@ -31,9 +31,9 @@ end=10, by_epoch=True, milestones=[6, 9], - gamma=0.1) + gamma=0.5) ] optim_wrapper = dict( type='OptimWrapper', - optimizer=dict(type='SGD', lr=0.001, momentum=0.9, weight_decay=0.0001)) \ No newline at end of file + optimizer=dict(type='Adam', lr=0.0001, momentum=0.9, weight_decay=0.0001)) diff --git a/lqit/common/__init__.py b/lqit/common/__init__.py index c4a9ae8..7639b50 100644 --- a/lqit/common/__init__.py +++ b/lqit/common/__init__.py @@ -1,6 +1,6 @@ from .data_preprocessor import * # noqa: F401,F403 -from .dataset_wrappers import DatasetWithGTImageWrapper, DatasetWithClearImageWrapper +from .dataset_wrappers import DatasetWithGTImageWrapper from .structures import * # noqa: F401,F403 from .transforms import * # noqa: F401,F403 -__all__ = ['DatasetWithGTImageWrapper', 'DatasetWithClearImageWrapper'] +__all__ = ['DatasetWithGTImageWrapper'] diff --git a/lqit/common/dataset_wrappers.py b/lqit/common/dataset_wrappers.py index f14f06f..2bd276d 100644 --- a/lqit/common/dataset_wrappers.py +++ b/lqit/common/dataset_wrappers.py @@ -118,117 +118,3 @@ def parse_gt_img_info(self, data_info: dict) -> Union[dict, List[dict]]: f'.{self.suffix}' data_info['gt_img_path'] = osp.join(gt_img_root, img_name) return data_info - - - -@DATASETS.register_module() -class DatasetWithClearImageWrapper: - """Dataset wrapper for image dehazing task. Add `gt_image_path` simultaneously. - - Args: - dataset (BaseDataset or dict): The dataset - suffix (str): gt_image suffix. Defaults to 'jpg'. - lazy_init (bool, optional): whether to load annotation during - instantiation. Defaults to False - """ - - def __init__(self, - dataset: Union[BaseDataset, dict], - suffix: str = 'jpg', - lazy_init: bool = False) -> None: - self.suffix = suffix - if isinstance(dataset, dict): - self.dataset = DATASETS.build(dataset) - elif isinstance(dataset, BaseDataset): - self.dataset = dataset - else: - raise TypeError( - 'elements in datasets sequence should be config or ' - f'`BaseDataset` instance, but got {type(dataset)}') - self._metainfo = self.dataset.metainfo - - self._fully_initialized = False - if not lazy_init: - self.full_init() - - @property - def metainfo(self) -> dict: - """Get the meta information of the repeated dataset. - - Returns: - dict: The meta information of repeated dataset. - """ - return self._metainfo - - def full_init(self): - self.dataset.full_init() - - def get_data_info(self, idx: int) -> dict: - return self.dataset.get_data_info(idx) - - def prepare_data(self, idx) -> Any: - """Get data processed by ``self.pipeline``. - - Args: - idx (int): The index of ``data_info``. - - Returns: - Any: Depends on ``self.pipeline``. - """ - data_info = self.get_data_info(idx) - data_info = self.parse_gt_img_info(data_info) - return self.dataset.pipeline(data_info) - - def __getitem__(self, idx): - if not self.dataset._fully_initialized: - warnings.warn( - 'Please call `full_init()` method manually to accelerate ' - 'the speed.') - self.dataset.full_init() - - if self.dataset.test_mode: - data = self.prepare_data(idx) - if data is None: - raise Exception('Test time pipline should not get `None` ' - 'data_sample') - return data - - for _ in range(self.dataset.max_refetch + 1): - data = self.prepare_data(idx) - # Broken images or random augmentations may cause the returned data - # to be None - if data is None: - idx = self.dataset._rand_another() - continue - return data - - raise Exception(f'Cannot find valid image after {self.max_refetch}! ' - 'Please check your image path and pipeline') - - def __len__(self): - return len(self.dataset) - - def parse_gt_img_info(self, data_info: dict) -> Union[dict, List[dict]]: - """Parse raw annotation to target format. - - Args: - raw_data_info (dict): Raw data information load from ``ann_file`` - - Returns: - Union[dict, List[dict]]: Parsed annotation. - """ - - gt_img_root = self.dataset.data_prefix.get('gt_img_path', None) - - if gt_img_root is None: - warnings.warn( - 'Cannot get gt_img_root, please set `gt_img_path` in ' - '`dataset.data_prefix`') - data_info['gt_img_path'] = data_info['img_path'] - else: - img_name = \ - f"{osp.split(data_info['img_path'])[0].split('/')[-1]}" + '/'\ - f"{osp.split(data_info['img_path'])[-1].split('_foggy_')[0]}" \ - f'.{self.suffix}' - data_info['gt_img_path'] = osp.join(gt_img_root, img_name) - return data_info diff --git a/lqit/edit/datasets/cityscape_foggy_dataset.py b/lqit/edit/datasets/cityscape_foggy_dataset.py index 9a56e01..df80fa9 100644 --- a/lqit/edit/datasets/cityscape_foggy_dataset.py +++ b/lqit/edit/datasets/cityscape_foggy_dataset.py @@ -1,17 +1,15 @@ # Modified from https://github.com/open-mmlab/mmediting/tree/1.x/ -import warnings import os.path as osp -from typing import Any, Callable, List, Optional, Union - -from .basic_image_dataset import BasicImageDataset +from typing import Callable, List, Optional, Union from lqit.registry import DATASETS +from .basic_image_dataset import BasicImageDataset @DATASETS.register_module() class CityscapeFoggyImageDataset(BasicImageDataset): - """CityscapeFoggyImageDataset for pixel-level vision tasks that have aligned gts, - such as image dehaze using cityscape and cityscape foggy datasets. + """CityscapeFoggyImageDataset for pixel-level vision tasks that have + aligned gts. Args: ann_file (str): Annotation file path. Defaults to ''. @@ -35,7 +33,7 @@ class CityscapeFoggyImageDataset(BasicImageDataset): Defaults to jpg. recursive (bool): If set to True, recursively scan the directory. Defaults to False. - split_str (str): split string that used to split image name to gt image name. + split_str (str): split image name to gt image name. Defaults to '_foggy'. """ @@ -70,31 +68,27 @@ def __init__(self, recursive=recursive, **kwards) - def load_data_list(self) -> List[dict]: - """Load data list from folder or annotation file. - - Returns: - list[dict]: A list of annotation. - """ - - img_ids = self._get_img_list() - - data_list = [] - # deal with img and gt img path - for img_id in img_ids: - data = dict(key=img_id) - data['img_id'] = img_id - for key in self.data_prefix: - img_id = self.mapping_table[key].format(img_id) - - if key == 'gt_img': - img_id = img_id.split(self.split_str)[0] - - path = osp.join(self.data_prefix[key], - f'{img_id}.{self.img_suffix[key]}') - data[f'{key}_path'] = path - data_list.append(data) - return data_list - - + """Load data list from folder or annotation file. + + Returns: + list[dict]: A list of annotation. + """ + img_ids = self._get_img_list() + + data_list = [] + # deal with img and gt img path + for img_id in img_ids: + data = dict(key=img_id) + data['img_id'] = img_id + for key in self.data_prefix: + img_id = self.mapping_table[key].format(img_id) + + if key == 'gt_img': + img_id = img_id.split(self.split_str)[0] + + path = osp.join(self.data_prefix[key], + f'{img_id}.{self.img_suffix[key]}') + data[f'{key}_path'] = path + data_list.append(data) + return data_list diff --git a/lqit/edit/models/editors/__init__.py b/lqit/edit/models/editors/__init__.py index 307d93e..aa93705 100644 --- a/lqit/edit/models/editors/__init__.py +++ b/lqit/edit/models/editors/__init__.py @@ -1,3 +1,3 @@ +from .aodnet import * # noqa: F401,F403 from .unet import * # noqa: F401,F403 from .zero_dce import * # noqa: F401,F403 -from .aodnet import * # noqa: F401,F403 \ No newline at end of file diff --git a/lqit/edit/models/editors/aodnet/__init__.py b/lqit/edit/models/editors/aodnet/__init__.py index 1cd3826..b5c2b65 100644 --- a/lqit/edit/models/editors/aodnet/__init__.py +++ b/lqit/edit/models/editors/aodnet/__init__.py @@ -1,5 +1,4 @@ from .aodnet import AODNet from .aodnet_generator import AODNetGenerator - __all__ = ['AODNet', 'AODNetGenerator'] diff --git a/lqit/edit/models/editors/aodnet/aodnet.py b/lqit/edit/models/editors/aodnet/aodnet.py index 86a731d..aa40a34 100644 --- a/lqit/edit/models/editors/aodnet/aodnet.py +++ b/lqit/edit/models/editors/aodnet/aodnet.py @@ -1,102 +1,38 @@ import torch import torch.nn as nn -import warnings -from mmcv.cnn import ConvModule -from mmcv.cnn.bricks.activation import build_activation_layer -from mmengine.model import BaseModule +import torch.nn.functional as F from lqit.registry import MODELS @MODELS.register_module() -class AODNet(BaseModule): - """AOD-Net: All-in-One Dehazing Network.""" - - def __init__(self, - in_channels=(1, 1, 2, 2, 4), - base_channels=3, - out_channels=(3, 3, 3, 3, 3), - num_stages=5, - kernel_size=(1, 3, 5, 7, 3), - padding=(0, 1, 2, 3, 1), - act_cfg=dict(type='ReLU'), - plugins=None, - pretrained=None, - norm_eval=False, - init_cfg=[ - dict(type='Normal', layer='Conv2d', mean=0, std=0.02), - dict( - type='Normal', - layer='BatchNorm2d', - mean=1.0, - std=0.02, - bias=0), - ]): - super().__init__(init_cfg=init_cfg) - - self.pretrained = pretrained - assert not (init_cfg and pretrained), \ - 'init_cfg and pretrained cannot be setting at the same time' - if isinstance(pretrained, str): - warnings.warn('DeprecationWarning: pretrained is a deprecated, ' - 'please use "init_cfg" instead') - self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) - elif pretrained is None: - if init_cfg is None: - self.init_cfg = [ - dict(type='Kaiming', layer='Conv2d'), - dict( - type='Constant', - val=1, - layer=['_BatchNorm', 'GroupNorm']) - ] - else: - raise TypeError('pretrained must be a str or None') - assert plugins is None, 'Not implemented yet.' - - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.num_stages = num_stages - self.base_channels = base_channels - self.padding = padding - self.with_activation = act_cfg is not None - self.norm_eval = norm_eval - self.act_cfg = act_cfg - # build activation layer - if self.with_activation: - act_cfg_ = act_cfg.copy() - # nn.Tanh has no 'inplace' argument - if act_cfg_['type'] not in [ - 'Tanh', 'PReLU', 'Sigmoid', 'HSigmoid', 'Swish' - ]: - act_cfg_.setdefault('inplace', True) - self.activate = build_activation_layer(act_cfg_) - - self._init_layer() - - def _init_layer(self): - - self.CONVM = nn.ModuleList() - for i in range(self.num_stages): - conv_act = ConvModule( - in_channels=self.in_channels[i] * self.base_channels, out_channels=self.out_channels[i], - kernel_size=self.kernel_size[i], stride=1, padding=self.padding[i], bias=True, act_cfg=self.act_cfg) - self.CONVM.append(conv_act) - - - def forward(self, inputs): - outs = [] - x1 = inputs - for i in range(self.num_stages): - if i > 1 and i != (self.num_stages - 1): # from i=2 concat - x1 = torch.cat((outs[i - 2], outs[i - 1]), 1) - - if i == self.num_stages - 1: # last concat all - x1 = torch.cat([outs[j] for j in range(len(outs))], 1) - - x1 = self.CONVM[i](x1) - outs.append(x1) - result = self.activate((outs[-1] * inputs) - outs[-1] + 1) - - return result +class AODNet(nn.Module): + + def __init__(self): + super(AODNet, self).__init__() + self.conv1 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=1) + self.conv2 = nn.Conv2d( + in_channels=3, out_channels=3, kernel_size=3, padding=1) + self.conv3 = nn.Conv2d( + in_channels=6, out_channels=3, kernel_size=5, padding=2) + self.conv4 = nn.Conv2d( + in_channels=6, out_channels=3, kernel_size=7, padding=3) + self.conv5 = nn.Conv2d( + in_channels=12, out_channels=3, kernel_size=3, padding=1) + self.b = 1 + + def forward(self, x): + x1 = F.relu(self.conv1(x)) + x2 = F.relu(self.conv2(x1)) + cat1 = torch.cat((x1, x2), 1) + x3 = F.relu(self.conv3(cat1)) + cat2 = torch.cat((x2, x3), 1) + x4 = F.relu(self.conv4(cat2)) + cat3 = torch.cat((x1, x2, x3, x4), 1) + k = F.relu(self.conv5(cat3)) + + if k.size() != x.size(): + raise Exception('haze image are different size!') + + output = k * x - k + self.b + return F.relu(output) diff --git a/lqit/edit/models/editors/aodnet/aodnet_generator.py b/lqit/edit/models/editors/aodnet/aodnet_generator.py index 7eb023e..d23b558 100644 --- a/lqit/edit/models/editors/aodnet/aodnet_generator.py +++ b/lqit/edit/models/editors/aodnet/aodnet_generator.py @@ -14,10 +14,7 @@ def __init__(self, pixel_loss: ConfigType = dict( type='MSELoss', loss_weight=1.0), init_cfg: OptMultiConfig = None) -> None: - super().__init__( - model=model, - pixel_loss=pixel_loss, - init_cfg=init_cfg) + super().__init__(model=model, pixel_loss=pixel_loss, init_cfg=init_cfg) def forward(self, x): """Forward function. @@ -33,7 +30,6 @@ def forward(self, x): def loss(self, loss_input: BatchPixelData, batch_img_metas: List[dict]): """Calculate the loss based on the outputs of generator.""" batch_outputs = loss_input.output - batch_inputs = loss_input.input batch_gt = loss_input.gt pixel_loss = self.pixel_loss(batch_outputs, batch_gt) From e3b4a60ce632d4de55c8b13e0a6b663d461dc757 Mon Sep 17 00:00:00 2001 From: hewanru-bit <2558905977@qq.com> Date: Tue, 8 Nov 2022 10:04:01 +0800 Subject: [PATCH 08/10] change the AODnet --- lqit/edit/datasets/cityscape_foggy_dataset.py | 5 +++-- lqit/edit/models/editors/aodnet/aodnet.py | 2 ++ .../models/editors/aodnet/aodnet_generator.py | 21 ------------------- 3 files changed, 5 insertions(+), 23 deletions(-) diff --git a/lqit/edit/datasets/cityscape_foggy_dataset.py b/lqit/edit/datasets/cityscape_foggy_dataset.py index df80fa9..1227cbf 100644 --- a/lqit/edit/datasets/cityscape_foggy_dataset.py +++ b/lqit/edit/datasets/cityscape_foggy_dataset.py @@ -50,7 +50,7 @@ def __init__(self, img_suffix: Union[str, dict] = 'jpg', recursive: bool = False, split_str: str = '_foggy', - **kwards): + **kwards) -> None: self.split_str = split_str @@ -83,7 +83,8 @@ def load_data_list(self) -> List[dict]: data['img_id'] = img_id for key in self.data_prefix: img_id = self.mapping_table[key].format(img_id) - + # The gt img name and img name do not match. + # one gt img corresponds to three imgs if key == 'gt_img': img_id = img_id.split(self.split_str)[0] diff --git a/lqit/edit/models/editors/aodnet/aodnet.py b/lqit/edit/models/editors/aodnet/aodnet.py index aa40a34..04b2f35 100644 --- a/lqit/edit/models/editors/aodnet/aodnet.py +++ b/lqit/edit/models/editors/aodnet/aodnet.py @@ -7,6 +7,8 @@ @MODELS.register_module() class AODNet(nn.Module): + """AOD-Net: All-in-One Dehazing Network. + https://ieeexplore.ieee.org/document/8237773""" def __init__(self): super(AODNet, self).__init__() diff --git a/lqit/edit/models/editors/aodnet/aodnet_generator.py b/lqit/edit/models/editors/aodnet/aodnet_generator.py index d23b558..6d9373f 100644 --- a/lqit/edit/models/editors/aodnet/aodnet_generator.py +++ b/lqit/edit/models/editors/aodnet/aodnet_generator.py @@ -16,17 +16,6 @@ def __init__(self, init_cfg: OptMultiConfig = None) -> None: super().__init__(model=model, pixel_loss=pixel_loss, init_cfg=init_cfg) - def forward(self, x): - """Forward function. - - Args: - x (Tensor): Input tensor with shape (n, c, h, w). - - Returns: - Tensor: Forward results. - """ - return self.model(x) - def loss(self, loss_input: BatchPixelData, batch_img_metas: List[dict]): """Calculate the loss based on the outputs of generator.""" batch_outputs = loss_input.output @@ -36,13 +25,3 @@ def loss(self, loss_input: BatchPixelData, batch_img_metas: List[dict]): losses = dict(pixel_loss=pixel_loss) return losses - - def post_precess(self, outputs): - # ZeroDCE return enhance loss and curve at the same time. - assert outputs.dim() in [3, 4] - in_channels = self.model.in_channels - if outputs.dim() == 4: - enhance_img = outputs[:, :in_channels, :, :] - else: - enhance_img = outputs[:in_channels, :, :] - return enhance_img From 771831c02079ba4417edd60dfb3f52f3734a76ad Mon Sep 17 00:00:00 2001 From: hewanru-bit <2558905977@qq.com> Date: Tue, 8 Nov 2022 15:20:27 +0800 Subject: [PATCH 09/10] change --- lqit/edit/models/editors/aodnet/aodnet.py | 3 +-- lqit/edit/models/editors/aodnet/aodnet_generator.py | 3 ++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lqit/edit/models/editors/aodnet/aodnet.py b/lqit/edit/models/editors/aodnet/aodnet.py index 04b2f35..dc42378 100644 --- a/lqit/edit/models/editors/aodnet/aodnet.py +++ b/lqit/edit/models/editors/aodnet/aodnet.py @@ -33,8 +33,7 @@ def forward(self, x): cat3 = torch.cat((x1, x2, x3, x4), 1) k = F.relu(self.conv5(cat3)) - if k.size() != x.size(): - raise Exception('haze image are different size!') + assert k.size() == x.size(), 'haze image are different size' output = k * x - k + self.b return F.relu(output) diff --git a/lqit/edit/models/editors/aodnet/aodnet_generator.py b/lqit/edit/models/editors/aodnet/aodnet_generator.py index 6d9373f..b106758 100644 --- a/lqit/edit/models/editors/aodnet/aodnet_generator.py +++ b/lqit/edit/models/editors/aodnet/aodnet_generator.py @@ -13,7 +13,8 @@ def __init__(self, model: ConfigType, pixel_loss: ConfigType = dict( type='MSELoss', loss_weight=1.0), - init_cfg: OptMultiConfig = None) -> None: + init_cfg: OptMultiConfig = None, + **kwargs) -> None: super().__init__(model=model, pixel_loss=pixel_loss, init_cfg=init_cfg) def loss(self, loss_input: BatchPixelData, batch_img_metas: List[dict]): From 067372a5a8485e5ab5c72240c1788fa0c402e877 Mon Sep 17 00:00:00 2001 From: hewanru-bit <2558905977@qq.com> Date: Tue, 8 Nov 2022 15:25:44 +0800 Subject: [PATCH 10/10] del some files --- configs/detection/edffnet/edffnet_new.py | 62 ------------------------ 1 file changed, 62 deletions(-) delete mode 100644 configs/detection/edffnet/edffnet_new.py diff --git a/configs/detection/edffnet/edffnet_new.py b/configs/detection/edffnet/edffnet_new.py deleted file mode 100644 index d0603b1..0000000 --- a/configs/detection/edffnet/edffnet_new.py +++ /dev/null @@ -1,62 +0,0 @@ -_base_ = '../edffnet/atss_r50_fpn_1x.py' - -model = dict( - type='EDFFNet', - # backbone=dict(norm_eval=False), - neck=dict( - type='DFFPN', - in_channels=[256, 512, 1024, 2048], - out_channels=256, - start_level=1, - add_extra_convs='on_input', - shape_level=2, - num_outs=5), - enhance_head=dict( - type='lqit.EdgeHead', - in_channels=256, - feat_channels=256, - num_convs=5, - loss_enhance=dict(type='mmdet.L1Loss', loss_weight=0.7), - gt_preprocessor=dict( - type='lqit.GTPixelPreprocessor', - mean=[128], - std=[57.12], - pad_size_divisor=32, - element_name='edge')), -) - -# dataset settings -train_pipeline = [ - dict(type='LoadImageFromFile'), - dict(type='LoadAnnotations', with_bbox=True), - dict(type='lqit.GetEdgeGTFromImage', method='scharr'), - dict( - type='lqit.TransBroadcaster', - src_key='img', - dst_key='gt_edge', - transforms=[ - dict(type='Resize', scale=(1333, 800), keep_ratio=True), - dict(type='RandomFlip', prob=0.5), - ]), - dict(type='lqit.PackInputs', ) -] -train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) - -optim_wrapper = dict( - type='OptimWrapper', - optimizer=dict(type='SGD', lr=0.0025, momentum=0.9, weight_decay=0.0001)) -# learning rate -param_scheduler = [ - dict( - type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, - end=1000), - dict( - type='MultiStepLR', - begin=0, - end=12, - by_epoch=True, - milestones=[8, 11], - gamma=0.1) -] - -# load_from = '/home/test/data2/HWR/mmdet_works/edffnet_50.7.pth'