forked from dhlab-epfl/dhSegment
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
132 lines (114 loc) · 6.38 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
import tensorflow as tf
# Tensorflow logging level
from logging import WARNING # import DEBUG, INFO, ERROR for more/less verbosity
tf.logging.set_verbosity(WARNING)
from dh_segment import estimator_fn, utils
from dh_segment.io import input
import json
from glob import glob
import numpy as np
try:
import better_exceptions
except ImportError:
print('/!\ W -- Not able to import package better_exceptions')
pass
from tqdm import trange
from sacred import Experiment
import pandas as pd
ex = Experiment('dhSegment_experiment')
@ex.config
def default_config():
train_data = None # Directory with training data
eval_data = None # Directory with validation data
model_output_dir = None # Directory to output tf model
restore_model = False # Set to true to continue training
classes_file = None # txt file with classes values (unused for REGRESSION)
gpu = '' # GPU to be used for training
prediction_type = utils.PredictionType.CLASSIFICATION # One of CLASSIFICATION, REGRESSION or MULTILABEL
pretrained_model_name = 'resnet50'
model_params = utils.ModelParams(pretrained_model_name=pretrained_model_name).to_dict() # Model parameters
training_params = utils.TrainingParams().to_dict() # Training parameters
if prediction_type == utils.PredictionType.CLASSIFICATION:
assert classes_file is not None
model_params['n_classes'] = utils.get_n_classes_from_file(classes_file)
elif prediction_type == utils.PredictionType.REGRESSION:
model_params['n_classes'] = 1
elif prediction_type == utils.PredictionType.MULTILABEL:
assert classes_file is not None
model_params['n_classes'] = utils.get_n_classes_from_file_multilabel(classes_file)
@ex.automain
def run(train_data, eval_data, model_output_dir, gpu, training_params, _config):
# Create output directory
if not os.path.isdir(model_output_dir):
os.makedirs(model_output_dir)
else:
assert _config.get('restore_model'), \
'{0} already exists, you cannot use it as output directory. ' \
'Set "restore_model=True" to continue training, or delete dir "rm -r {0}"'.format(model_output_dir)
# Save config
with open(os.path.join(model_output_dir, 'config.json'), 'w') as f:
json.dump(_config, f, indent=4, sort_keys=True)
# Create export directory for saved models
saved_model_dir = os.path.join(model_output_dir, 'export')
if not os.path.isdir(saved_model_dir):
os.makedirs(saved_model_dir)
training_params = utils.TrainingParams.from_dict(training_params)
session_config = tf.ConfigProto()
session_config.gpu_options.visible_device_list = str(gpu)
session_config.gpu_options.per_process_gpu_memory_fraction = 0.9
estimator_config = tf.estimator.RunConfig().replace(session_config=session_config,
save_summary_steps=10,
keep_checkpoint_max=1)
estimator = tf.estimator.Estimator(estimator_fn.model_fn, model_dir=model_output_dir,
params=_config, config=estimator_config)
def get_dirs_or_files(input_data):
if os.path.isdir(input_data):
image_input, labels_input = os.path.join(input_data, 'images'), os.path.join(input_data, 'labels')
# Check if training dir exists
assert os.path.isdir(image_input), "{} is not a directory".format(image_input)
assert os.path.isdir(labels_input), "{} is not a directory".format(labels_input)
elif os.path.isfile(input_data) and input_data.endswith('.csv'):
image_input = input_data
labels_input = None
else:
raise TypeError('input_data {} is neither a directory nor a csv file'.format(input_data))
return image_input, labels_input
train_input, train_labels_input = get_dirs_or_files(train_data)
if eval_data is not None:
eval_input, eval_labels_input = get_dirs_or_files(eval_data)
# Configure exporter
serving_input_fn = input.serving_input_filename(training_params.input_resized_size)
if eval_data is not None:
exporter = tf.estimator.BestExporter(serving_input_receiver_fn=serving_input_fn, exports_to_keep=2)
else:
exporter = tf.estimator.LatestExporter(name='SimpleExporter', serving_input_receiver_fn=serving_input_fn,
exports_to_keep=5)
for i in trange(0, training_params.n_epochs, training_params.evaluate_every_epoch, desc='Evaluated epochs'):
estimator.train(input.input_fn(train_input,
input_label_dir=train_labels_input,
num_epochs=training_params.evaluate_every_epoch,
batch_size=training_params.batch_size,
data_augmentation=training_params.data_augmentation,
make_patches=training_params.make_patches,
image_summaries=True,
params=_config,
num_threads=32))
if eval_data is not None:
eval_result = estimator.evaluate(input.input_fn(eval_input,
input_label_dir=eval_labels_input,
batch_size=1,
data_augmentation=False,
make_patches=False,
image_summaries=False,
params=_config,
num_threads=32))
else:
eval_result = None
exporter.export(estimator, saved_model_dir, checkpoint_path=None, eval_result=eval_result,
is_the_final_export=False)
# If export directory is empty, export a model anyway
if not os.listdir(saved_model_dir):
final_exporter = tf.estimator.FinalExporter(name='FinalExporter', serving_input_receiver_fn=serving_input_fn)
final_exporter.export(estimator, saved_model_dir, checkpoint_path=None, eval_result=eval_result,
is_the_final_export=True)