Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] add AODNet and support cityscape foggy dataset #15

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
40 changes: 12 additions & 28 deletions configs/detection/edffnet/edffnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

model = dict(
type='EDFFNet',
backbone=dict(norm_eval=True),
backbone=dict(norm_eval=False),
neck=dict(
type='DFFPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs='on_input',
add_extra_convs='on_output',
shape_level=2,
num_outs=5),
enhance_head=dict(
Expand All @@ -19,37 +19,21 @@
loss_enhance=dict(type='mmdet.L1Loss', loss_weight=0.7),
gt_preprocessor=dict(
type='lqit.GTPixelPreprocessor',
mean=[128],
std=[57.12],
mean=[123.675],
std=[58.395],
pad_size_divisor=32,
element_name='edge')))
element_name='edge')),
)

# dataset settings
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='LoadImageFromFile',
file_client_args={{_base_.file_client_args}}),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', prob=0.5),
dict(type='lqit.GetEdgeGTFromImage', method='scharr'),
dict(
type='lqit.TransBroadcaster',
src_key='img',
dst_key='gt_edge',
transforms=[
dict(type='Resize', scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', prob=0.5)
]),
dict(type='lqit.PackInputs', )
dict(type='lqit.PackInputs')
]
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))

param_scheduler = [
dict(
type='LinearLR', start_factor=0.001, by_epoch=False, begin=0,
end=1000),
dict(
type='MultiStepLR',
begin=0,
end=12,
by_epoch=True,
milestones=[8, 11],
gamma=0.1)
]
76 changes: 76 additions & 0 deletions configs/edit/_base_/datasets/cityscape_enhancement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# dataset settings
dataset_type = 'CityscapeFoggyImageDataset'
data_root = 'data/Datasets/'

file_client_args = dict(backend='disk')

train_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='LoadGTImageFromFile', file_client_args=file_client_args),
dict(
type='TransBroadcaster',
src_key='img',
dst_key='gt_img',
transforms=[
dict(type='Resize', scale=(512, 512), keep_ratio=True),
dict(type='RandomFlip', prob=0.5),
]),
dict(type='PackInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='LoadGTImageFromFile', file_client_args=file_client_args),
dict(
type='TransBroadcaster',
src_key='img',
dst_key='gt_img',
transforms=[dict(type='Resize', scale=(512, 512), keep_ratio=True)]),
dict(
type='PackInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor'))
]
train_dataloader = dict(
batch_size=2,
num_workers=2,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
metainfo=dict(
dataset_type='cityscape_enhancement', task_name='enhancement'),
ann_file='cityscape_foggy/train/train.txt',
data_prefix=dict(
img='cityscape_foggy/train/', gt_img='cityscape/train/'),
search_key='img',
img_suffix=dict(img='png', gt_img='png'),
file_client_args=file_client_args,
pipeline=train_pipeline,
split_str='_foggy'))
val_dataloader = dict(
batch_size=1,
num_workers=2,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
test_mode=True,
metainfo=dict(
dataset_type='cityscape_enhancement', task_name='enhancement'),
ann_file='cityscape_foggy/test/test.txt',
data_prefix=dict(
img='cityscape_foggy/test/', gt_img='cityscape/test/'),
search_key='img',
img_suffix=dict(img='png', gt_img='png'),
file_client_args=file_client_args,
pipeline=test_pipeline,
split_str='_foggy'))
test_dataloader = val_dataloader

val_evaluator = [
dict(type='MSE', gt_key='img', pred_key='pred_img'),
]
test_evaluator = val_evaluator
39 changes: 39 additions & 0 deletions configs/edit/aodnet/aodnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
_base_ = [
'../_base_/datasets/cityscape_enhancement.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
model = dict(
type='lqit.BaseEditModel',
data_preprocessor=dict(
type='lqit.EditDataPreprocessor',
mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5],
bgr_to_rgb=True,
pad_size_divisor=32,
gt_name='img'),
generator=dict(
_scope_='lqit',
hewanru-bit marked this conversation as resolved.
Show resolved Hide resolved
type='AODNetGenerator',
model=dict(type='AODNet'),
pixel_loss=dict(type='MSELoss', loss_weight=1.0)))

train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=10, val_interval=1)
param_scheduler = [
dict(
type='LinearLR',
start_factor=0.0001,
by_epoch=False,
begin=0,
end=1000),
dict(
type='MultiStepLR',
begin=0,
end=10,
by_epoch=True,
milestones=[6, 9],
gamma=0.5)
]

optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='Adam', lr=0.0001, momentum=0.9, weight_decay=0.0001))
7 changes: 1 addition & 6 deletions lqit/detection/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
from .class_names import * # noqa: F401,F403
from .rtts import RTTSCocoDataset
from .urpc import URPCCocoDataset, URPCXMLDataset
from .xml_dataset import XMLDatasetWithMetaFile

__all__ = [
'XMLDatasetWithMetaFile', 'URPCCocoDataset', 'URPCXMLDataset',
'RTTSCocoDataset'
]
__all__ = ['URPCCocoDataset', 'URPCXMLDataset', 'RTTSCocoDataset']
4 changes: 2 additions & 2 deletions lqit/detection/models/detectors/edffnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from mmdet.registry import MODELS
from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig

from .single_stage_enhance_head import SingleStageWithEnhanceHead
from .single_stage_enhance_head import SingleStageDetector


@MODELS.register_module()
class EDFFNet(SingleStageWithEnhanceHead):
class EDFFNet(SingleStageDetector):

def __init__(self,
backbone: ConfigType,
Expand Down
1 change: 1 addition & 0 deletions lqit/detection/models/necks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .dffpn import DFFPN

__all__ = ['DFFPN']
3 changes: 2 additions & 1 deletion lqit/edit/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .basic_image_dataset import BasicImageDataset
from .cityscape_foggy_dataset import CityscapeFoggyImageDataset

__all__ = ['BasicImageDataset']
__all__ = ['BasicImageDataset', 'CityscapeFoggyImageDataset']
95 changes: 95 additions & 0 deletions lqit/edit/datasets/cityscape_foggy_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Modified from https://github.com/open-mmlab/mmediting/tree/1.x/
import os.path as osp
from typing import Callable, List, Optional, Union

from lqit.registry import DATASETS
from .basic_image_dataset import BasicImageDataset


@DATASETS.register_module()
class CityscapeFoggyImageDataset(BasicImageDataset):
"""CityscapeFoggyImageDataset for pixel-level vision tasks that have
aligned gts.

Args:
ann_file (str): Annotation file path. Defaults to ''.
metainfo (dict, optional): Meta information for dataset, such as class
information. Defaults to None.
data_root (str, optional): The root directory for ``data_prefix`` and
``ann_file``. Defaults to None.
data_prefix (dict, optional): Prefix for data. Defaults to
dict(img='').
mapping_table (dict): Mapping table for data.
Defaults to dict().
pipeline (list, optional): Processing pipeline. Defaults to [].
test_mode (bool, optional): ``test_mode=True`` means in test phase.
Defaults to False.
search_key (str): The key used for searching the folder to get
data_list. Defaults to 'gt'.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmengine.fileio.FileClient` for details.
Defaults to dict(backend='disk').
img_suffix (str or dict[str]): Image suffix that we are interested in.
Defaults to jpg.
recursive (bool): If set to True, recursively scan the
directory. Defaults to False.
split_str (str): split image name to gt image name.
Defaults to '_foggy'.
"""

def __init__(self,
ann_file: str = '',
metainfo: Optional[dict] = None,
data_root: Optional[str] = None,
data_prefix: dict = dict(img=''),
mapping_table: dict = dict(),
pipeline: List[Union[dict, Callable]] = [],
test_mode: bool = False,
search_key: Optional[str] = None,
file_client_args: dict = dict(backend='disk'),
img_suffix: Union[str, dict] = 'jpg',
recursive: bool = False,
split_str: str = '_foggy',
**kwards) -> None:

self.split_str = split_str

super().__init__(
ann_file=ann_file,
metainfo=metainfo,
data_root=data_root,
data_prefix=data_prefix,
mapping_table=mapping_table,
pipeline=pipeline,
test_mode=test_mode,
search_key=search_key,
file_client_args=file_client_args,
img_suffix=img_suffix,
recursive=recursive,
**kwards)

def load_data_list(self) -> List[dict]:
"""Load data list from folder or annotation file.

Returns:
list[dict]: A list of annotation.
"""
img_ids = self._get_img_list()

data_list = []
# deal with img and gt img path
for img_id in img_ids:
data = dict(key=img_id)
data['img_id'] = img_id
for key in self.data_prefix:
img_id = self.mapping_table[key].format(img_id)
# The gt img name and img name do not match.
# one gt img corresponds to three imgs
if key == 'gt_img':
img_id = img_id.split(self.split_str)[0]

path = osp.join(self.data_prefix[key],
f'{img_id}.{self.img_suffix[key]}')
data[f'{key}_path'] = path
data_list.append(data)
return data_list
1 change: 1 addition & 0 deletions lqit/edit/models/editors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .aodnet import * # noqa: F401,F403
from .unet import * # noqa: F401,F403
from .zero_dce import * # noqa: F401,F403
4 changes: 4 additions & 0 deletions lqit/edit/models/editors/aodnet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .aodnet import AODNet
from .aodnet_generator import AODNetGenerator

__all__ = ['AODNet', 'AODNetGenerator']
39 changes: 39 additions & 0 deletions lqit/edit/models/editors/aodnet/aodnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from lqit.registry import MODELS


@MODELS.register_module()
class AODNet(nn.Module):
"""AOD-Net: All-in-One Dehazing Network.
https://ieeexplore.ieee.org/document/8237773"""

def __init__(self):
super(AODNet, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=1)
self.conv2 = nn.Conv2d(
in_channels=3, out_channels=3, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(
in_channels=6, out_channels=3, kernel_size=5, padding=2)
self.conv4 = nn.Conv2d(
in_channels=6, out_channels=3, kernel_size=7, padding=3)
self.conv5 = nn.Conv2d(
in_channels=12, out_channels=3, kernel_size=3, padding=1)
self.b = 1

def forward(self, x):
x1 = F.relu(self.conv1(x))
x2 = F.relu(self.conv2(x1))
cat1 = torch.cat((x1, x2), 1)
x3 = F.relu(self.conv3(cat1))
cat2 = torch.cat((x2, x3), 1)
x4 = F.relu(self.conv4(cat2))
cat3 = torch.cat((x1, x2, x3, x4), 1)
k = F.relu(self.conv5(cat3))

assert k.size() == x.size(), 'haze image are different size'

output = k * x - k + self.b
return F.relu(output)
28 changes: 28 additions & 0 deletions lqit/edit/models/editors/aodnet/aodnet_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import List

from lqit.edit.models.base_models import BaseGenerator
from lqit.edit.structures import BatchPixelData
from lqit.registry import MODELS
from lqit.utils.typing import ConfigType, OptMultiConfig


@MODELS.register_module()
class AODNetGenerator(BaseGenerator):

def __init__(self,
model: ConfigType,
pixel_loss: ConfigType = dict(
type='MSELoss', loss_weight=1.0),
init_cfg: OptMultiConfig = None,
**kwargs) -> None:
super().__init__(model=model, pixel_loss=pixel_loss, init_cfg=init_cfg)

def loss(self, loss_input: BatchPixelData, batch_img_metas: List[dict]):
"""Calculate the loss based on the outputs of generator."""
batch_outputs = loss_input.output
batch_gt = loss_input.gt

pixel_loss = self.pixel_loss(batch_outputs, batch_gt)

losses = dict(pixel_loss=pixel_loss)
return losses
Loading