-
Notifications
You must be signed in to change notification settings - Fork 28
/
tf_optimize.py
52 lines (38 loc) · 1.45 KB
/
tf_optimize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
'''
Use TensorRT+TensorFlow to build a new TF graph with optimized TRT-based subgraphs
'''
import os
import tensorflow as tf
import tensorflow.contrib.tensorrt as trt
from tensorflow.python.framework import graph_io
__author__ = "Dmitry Korobchenko ([email protected])"
### Settings
FROZEN_GDEF_PATH = 'data/frozen.pb' # ADJUST
TRT_GDEF_PATH = 'data/frozen_trt.pb' # ADJUST
OUTPUT_NODE = 'net/fc8/BiasAdd' # ADJUST
MAX_BATCH_SIZE = 1 # ADJUST
MAX_WORKSPACE = 1 << 32 # ADJUST
DATA_TYPE = 'FP16' # ADJUST # 'FP16' | 'FP32'
EXPORT_FOR_TENSORBOARD = False # ADJUST
### Load frozen TF graph
graphdef_frozen = tf.GraphDef()
with tf.gfile.GFile(FROZEN_GDEF_PATH, "rb") as f:
graphdef_frozen.ParseFromString(f.read())
### Build new graph with optimized TensorRT nodes
graphdef_trt = trt.create_inference_graph(
input_graph_def=graphdef_frozen,
outputs=[OUTPUT_NODE],
max_batch_size=MAX_BATCH_SIZE,
max_workspace_size_bytes=MAX_WORKSPACE,
precision_mode=DATA_TYPE)
### Save new TensorRT graph
os.makedirs(os.path.dirname(TRT_GDEF_PATH), exist_ok=True)
graph_io.write_graph(graphdef_trt, './', TRT_GDEF_PATH, as_text=False)
### List frozen nodes
print([x.name for x in graphdef_trt.node])
### Export new graph for visualization in Tensorboard
if EXPORT_FOR_TENSORBOARD:
graph_trt = tf.Graph()
with graph_trt.as_default():
tf.import_graph_def(graphdef_trt)
_=tf.summary.FileWriter('data/checkpoints/vggA_BN_frozen_trt/', graph_trt)