Skip to content

Commit

Permalink
feat: add CondNet
Browse files Browse the repository at this point in the history
  • Loading branch information
sithu31296 committed Sep 24, 2021
1 parent 6e50371 commit 9798e02
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ Supported Heads/Methods:
* [UPerNet][upernet]
* [SFNet][sfnet]
* [SegFormer][segformer]
* [CondNet][condnet]

Supported Standalone Models:
* [DDRNet][ddrnet]
Expand Down
3 changes: 2 additions & 1 deletion models/heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
from .fpn import FPNHead
from .fapn import FaPNHead
from .fcn import FCNHead
from .condnet import CondHead

__all__ = ['UPerHead', 'SegFormerHead', 'SFHead', 'FPNHead', 'FaPNHead', 'FCNHead']
__all__ = ['UPerHead', 'SegFormerHead', 'SFHead', 'FPNHead', 'FaPNHead', 'FCNHead', 'CondHead']
67 changes: 67 additions & 0 deletions models/heads/condnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import torch
from torch import nn, Tensor
from torch.nn import functional as F


class ConvModule(nn.Sequential):
def __init__(self, c1, c2, k, s=1, p=0, d=1, g=1):
super().__init__(
nn.Conv2d(c1, c2, k, s, p, d, g, bias=False),
nn.BatchNorm2d(c2),
nn.ReLU(True)
)


class CondHead(nn.Module):
def __init__(self, in_channel: int = 2048, channel: int = 512, num_classes: int = 19):
super().__init__()
self.num_classes = num_classes
self.weight_num = channel * num_classes
self.bias_num = num_classes

self.conv = ConvModule(in_channel, channel, 1)
self.dropout = nn.Dropout2d(0.1)

self.guidance_project = nn.Conv2d(channel, num_classes, 1)
self.filter_project = nn.Conv2d(channel*num_classes, self.weight_num + self.bias_num, 1, groups=num_classes)

def forward(self, features) -> Tensor:
x = self.dropout(self.conv(features[-1]))
B, C, H, W = x.shape
guidance_mask = self.guidance_project(x)
cond_logit = guidance_mask

key = x
value = x
guidance_mask = guidance_mask.softmax(dim=1).view(*guidance_mask.shape[:2], -1)
key = key.view(B, C, -1).permute(0, 2, 1)

cond_filters = torch.matmul(guidance_mask, key)
cond_filters /= H * W
cond_filters = cond_filters.view(B, -1, 1, 1)
cond_filters = self.filter_project(cond_filters)
cond_filters = cond_filters.view(B, -1)

weight, bias = torch.split(cond_filters, [self.weight_num, self.bias_num], dim=1)
weight = weight.reshape(B * self.num_classes, -1, 1, 1)
bias = bias.reshape(B * self.num_classes)

value = value.view(-1, H, W).unsqueeze(0)
seg_logit = F.conv2d(value, weight, bias, 1, 0, groups=B).view(B, self.num_classes, H, W)

if self.training:
return cond_logit, seg_logit
return seg_logit


if __name__ == '__main__':
import sys
sys.path.insert(0, '.')
from models.backbones.resnetd import ResNetD
backbone = ResNetD('50')
head = CondHead()
x = torch.randn(2, 3, 224, 224)
features = backbone(x)
outs = head(features)
for out in outs:
print(out.shape)

0 comments on commit 9798e02

Please sign in to comment.