-
Notifications
You must be signed in to change notification settings - Fork 0
/
module.py
36 lines (32 loc) · 1.26 KB
/
module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import pytorch_lightning as pl
import torch
import torch.nn as nn
from cifar10_models.densenet import densenet121, densenet161, densenet169
from cifar10_models.googlenet import googlenet
from cifar10_models.inception import inception_v3
from cifar10_models.mobilenetv2 import mobilenet_v2
from cifar10_models.resnet import resnet18, resnet34, resnet50
from cifar10_models.vgg import vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn
import os
all_classifiers = {
"vgg11_bn": vgg11_bn(pretrained=True),
"vgg13_bn": vgg13_bn(pretrained=True),
"vgg16_bn": vgg16_bn(pretrained=True),
"vgg19_bn": vgg19_bn(pretrained=True),
"resnet18": resnet18(pretrained=True),
"resnet34": resnet34(pretrained=True),
"resnet50": resnet50(pretrained=True),
"densenet121": densenet121(pretrained=True),
"densenet161": densenet161(pretrained=True),
"densenet169": densenet169(pretrained=True),
"mobilenet_v2": mobilenet_v2(pretrained=True),
"googlenet": googlenet(pretrained=True),
"inception_v3": inception_v3(pretrained=True),
}
class CIFAR10Model(nn.Module):
def __init__(self, classifier_name):
super(CIFAR10Model, self).__init__()
self.model = all_classifiers[classifier_name]
def forward(self, x):
x = self.model(x)
return x