-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
86 lines (74 loc) · 3.18 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""
This file runs the main training/val loop, etc... using Lightning Trainer
"""
import os
from os.path import join
import random
import torch
from argparse import ArgumentParser
from core import utils
def main(args):
from pytorch_lightning import Trainer
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from core.models import TheEye
model = TheEye(args)
seed = 1
# don't seed numpy because want random dataloader init, but distr training requires same seed for model inits
os.environ["PYTHONHASHSEED"] = str(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# most basic trainer, uses good defaults
# os.environ["WANDB_API_KEY"] = args.wandb_api_key
logger = pl_loggers.TensorBoardLogger('~/tensorboard_logs')
checkpoint_callback = ModelCheckpoint(dirpath='/media/heka/TERA/Data/openimages_models/', # TODO: replace when VM
filename=join(args.experiment, args.run_name),
monitor='loss_val')
lr_monitor = LearningRateMonitor(logging_interval='step')
trainer = Trainer(logger=logger,
checkpoint_callback=checkpoint_callback,
callbacks=[lr_monitor],
default_root_dir=None,
gradient_clip_val=0,
gpus=args.gpus,
auto_select_gpus=False, # True will assume gpus present...
log_gpu_memory=None,
progress_bar_refresh_rate=1,
overfit_batches=0.,
fast_dev_run=False,
accumulate_grad_batches=1,
max_epochs=args.max_epochs,
limit_train_batches=vars(args).get('limit_train_batches', 1.),
val_check_interval=args.val_check_interval,
limit_val_batches=args.limit_val_batches,
accelerator='ddp',
sync_batchnorm=False,
precision=args.precision,
weights_summary='top',
weights_save_path=None,
num_sanity_val_steps=args.num_sanity_val_steps,
resume_from_checkpoint=args.resume_from,
benchmark=False,
deterministic=False,
reload_dataloaders_every_epoch=False,
terminate_on_nan=False, # do NOT use on TPUs, veeeery slow!!
prepare_data_per_node=True,
amp_backend='native')
trainer.logger.log_hyperparams(args)
trainer.fit(model)
if __name__ == '__main__':
parser = ArgumentParser(add_help=False)
# add CLI args:
parser.add_argument(
"--config-module",
default='configs.openimages',
metavar="FILE",
help="path to config module (usually under ./configs)",
type=str,
)
# parse params
args = parser.parse_args()
args_config = utils.load_args_module(args)
vars(args).update(args_config)
main(args)