-
Notifications
You must be signed in to change notification settings - Fork 74
/
train.py
122 lines (103 loc) · 3.92 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import argparse
from argparse import Namespace
from pathlib import Path
import warnings
import torch
import pytorch_lightning as pl
import yaml
import numpy as np
from lightning_modules import LigandPocketDDPM
def merge_args_and_yaml(args, config_dict):
arg_dict = args.__dict__
for key, value in config_dict.items():
if key in arg_dict:
warnings.warn(f"Command line argument '{key}' (value: "
f"{arg_dict[key]}) will be overwritten with value "
f"{value} provided in the config file.")
if isinstance(value, dict):
arg_dict[key] = Namespace(**value)
else:
arg_dict[key] = value
return args
def merge_configs(config, resume_config):
for key, value in resume_config.items():
if isinstance(value, Namespace):
value = value.__dict__
if key in config and config[key] != value:
warnings.warn(f"Config parameter '{key}' (value: "
f"{config[key]}) will be overwritten with value "
f"{value} from the checkpoint.")
config[key] = value
return config
# ------------------------------------------------------------------------------
# Training
# ______________________________________________________________________________
if __name__ == "__main__":
p = argparse.ArgumentParser()
p.add_argument('--config', type=str, required=True)
p.add_argument('--resume', type=str, default=None)
args = p.parse_args()
with open(args.config, 'r') as f:
config = yaml.safe_load(f)
assert 'resume' not in config
# Get main config
ckpt_path = None if args.resume is None else Path(args.resume)
if args.resume is not None:
resume_config = torch.load(
ckpt_path, map_location=torch.device('cpu'))['hyper_parameters']
config = merge_configs(config, resume_config)
args = merge_args_and_yaml(args, config)
out_dir = Path(args.logdir, args.run_name)
histogram_file = Path(args.datadir, 'size_distribution.npy')
histogram = np.load(histogram_file).tolist()
pl_module = LigandPocketDDPM(
outdir=out_dir,
dataset=args.dataset,
datadir=args.datadir,
batch_size=args.batch_size,
lr=args.lr,
egnn_params=args.egnn_params,
diffusion_params=args.diffusion_params,
num_workers=args.num_workers,
augment_noise=args.augment_noise,
augment_rotation=args.augment_rotation,
clip_grad=args.clip_grad,
eval_epochs=args.eval_epochs,
eval_params=args.eval_params,
visualize_sample_epoch=args.visualize_sample_epoch,
visualize_chain_epoch=args.visualize_chain_epoch,
auxiliary_loss=args.auxiliary_loss,
loss_params=args.loss_params,
mode=args.mode,
node_histogram=histogram,
pocket_representation=args.pocket_representation,
virtual_nodes=args.virtual_nodes
)
logger = pl.loggers.WandbLogger(
save_dir=args.logdir,
project='ligand-pocket-ddpm',
group=args.wandb_params.group,
name=args.run_name,
id=args.run_name,
resume='must' if args.resume is not None else False,
entity=args.wandb_params.entity,
mode=args.wandb_params.mode,
)
checkpoint_callback = pl.callbacks.ModelCheckpoint(
dirpath=Path(out_dir, 'checkpoints'),
filename="best-model-epoch={epoch:02d}",
monitor="loss/val",
save_top_k=1,
save_last=True,
mode="min",
)
trainer = pl.Trainer(
max_epochs=args.n_epochs,
logger=logger,
callbacks=[checkpoint_callback],
enable_progress_bar=args.enable_progress_bar,
num_sanity_val_steps=args.num_sanity_val_steps,
accelerator='gpu', devices=args.gpus,
strategy=('ddp' if args.gpus > 1 else None)
)
trainer.fit(model=pl_module, ckpt_path=ckpt_path)