Skip to content

Commit cc7991d

Browse files
Shashank TyagiShashank Tyagi
Shashank Tyagi
authored and
Shashank Tyagi
committed
add more models
1 parent 0272d6c commit cc7991d

File tree

1 file changed

+72
-11
lines changed

1 file changed

+72
-11
lines changed

with SPN/model.py

+72-11
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ def __init__(self, sess):
1111

1212
self.sess = sess
1313
self.batch_size = 2
14-
self.img_height = 227
15-
self.img_width = 227
16-
self.out_height = 200
17-
self.out_width = 200
14+
self.img_height = 500
15+
self.img_width = 500
16+
self.out_height = 227
17+
self.out_width = 227
1818
self.channel = 3
1919

2020
self.num_epochs = 10
@@ -39,13 +39,13 @@ def build_network(self):
3939
self.gender = tf.placeholder(tf.float32, [self.batch_size,2], name='gender')
4040

4141

42-
theta = self.localization_network(self.X)
42+
theta = self.localization_squeezenet(self.X)
4343

44-
T_mat = self.get_transformation_matrix(theta)
45-
46-
cropped = transformer(self.X, T_mat, [self.out_height, self.out_width])
44+
self.T_mat = tf.reshape(theta, [-1, 2,3])
45+
46+
self.cropped = transformer(self.X, self.T_mat, [self.out_height, self.out_width])
4747

48-
net_output = self.hyperface(cropped) # (out_detection, out_landmarks, out_visibility, out_pose, out_gender)
48+
net_output = self.hyperface(self.cropped) # (out_detection, out_landmarks, out_visibility, out_pose, out_gender)
4949

5050

5151
loss_detection = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(net_output[0], self.detection))
@@ -80,7 +80,9 @@ def train(self):
8080

8181
writer = tf.summary.FileWriter('./logs', self.sess.graph)
8282
loss_summ = tf.summary.scalar('loss', self.loss)
83+
img_summ = tf.summary.image('cropped_image', self.cropped)
8384

85+
print self.sess.run(self.T_mat, feed_dict={self.X: np.random.randn(self.batch_size, self.img_height, self.img_width, self.channel)})
8486

8587

8688

@@ -131,7 +133,7 @@ def hyperface(self,inputs, reuse = False):
131133

132134

133135

134-
def localization_network(self,inputs): #VGG16
136+
def localization_VGG16(self,inputs):
135137

136138
with tf.variable_scope('localization_network'):
137139
with slim.arg_scope([slim.conv2d, slim.fully_connected],
@@ -152,11 +154,70 @@ def localization_network(self,inputs): #VGG16
152154

153155
net = slim.fully_connected(tf.reshape(net, [-1, shape]), 4096, scope='fc6')
154156
net = slim.fully_connected(net, 1024, scope='fc7')
155-
net = slim.fully_connected(net, 3, biases_initializer = tf.constant_initializer(1.0) , scope='fc8')
157+
identity = np.array([[1., 0., 0.],
158+
[0., 1., 0.]])
159+
identity = identity.flatten()
160+
net = slim.fully_connected(net, 6, biases_initializer = tf.constant_initializer(identity) , scope='fc8')
156161

157162
return net
158163

159164

165+
def localization_squeezenet(self, inputs):
166+
167+
with tf.variable_scope('localization_network'):
168+
with slim.arg_scope([slim.conv2d], activation_fn = tf.nn.relu,
169+
padding = 'SAME',
170+
weights_initializer = tf.constant_initializer(0.0)):
171+
172+
conv1 = slim.conv2d(inputs, 64, [3,3], 2, padding = 'VALID', scope='conv1')
173+
pool1 = slim.max_pool2d(conv1, [2,2], 2, scope='pool1')
174+
fire2 = self.fire_module(pool1, 16, 64, scope = 'fire2')
175+
fire3 = self.fire_module(fire2, 16, 64, scope = 'fire3', res_connection=True)
176+
fire4 = self.fire_module(fire3, 32, 128, scope = 'fire4')
177+
pool4 = slim.max_pool2d(fire4, [2,2], 2, scope='pool4')
178+
fire5 = self.fire_module(pool4, 32, 128, scope = 'fire5', res_connection=True)
179+
fire6 = self.fire_module(fire5, 48, 192, scope = 'fire6')
180+
fire7 = self.fire_module(fire6, 48, 192, scope = 'fire7', res_connection=True)
181+
fire8 = self.fire_module(fire7, 64, 256, scope = 'fire8')
182+
pool8 = slim.max_pool2d(fire8, [2,2], 2, scope='pool8')
183+
fire9 = self.fire_module(pool8, 64, 256, scope = 'fire9', res_connection=True)
184+
conv10 = slim.conv2d(fire9, 128, [1,1], 1, scope='conv10')
185+
shape = int(np.prod(conv10.get_shape()[1:]))
186+
fc11 = slim.fully_connected(tf.reshape(conv10, [-1, shape]), 6, biases_initializer = tf.constant_initializer(np.array([[1., 0., 0.],
187+
[0., 1., 0.]])) , scope='fc11')
188+
return fc11
189+
190+
191+
def fire_module(self, inputs, s_channels, e_channels, scope, res_connection = False):
192+
with tf.variable_scope(scope):
193+
sq = self.squeeze(inputs, s_channels, 'squeeze')
194+
ex = self.expand(sq, e_channels, 'expand')
195+
if res_connection:
196+
ret = tf.nn.relu(tf.add(inputs,ex))
197+
else:
198+
ret = tf.nn.relu(ex)
199+
return ret
200+
201+
202+
def squeeze(self, inputs, channels, scope):
203+
with slim.arg_scope([slim.conv2d], activation_fn = None,
204+
padding = 'SAME',
205+
weights_initializer = tf.truncated_normal_initializer(0.0, 0.01)):
206+
sq = slim.conv2d(inputs, channels, [1,1], 1, scope = scope)
207+
return sq
208+
209+
def expand(self, inputs, channels, scope):
210+
with slim.arg_scope([slim.conv2d], activation_fn = None,
211+
padding = 'SAME',
212+
weights_initializer = tf.truncated_normal_initializer(0.0, 0.01)):
213+
with tf.variable_scope(scope):
214+
e1x1 = slim.conv2d(inputs, channels, [1,1], 1, scope='e1x1')
215+
e3x3 = slim.conv2d(inputs, channels, [3,3], 1, scope='e3x3')
216+
expand = tf.concat(3, [e1x1, e3x3])
217+
218+
return expand
219+
220+
160221

161222
def predict(self, imgs_path):
162223
print 'Running inference...'

0 commit comments

Comments
 (0)