This repository has been archived by the owner on Aug 3, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 369
/
run.py
104 lines (89 loc) · 3.44 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# Copyright (c) 2017 NVIDIA Corporation
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import os
import sys
import tensorflow as tf
if hasattr(tf.compat, 'v1'):
tf.compat.v1.disable_eager_execution()
from open_seq2seq.utils.utils import deco_print, get_base_config, create_model,\
create_logdir, check_logdir, \
check_base_model_logdir
from open_seq2seq.utils import train, infer, evaluate
def main():
# Parse args and create config
args, base_config, base_model, config_module = get_base_config(sys.argv[1:])
if args.mode == "interactive_infer":
raise ValueError(
"Interactive infer is meant to be run from an IPython",
"notebook not from run.py."
)
# restore_best_checkpoint = base_config.get('restore_best_checkpoint', False)
# # Check logdir and create it if necessary
# checkpoint = check_logdir(args, base_config, restore_best_checkpoint)
load_model = base_config.get('load_model', None)
restore_best_checkpoint = base_config.get('restore_best_checkpoint', False)
base_ckpt_dir = check_base_model_logdir(load_model, args,
restore_best_checkpoint)
base_config['load_model'] = base_ckpt_dir
# Check logdir and create it if necessary
checkpoint = check_logdir(args, base_config, restore_best_checkpoint)
# Initilize Horovod
if base_config['use_horovod']:
import horovod.tensorflow as hvd
hvd.init()
if hvd.rank() == 0:
deco_print("Using horovod")
from mpi4py import MPI
MPI.COMM_WORLD.Barrier()
else:
hvd = None
if args.enable_logs:
if hvd is None or hvd.rank() == 0:
old_stdout, old_stderr, stdout_log, stderr_log = create_logdir(
args,
base_config
)
base_config['logdir'] = os.path.join(base_config['logdir'], 'logs')
if args.mode == 'train' or args.mode == 'train_eval' or args.benchmark:
if hvd is None or hvd.rank() == 0:
if checkpoint is None or args.benchmark:
if base_ckpt_dir:
deco_print("Starting training from the base model")
else:
deco_print("Starting training from scratch")
else:
deco_print(
"Restored checkpoint from {}. Resuming training".format(checkpoint),
)
elif args.mode == 'eval' or args.mode == 'infer':
if hvd is None or hvd.rank() == 0:
deco_print("Loading model from {}".format(checkpoint))
# Create model and train/eval/infer
with tf.Graph().as_default():
model = create_model(
args, base_config, config_module, base_model, hvd, checkpoint)
hooks = None
if ('train_params' in config_module and
'hooks' in config_module['train_params']):
hooks = config_module['train_params']['hooks']
if args.mode == "train_eval":
train(
model[0], eval_model=model[1], debug_port=args.debug_port,
custom_hooks=hooks)
elif args.mode == "train":
train(
model, eval_model=None, debug_port=args.debug_port, custom_hooks=hooks)
elif args.mode == "eval":
evaluate(model, checkpoint)
elif args.mode == "infer":
infer(model, checkpoint, args.infer_output_file)
if args.enable_logs and (hvd is None or hvd.rank() == 0):
sys.stdout = old_stdout
sys.stderr = old_stderr
stdout_log.close()
stderr_log.close()
if __name__ == '__main__':
main()