Skip to content

Commit 2760af1

Browse files
authored
Merge pull request #57 from arcelien/master
Add freezing functionality
2 parents f3aed74 + 843d811 commit 2760af1

File tree

5 files changed

+150
-5
lines changed

5 files changed

+150
-5
lines changed

freeze.py

+143
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
#!/usr/bin/env python3
2+
from argparse import ArgumentParser
3+
from importlib import import_module
4+
import time
5+
6+
import numpy as np
7+
import tensorflow as tf
8+
from tensorflow.python.framework import graph_io
9+
10+
import common
11+
from nets import NET_CHOICES
12+
from heads import HEAD_CHOICES
13+
14+
15+
parser = ArgumentParser(description='Train a ReID network.')
16+
17+
parser.add_argument(
18+
'--checkpoint_name', default='market1501_weights/checkpoint-25000', type=common.readable_directory,
19+
help='Location of checkpoint to freeze.')
20+
21+
parser.add_argument(
22+
'--frozen_model_path', default='./encoder_trinet.pb', type=common.writeable_directory,
23+
help='Location to save or load frozen model.')
24+
25+
parser.add_argument(
26+
'--model_name', default='resnet_v1_50', choices=NET_CHOICES,
27+
help='Name of the model to use.')
28+
29+
parser.add_argument(
30+
'--head_name', default='fc1024_normalize', choices=HEAD_CHOICES,
31+
help='Name of the head to use.')
32+
33+
parser.add_argument(
34+
'--embedding_dim', default=128, type=common.positive_int,
35+
help='Dimensionality of the embedding space.')
36+
37+
parser.add_argument(
38+
'--net_input_height', default=256, type=common.positive_int,
39+
help='Height of the input directly fed into the network.')
40+
41+
parser.add_argument(
42+
'--net_input_width', default=128, type=common.positive_int,
43+
help='Width of the input directly fed into the network.')
44+
45+
parser.add_argument(
46+
'--save_graph', action='store_true', default=False,
47+
help='Whether to save frozen graph for visualization.')
48+
49+
parser.add_argument(
50+
'--load', action='store_true', default=False,
51+
help='Whether to load frozen model after saving and benchmark.')
52+
53+
parser.add_argument(
54+
'--batch_size', default=16, type=common.positive_int,
55+
help='Batch size of dummy data input.')
56+
57+
parser.add_argument(
58+
'--runs', default=100, type=common.positive_int,
59+
help='Number of passes through the network to check speed.')
60+
61+
62+
def save(args):
63+
"""
64+
Freezes a model checkpoint into a tensorflow pb file.
65+
Default parameters assume using provided tensorflow checkpoint extracted in root directory.
66+
Input node name: "input"
67+
Output node name: "head/out_emb"
68+
"""
69+
images = tf.placeholder(tf.float32, shape=(
70+
None, args.net_input_height, args.net_input_width, 3), name='input')
71+
72+
model = import_module('nets.' + args.model_name)
73+
head = import_module('heads.' + args.head_name)
74+
75+
endpoints, body_prefix = model.endpoints(images, is_training=False)
76+
with tf.name_scope('head'):
77+
endpoints = head.head(endpoints, args.embedding_dim, is_training=False)
78+
79+
with tf.Session() as sess:
80+
tf.train.Saver().restore(sess, args.checkpoint_name)
81+
output_node_names = ['head/out_emb']
82+
83+
if args.save_graph:
84+
summary_writer = tf.summary.FileWriter(logdir='./logs/')
85+
summary_writer.add_graph(graph=sess.graph)
86+
print('saved graph')
87+
88+
output_graph_def = tf.graph_util.convert_variables_to_constants(
89+
sess,
90+
tf.get_default_graph().as_graph_def(),
91+
output_node_names
92+
)
93+
with tf.gfile.GFile(args.frozen_model_path, 'wb') as f:
94+
f.write(output_graph_def.SerializeToString())
95+
print('{} ops in the frozen graph.'.format(len(output_graph_def.node)))
96+
97+
98+
def load(args):
99+
"""
100+
Check that a frozen model can be loaded correctly.
101+
Runs speed and memory benchmark.
102+
"""
103+
# check memory usage of model with session config
104+
config = tf.ConfigProto()
105+
# config.gpu_options.per_process_gpu_memory_fraction = 0.1
106+
config.gpu_options.allow_growth = True
107+
108+
with tf.Session(graph=tf.Graph(), config=config) as sess:
109+
output_graph_def = tf.GraphDef()
110+
with open(args.frozen_model_path, "rb") as f:
111+
output_graph_def.ParseFromString(f.read())
112+
tf.import_graph_def(output_graph_def, name='')
113+
print('{} ops in the frozen graph.'.format(len(output_graph_def.node)))
114+
115+
in_img = sess.graph.get_tensor_by_name('input:0')
116+
emb = sess.graph.get_tensor_by_name('head/out_emb:0')
117+
118+
# benchmark speed with given batch_size
119+
img_data = np.zeros(
120+
(args.batch_size, args.net_input_height, args.net_input_width, 3))
121+
t = time.time()
122+
total_time = 0
123+
for i in range(args.runs):
124+
_ = sess.run(emb, feed_dict={in_img: img_data})
125+
took = time.time() - t
126+
total_time += took
127+
print('runs per second: {:.2f}, time per run: {:.5f}'.format(
128+
1/took, took))
129+
t = time.time()
130+
print('averaged runs per second: {:.2f}, averaged time per run: {:.5f}'.format(
131+
args.runs/total_time, total_time/args.runs))
132+
133+
134+
def main():
135+
args = parser.parse_args()
136+
if not args.load:
137+
save(args)
138+
else:
139+
load(args)
140+
141+
142+
if __name__ == '__main__':
143+
main()

heads/direct.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
from tensorflow.contrib import slim
33

44
def head(endpoints, embedding_dim, is_training):
5-
endpoints['emb'] = endpoints['emb_raw'] = slim.fully_connected(
5+
endpoints['emb_raw'] = slim.fully_connected(
66
endpoints['model_output'], embedding_dim, activation_fn=None,
77
weights_initializer=tf.orthogonal_initializer(), scope='emb')
8+
endpoints['emb'] = tf.identity(endpoints['emb_raw'], name="out_emb")
89

910
return endpoints

heads/direct_normalize.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,6 @@ def head(endpoints, embedding_dim, is_training):
55
endpoints['emb_raw'] = slim.fully_connected(
66
endpoints['model_output'], embedding_dim, activation_fn=None,
77
weights_initializer=tf.orthogonal_initializer(), scope='emb')
8-
endpoints['emb'] = tf.nn.l2_normalize(endpoints['emb_raw'], -1)
8+
endpoints['emb'] = tf.nn.l2_normalize(endpoints['emb_raw'], -1, name="out_emb")
99

1010
return endpoints

heads/fc1024.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@ def head(endpoints, embedding_dim, is_training):
1212
'updates_collections': tf.GraphKeys.UPDATE_OPS,
1313
})
1414

15-
endpoints['emb'] = endpoints['emb_raw'] = slim.fully_connected(
15+
endpoints['emb_raw'] = slim.fully_connected(
1616
endpoints['head_output'], embedding_dim, activation_fn=None,
1717
weights_initializer=tf.orthogonal_initializer(), scope='emb')
18-
18+
endpoints['emb'] = tf.identity(endpoints['emb_raw'], name="out_emb")
19+
1920
return endpoints

heads/fc1024_normalize.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,6 @@ def head(endpoints, embedding_dim, is_training):
1515
endpoints['emb_raw'] = slim.fully_connected(
1616
endpoints['head_output'], embedding_dim, activation_fn=None,
1717
weights_initializer=tf.orthogonal_initializer(), scope='emb')
18-
endpoints['emb'] = tf.nn.l2_normalize(endpoints['emb_raw'], -1)
18+
endpoints['emb'] = tf.nn.l2_normalize(endpoints['emb_raw'], -1, name="out_emb")
1919

2020
return endpoints

0 commit comments

Comments
 (0)