Skip to content

Commit 0272d6c

Browse files
author
Shashank Tyagi
committed
add SPN
1 parent df52965 commit 0272d6c

File tree

6 files changed

+440
-0
lines changed

6 files changed

+440
-0
lines changed
Binary file not shown.

with SPN/main.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import tensorflow as tf
2+
import os
3+
from model import *
4+
5+
weights_path = '/Users/shashank/Tensorflow/SPN/weights/'
6+
imgs_path = '/Users/shashank/Tensorflow/CSE252C-Hyperface/git/truth_data.npy'
7+
8+
if not os.path.exists('./logs'):
9+
os.makedirs('./logs')
10+
11+
map(os.unlink, (os.path.join( './logs',f) for f in os.listdir('./logs')) )
12+
13+
14+
with tf.Session() as sess:
15+
print 'Building Graph...'
16+
model = Network(sess)
17+
print 'Graph Built!'
18+
sess.run(tf.global_variables_initializer())
19+
model.train()

with SPN/model.py

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
import tensorflow as tf
2+
import tensorflow.contrib.slim as slim
3+
import numpy as np
4+
from spatial_transformer import transformer
5+
from tqdm import tqdm
6+
from pdb import set_trace as brk
7+
8+
class Network(object):
9+
10+
def __init__(self, sess):
11+
12+
self.sess = sess
13+
self.batch_size = 2
14+
self.img_height = 227
15+
self.img_width = 227
16+
self.out_height = 200
17+
self.out_width = 200
18+
self.channel = 3
19+
20+
self.num_epochs = 10
21+
22+
# Hyperparameters
23+
self.weight_detect = 1
24+
self.weight_landmarks = 5
25+
self.weight_visibility = 0.5
26+
self.weight_pose = 5
27+
self.weight_gender = 2
28+
29+
self.build_network()
30+
31+
32+
def build_network(self):
33+
34+
self.X = tf.placeholder(tf.float32, [self.batch_size, self.img_height, self.img_width, self.channel], name='images')
35+
self.detection = tf.placeholder(tf.float32, [self.batch_size,2], name='detection')
36+
self.landmarks = tf.placeholder(tf.float32, [self.batch_size, 42], name='landmarks')
37+
self.visibility = tf.placeholder(tf.float32, [self.batch_size,21], name='visibility')
38+
self.pose = tf.placeholder(tf.float32, [self.batch_size,3], name='pose')
39+
self.gender = tf.placeholder(tf.float32, [self.batch_size,2], name='gender')
40+
41+
42+
theta = self.localization_network(self.X)
43+
44+
T_mat = self.get_transformation_matrix(theta)
45+
46+
cropped = transformer(self.X, T_mat, [self.out_height, self.out_width])
47+
48+
net_output = self.hyperface(cropped) # (out_detection, out_landmarks, out_visibility, out_pose, out_gender)
49+
50+
51+
loss_detection = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(net_output[0], self.detection))
52+
53+
visibility_mask = tf.reshape(tf.tile(tf.expand_dims(self.visibility, axis=2), [1,1,2]), [self.batch_size, -1])
54+
loss_landmarks = tf.reduce_mean(tf.square(visibility_mask*(net_output[1] - self.landmarks)))
55+
56+
loss_visibility = tf.reduce_mean(tf.square(net_output[2] - self.visibility))
57+
loss_pose = tf.reduce_mean(tf.square(net_output[3] - self.pose))
58+
loss_gender = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(net_output[4], self.gender))
59+
60+
self.loss = self.weight_detect*loss_detection + self.weight_landmarks*loss_landmarks \
61+
+ self.weight_visibility*loss_visibility + self.weight_pose*loss_pose \
62+
+ self.weight_gender*loss_gender
63+
64+
65+
66+
def get_transformation_matrix(self, theta):
67+
with tf.name_scope('T_matrix'):
68+
theta = tf.expand_dims(theta, 2)
69+
mat = tf.constant(np.repeat(np.array([[[1,0,0],[0,0,0],[0,1,0],[0,0,0],[0,1,0],[0,0,1]]]),
70+
self.batch_size, axis=0), dtype=tf.float32)
71+
tr_matrix = tf.squeeze(tf.matmul(mat, theta))
72+
73+
return tr_matrix
74+
75+
76+
77+
def train(self):
78+
79+
optimizer = tf.train.AdamOptimizer().minimize(self.loss)
80+
81+
writer = tf.summary.FileWriter('./logs', self.sess.graph)
82+
loss_summ = tf.summary.scalar('loss', self.loss)
83+
84+
85+
86+
87+
def hyperface(self,inputs, reuse = False):
88+
89+
if reuse:
90+
tf.get_variable_scope().reuse_variables()
91+
92+
with slim.arg_scope([slim.conv2d, slim.fully_connected],
93+
activation_fn = tf.nn.relu,
94+
weights_initializer = tf.truncated_normal_initializer(0.0, 0.01) ):
95+
96+
conv1 = slim.conv2d(inputs, 96, [11,11], 4, padding= 'VALID', scope='conv1')
97+
max1 = slim.max_pool2d(conv1, [3,3], 2, padding= 'VALID', scope='max1')
98+
99+
conv1a = slim.conv2d(max1, 256, [4,4], 4, padding= 'VALID', scope='conv1a')
100+
101+
conv2 = slim.conv2d(max1, 256, [5,5], 1, scope='conv2')
102+
max2 = slim.max_pool2d(conv2, [3,3], 2, padding= 'VALID', scope='max2')
103+
conv3 = slim.conv2d(max2, 384, [3,3], 1, scope='conv3')
104+
105+
conv3a = slim.conv2d(conv3, 256, [2,2], 2, padding= 'VALID', scope='conv3a')
106+
107+
conv4 = slim.conv2d(conv3, 384, [3,3], 1, scope='conv4')
108+
conv5 = slim.conv2d(conv4, 256, [3,3], 1, scope='conv5')
109+
pool5 = slim.max_pool2d(conv5, [3,3], 2, padding= 'VALID', scope='pool5')
110+
111+
concat_feat = tf.concat(3, [conv1a, conv3a, pool5])
112+
conv_all = slim.conv2d(concat_feat, 192, [1,1], 1, padding= 'VALID', scope='conv_all')
113+
114+
shape = int(np.prod(conv_all.get_shape()[1:]))
115+
# transposed for weight loading from chainer model
116+
fc_full = slim.fully_connected(tf.reshape(tf.transpose(conv_all, [0,3,1,2]), [-1, shape]), 3072, scope='fc_full')
117+
118+
fc_detection = slim.fully_connected(fc_full, 512, scope='fc_detection1')
119+
fc_landmarks = slim.fully_connected(fc_full, 512, scope='fc_landmarks1')
120+
fc_visibility = slim.fully_connected(fc_full, 512, scope='fc_visibility1')
121+
fc_pose = slim.fully_connected(fc_full, 512, scope='fc_pose1')
122+
fc_gender = slim.fully_connected(fc_full, 512, scope='fc_gender1')
123+
124+
out_detection = slim.fully_connected(fc_detection, 2, scope='fc_detection2', activation_fn = None)
125+
out_landmarks = slim.fully_connected(fc_landmarks, 42, scope='fc_landmarks2', activation_fn = None)
126+
out_visibility = slim.fully_connected(fc_visibility, 21, scope='fc_visibility2', activation_fn = None)
127+
out_pose = slim.fully_connected(fc_pose, 3, scope='fc_pose2', activation_fn = None)
128+
out_gender = slim.fully_connected(fc_gender, 2, scope='fc_gender2', activation_fn = None)
129+
130+
return [tf.nn.softmax(out_detection), out_landmarks, out_visibility, out_pose, tf.nn.softmax(out_gender)]
131+
132+
133+
134+
def localization_network(self,inputs): #VGG16
135+
136+
with tf.variable_scope('localization_network'):
137+
with slim.arg_scope([slim.conv2d, slim.fully_connected],
138+
activation_fn = tf.nn.relu,
139+
weights_initializer = tf.constant_initializer(0.0)):
140+
141+
net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1')
142+
net = slim.max_pool2d(net, [2, 2], scope='pool1')
143+
net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2')
144+
net = slim.max_pool2d(net, [2, 2], scope='pool2')
145+
net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], scope='conv3')
146+
net = slim.max_pool2d(net, [2, 2], scope='pool3')
147+
net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv4')
148+
net = slim.max_pool2d(net, [2, 2], scope='pool4')
149+
net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv5')
150+
net = slim.max_pool2d(net, [2, 2], scope='pool5')
151+
shape = int(np.prod(net.get_shape()[1:]))
152+
153+
net = slim.fully_connected(tf.reshape(net, [-1, shape]), 4096, scope='fc6')
154+
net = slim.fully_connected(net, 1024, scope='fc7')
155+
net = slim.fully_connected(net, 3, biases_initializer = tf.constant_initializer(1.0) , scope='fc8')
156+
157+
return net
158+
159+
160+
161+
def predict(self, imgs_path):
162+
print 'Running inference...'
163+
np.set_printoptions(suppress=True)
164+
imgs = (np.load(imgs_path) - 127.5)/128.0
165+
shape = imgs.shape
166+
self.X = tf.placeholder(tf.float32, [shape[0], self.img_height, self.img_width, self.channel], name='images')
167+
pred = self.network(self.X, reuse = True)
168+
169+
net_preds = self.sess.run(pred, feed_dict={self.X: imgs})
170+
171+
print net_preds[-1]
172+
import matplotlib.pyplot as plt
173+
plt.imshow(imgs[-1]);plt.show()
174+
175+
brk()
176+
177+
178+
179+
180+
def load_weights(self, path):
181+
variables = slim.get_model_variables()
182+
print 'Loading weights...'
183+
for var in tqdm(variables):
184+
if ('conv' in var.name) and ('weights' in var.name):
185+
self.sess.run(var.assign(np.load(path+var.name.split('/')[0]+'/W.npy').transpose((2,3,1,0))))
186+
elif ('fc' in var.name) and ('weights' in var.name):
187+
self.sess.run(var.assign(np.load(path+var.name.split('/')[0]+'/W.npy').T))
188+
elif 'biases' in var.name:
189+
self.sess.run(var.assign(np.load(path+var.name.split('/')[0]+'/b.npy')))
190+
print 'Weights loaded!!'
191+
192+
def print_variables(self):
193+
variables = slim.get_model_variables()
194+
print 'Model Variables:\n'
195+
for var in variables:
196+
print var.name, ' ', var.get_shape()
197+
198+
199+
200+

with SPN/model.pyc

8.51 KB
Binary file not shown.

0 commit comments

Comments
 (0)