-
Notifications
You must be signed in to change notification settings - Fork 0
/
train6.py
170 lines (135 loc) · 6.22 KB
/
train6.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import argparse
import torch
import torchvision
import torchvision.utils as vutils
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch.autograd import Variable
from PIL import Image
from random import randint
import cv2
import numpy as np
import cfg
import model6
from model6 import NetD
args = cfg.parse_args()
parser = argparse.ArgumentParser()
parser.add_argument('--batchSize', type=int, default=4)#一次训练所取的样本数
parser.add_argument('--imageSize', type=int, default=128)#图片的大小,这个训练集里面图片的大小都是一致的
parser.add_argument('--nz', type=int, default=args.latent_dim, help='size of the latent z vector')#暂时不知道
parser.add_argument('--ngf', type=int, default=64)
parser.add_argument('--ndf', type=int, default=128)
parser.add_argument('--epoch', type=int, default=95, help='number of epochs to train for')#训练的轮数
parser.add_argument('--lr', type=float, default=0.0001, help='learning rate, default=0.0002')#学习率
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')#参数
parser.add_argument('--data_path', default='data1/', help='folder to train data')
parser.add_argument('--data_path2', default='data2/', help='folder to train data')
parser.add_argument('--outf', default='myImgs6/', help='folder to output images and model checkpoints')
opt = parser.parse_args()
#图像读入与预处理
transforms = torchvision.transforms.Compose([
torchvision.transforms.Grayscale(num_output_channels=3), # 彩色图像转灰度图像num_output_channels默认1
torchvision.transforms.Scale(opt.imageSize),
torchvision.transforms.CenterCrop(opt.imageSize),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])#使用平均值和标准偏差对张量图像进行规格化,消除量纲
dataset = torchvision.datasets.ImageFolder(opt.data_path, transform=transforms)
dataset2 = torchvision.datasets.ImageFolder(opt.data_path2, transform=transforms)
# 定义是否使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def my_collate(batch):
data = [item[0] for item in batch]
target = [item[1] for item in batch]
return [data, target]
dataloader = torch.utils.data.DataLoader(#数据取器
dataset=dataset,
batch_size=opt.batchSize,
shuffle=False,
drop_last=True,
)
dataloader2 = torch.utils.data.DataLoader(#数据取器
dataset=dataset,
batch_size=opt.batchSize,
shuffle=False,
drop_last=True,
)
netG = model6.Generator().to(device)
netD = NetD(opt.ndf).to(device)
criterion = nn.BCELoss().to(device)
L1 = nn.L1Loss().to(device) # Pix2Pix论文中在传统GAN目标函数加上了L1
optimizerG = torch.optim.Adam(netG.parameters(), lr=0.0005, betas=(opt.beta1, 0.999))
optimizerD = torch.optim.Adam(netD.parameters(), lr=0.0005, betas=(opt.beta1, 0.999))
label = torch.FloatTensor(opt.batchSize)
real_label = 1
fake_label = 00
for epoch in range(1, opt.epoch + 1):
for i, (img) in enumerate(zip(dataloader,dataloader2)):
#images1是原图
#images2是线稿图
# print("size:")
# print(imgs.shape)
# print(type(imgs))
# imgs = torch.squeeze(imgs) # 若不标注删除第几维度,则会删除所有为1的维度
# plt.imshow(imgs.reshape(opt.imageSize,opt.imageSize,3))
# plt.imshow(imgs.numpy().reshape(opt.imageSize,opt.imageSize,3))
#q = torch.squeeze(img[1][0][0]) # 若不标注删除第几维度,则会删除所有为1的维度
# plt.imshow(q.numpy().reshape(opt.imageSize, opt.imageSize, 3))
#plt.imshow(q.numpy().reshape(opt.imageSize, opt.imageSize, 3))
# 固定生成器G,训练鉴别器D
optimizerD.zero_grad() #把模型中参数的梯度设为0
## 让D尽可能的把真图片判别为1 ,这个是真图
imgs=img[1][0].to(device)
output = netD(imgs)
output = torch.squeeze(output)#squeeze去掉维数为1的的维度 unsqueeze对维度进行扩充
label.data.fill_(real_label)
label=label.to(device)
errD_real = criterion(output, label)
errD_real.backward()
## 让D尽可能把假图片判别为0
label.data.fill_(fake_label)
#noise = torch.randn(opt.batchSize, opt.nz)
noise=img[0][0].to(device)
'''
gpu版本
'''
with torch.no_grad(): #避免梯度传到G,因为G不用更新
fake = netG(noise) # 生成假图
output = netD(fake)
'''
cpu版本
'''
# test process
#output = netD(fake.detach()) #避免梯度传到G,因为G不用更新
output = torch.squeeze(output)
errD_fake = criterion(output, label)
errD_fake.backward()
errD = errD_fake + errD_real
optimizerD.step()
# 固定鉴别器D,训练生成器G
optimizerG.zero_grad()
#之前的fake 是不能传播梯度的,这里一定要重新生成
fake = netG(noise) # 生成假图
# 让D尽可能把G生成的假图判别为1
label.data.fill_(real_label)
label = label.to(device)
output = netD(fake)
output = torch.squeeze(output)
errG = criterion(output, label)
G_L1_Loss = L1(output, label)
lamb = 100# L1正则化的权重
G_loss = errG + lamb * G_L1_Loss
G_loss.backward()
optimizerG.step()
print('[%d/%d][%d/%d] Loss_D: %.3f Loss_G %.3f'
% (epoch, opt.epoch, i, len(dataloader), errD.item(), errG.item()))
if(i%(40+epoch) == 0):
vutils.save_image(fake.data,
'%s/%03d_%03d.png' % (opt.outf, epoch,i+1),
normalize=True)
vutils.save_image(fake.data,
'%s/fake_samples_epoch_%03d.png' % (opt.outf, epoch),
normalize=True)
torch.save(netG.state_dict(), '%s/netG_%03d.pth' % (opt.outf, epoch))
torch.save(netD.state_dict(), '%s/netD_%03d.pth' % (opt.outf, epoch))