-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
executable file
·128 lines (99 loc) · 3.45 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#!python
import os
import torch
import logging
from time import time
from tqdm import tqdm
from logging import info
from torch.nn import functional as F
from matplotlib import pyplot as plt
from upycli.decorator import command
from src.dataloaders import ImageDataset, FacesDataset, CarsDataset
from src.noise import NoiseScheduler
from src.models.basic import BasicDenoisingDiffusion
logging.basicConfig(level=logging.INFO)
@command
def train(device = "cpu",
dataset = "faces",
debug = False,
dryrun = False,
# Noise Scheduler parameters
schedule = "linear",
timesteps = 300,
start = 0.0001,
end = 0.020,
# Training parameterss
lr = 0.001,
epochs = 50,
batch_size = 16,
# Model shape
shape = (
# Downsampling
[64, 128, 256, 512, 1024],
# Upsampling
[1024, 512, 256, 128, 64])):
""" Train a new model from scratch.
"""
info("Start Training")
if dataset == "faces":
ds = FacesDataset(device)
elif dataset == "cars":
ds = CarsDataset()
else:
raise ValueError(f"Unknown dataset `{dataset}`")
info(f"Built {len(ds.images) // batch_size} batches of {batch_size} samples")
ns = NoiseScheduler(
ntype=schedule,
steps=timesteps,
start=start,
end=end,
device=device)
# Build model
model = BasicDenoisingDiffusion(shape)
param_size = sum([p.numel() for p in model.parameters()])
info(f"DenoisingDiffusion Model :: {param_size} parameters")
if debug:
print(repr(model))
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
if not os.path.exists('.checkpoints'):
os.mkdir('.checkpoints')
ns.save(f".checkpoints/scheduler.json")
__start = time()
losslog = list()
for E in range(1, epochs + 1):
print(f"Epoch {E}/{epochs}", f"Epoch Loss {losslog[-1]}" if losslog else "")
dl = ds.loader(batch_size)
for batch in tqdm(dl):
optimizer.zero_grad()
timestep = torch.randint(0, ns.steps,
size=(batch_size,),
device=device,
dtype=torch.long)
image_, noise = ns.forward_diffusion(batch, timestep)
noise_ = model(image_, timestep)
loss = F.l1_loss(noise, noise_)
loss.backward()
optimizer.step()
losslog.append(loss.detach().cpu().item())
if dryrun:
break
# Save checkpoint for this epoch
torch.save(model, f".checkpoints/epoch_{E}.pt")
plt.figure(figsize=(12,4), dpi=150)
plt.semilogy(losslog)
plt.savefig("results/losslog.png")
plt.close()
ImageDataset.plot(model.sample(ns, 8), save=f"results/training/epoch_{E}.png")
__end = time()
info(f"Training time {round((__end - __start)/60, 3)} minutes.")
torch.save(model, "results/model.pt")
ns.save("results/scheduler.json")
@command
def test(model = "results/model.pt", ns_path = "results/scheduler.json", device = "cpu"):
""" Run a model from a given path.
"""
info("Start Testing")
model = torch.load(model, map_location=device)
ns = NoiseScheduler.load(ns_path, device=device)
ImageDataset.plot(model.sample(ns, 16), save=f"results/generated.png")