-
Notifications
You must be signed in to change notification settings - Fork 28
/
trt_optimize.py
51 lines (35 loc) · 1.24 KB
/
trt_optimize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
'''
Optimize frozen TF graph and prepare a stand-alone TensorRT inference engine
'''
import tensorrt as trt
import uff
__author__ = "Dmitry Korobchenko ([email protected])"
### Settings
FROZEN_GDEF_PATH = 'data/frozen.pb' # ADJUST
ENGINE_PATH = 'data/engine.plan' # ADJUST
INPUT_NODE = 'net/input' # ADJUST
OUTPUT_NODE = 'net/fc8/BiasAdd' # ADJUST
INPUT_SIZE = [3, 224, 224] # ADJUST
MAX_BATCH_SIZE = 1 # ADJUST
MAX_WORKSPACE = 1 << 32 # ADJUST
DATA_TYPE = trt.float16 # ADJUST # float16 | float32
### Convert TF frozen graph to UFF graph
uff_model = uff.from_tensorflow_frozen_model(FROZEN_GDEF_PATH, [OUTPUT_NODE])
### Create TRT model builder
trt_logger = trt.Logger(trt.Logger.INFO)
builder = trt.Builder(trt_logger)
builder.max_batch_size = MAX_BATCH_SIZE
builder.max_workspace_size = MAX_WORKSPACE
builder.fp16_mode = (DATA_TYPE == trt.float16)
### Create UFF parser
parser = trt.UffParser()
parser.register_input(INPUT_NODE, INPUT_SIZE)
parser.register_output(OUTPUT_NODE)
### Parse UFF graph
network = builder.create_network()
parser.parse_buffer(uff_model, network)
### Build optimized inference engine
engine = builder.build_cuda_engine(network)
### Save inference engine
with open(ENGINE_PATH, "wb") as f:
f.write(engine.serialize())