Skip to content

Commit d1ef5fa

Browse files
committed
add graph and trt conversion
1 parent d9431b5 commit d1ef5fa

3 files changed

+29
-0
lines changed

config/convert_tf_to_trt.py

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Import TensorFlow and TensorRT
2+
import tensorflow as tf
3+
import tensorflow.contrib.tensorrt as trt
4+
# Inference with TF-TRT frozen graph workflow:
5+
6+
import sys, os
7+
8+
graph_name = sys.argv[1]
9+
10+
graph = tf.Graph()
11+
with graph.as_default():
12+
with tf.Session() as sess:
13+
# First deserialize your frozen graph:
14+
with tf.gfile.GFile(sys.argv[1], 'rb') as f:
15+
graph_def = tf.GraphDef()
16+
graph_def.ParseFromString(f.read())
17+
# Now you can create a TensorRT inference graph from your
18+
# frozen graph:
19+
trt_graph = trt.create_inference_graph(
20+
input_graph_def=graph_def,
21+
outputs=['network/output/Argmax', 'network/upscore_8s/upscore8/upscore8/BiasAdd'],
22+
max_batch_size=1,
23+
max_workspace_size_bytes=2500000000,
24+
precision_mode='FP16')
25+
# Import the TensorRT graph into a new graph and run:
26+
output_node = tf.import_graph_def(
27+
trt_graph,
28+
return_elements=['network/output/Argmax', 'network/upscore_8s/upscore8/upscore8/BiasAdd' ])
29+
sess.run(output_node)
13.2 MB
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)