|
| 1 | +import tensorflow as tf |
| 2 | +import tensorflow.contrib.slim as slim |
| 3 | +import numpy as np |
| 4 | +from spatial_transformer import transformer |
| 5 | +from tqdm import tqdm |
| 6 | +from pdb import set_trace as brk |
| 7 | + |
| 8 | +class Network(object): |
| 9 | + |
| 10 | + def __init__(self, sess): |
| 11 | + |
| 12 | + self.sess = sess |
| 13 | + self.batch_size = 2 |
| 14 | + self.img_height = 227 |
| 15 | + self.img_width = 227 |
| 16 | + self.out_height = 200 |
| 17 | + self.out_width = 200 |
| 18 | + self.channel = 3 |
| 19 | + |
| 20 | + self.num_epochs = 10 |
| 21 | + |
| 22 | + # Hyperparameters |
| 23 | + self.weight_detect = 1 |
| 24 | + self.weight_landmarks = 5 |
| 25 | + self.weight_visibility = 0.5 |
| 26 | + self.weight_pose = 5 |
| 27 | + self.weight_gender = 2 |
| 28 | + |
| 29 | + self.build_network() |
| 30 | + |
| 31 | + |
| 32 | + def build_network(self): |
| 33 | + |
| 34 | + self.X = tf.placeholder(tf.float32, [self.batch_size, self.img_height, self.img_width, self.channel], name='images') |
| 35 | + self.detection = tf.placeholder(tf.float32, [self.batch_size,2], name='detection') |
| 36 | + self.landmarks = tf.placeholder(tf.float32, [self.batch_size, 42], name='landmarks') |
| 37 | + self.visibility = tf.placeholder(tf.float32, [self.batch_size,21], name='visibility') |
| 38 | + self.pose = tf.placeholder(tf.float32, [self.batch_size,3], name='pose') |
| 39 | + self.gender = tf.placeholder(tf.float32, [self.batch_size,2], name='gender') |
| 40 | + |
| 41 | + |
| 42 | + theta = self.localization_network(self.X) |
| 43 | + |
| 44 | + T_mat = self.get_transformation_matrix(theta) |
| 45 | + |
| 46 | + cropped = transformer(self.X, T_mat, [self.out_height, self.out_width]) |
| 47 | + |
| 48 | + net_output = self.hyperface(cropped) # (out_detection, out_landmarks, out_visibility, out_pose, out_gender) |
| 49 | + |
| 50 | + |
| 51 | + loss_detection = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(net_output[0], self.detection)) |
| 52 | + |
| 53 | + visibility_mask = tf.reshape(tf.tile(tf.expand_dims(self.visibility, axis=2), [1,1,2]), [self.batch_size, -1]) |
| 54 | + loss_landmarks = tf.reduce_mean(tf.square(visibility_mask*(net_output[1] - self.landmarks))) |
| 55 | + |
| 56 | + loss_visibility = tf.reduce_mean(tf.square(net_output[2] - self.visibility)) |
| 57 | + loss_pose = tf.reduce_mean(tf.square(net_output[3] - self.pose)) |
| 58 | + loss_gender = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(net_output[4], self.gender)) |
| 59 | + |
| 60 | + self.loss = self.weight_detect*loss_detection + self.weight_landmarks*loss_landmarks \ |
| 61 | + + self.weight_visibility*loss_visibility + self.weight_pose*loss_pose \ |
| 62 | + + self.weight_gender*loss_gender |
| 63 | + |
| 64 | + |
| 65 | + |
| 66 | + def get_transformation_matrix(self, theta): |
| 67 | + with tf.name_scope('T_matrix'): |
| 68 | + theta = tf.expand_dims(theta, 2) |
| 69 | + mat = tf.constant(np.repeat(np.array([[[1,0,0],[0,0,0],[0,1,0],[0,0,0],[0,1,0],[0,0,1]]]), |
| 70 | + self.batch_size, axis=0), dtype=tf.float32) |
| 71 | + tr_matrix = tf.squeeze(tf.matmul(mat, theta)) |
| 72 | + |
| 73 | + return tr_matrix |
| 74 | + |
| 75 | + |
| 76 | + |
| 77 | + def train(self): |
| 78 | + |
| 79 | + optimizer = tf.train.AdamOptimizer().minimize(self.loss) |
| 80 | + |
| 81 | + writer = tf.summary.FileWriter('./logs', self.sess.graph) |
| 82 | + loss_summ = tf.summary.scalar('loss', self.loss) |
| 83 | + |
| 84 | + |
| 85 | + |
| 86 | + |
| 87 | + def hyperface(self,inputs, reuse = False): |
| 88 | + |
| 89 | + if reuse: |
| 90 | + tf.get_variable_scope().reuse_variables() |
| 91 | + |
| 92 | + with slim.arg_scope([slim.conv2d, slim.fully_connected], |
| 93 | + activation_fn = tf.nn.relu, |
| 94 | + weights_initializer = tf.truncated_normal_initializer(0.0, 0.01) ): |
| 95 | + |
| 96 | + conv1 = slim.conv2d(inputs, 96, [11,11], 4, padding= 'VALID', scope='conv1') |
| 97 | + max1 = slim.max_pool2d(conv1, [3,3], 2, padding= 'VALID', scope='max1') |
| 98 | + |
| 99 | + conv1a = slim.conv2d(max1, 256, [4,4], 4, padding= 'VALID', scope='conv1a') |
| 100 | + |
| 101 | + conv2 = slim.conv2d(max1, 256, [5,5], 1, scope='conv2') |
| 102 | + max2 = slim.max_pool2d(conv2, [3,3], 2, padding= 'VALID', scope='max2') |
| 103 | + conv3 = slim.conv2d(max2, 384, [3,3], 1, scope='conv3') |
| 104 | + |
| 105 | + conv3a = slim.conv2d(conv3, 256, [2,2], 2, padding= 'VALID', scope='conv3a') |
| 106 | + |
| 107 | + conv4 = slim.conv2d(conv3, 384, [3,3], 1, scope='conv4') |
| 108 | + conv5 = slim.conv2d(conv4, 256, [3,3], 1, scope='conv5') |
| 109 | + pool5 = slim.max_pool2d(conv5, [3,3], 2, padding= 'VALID', scope='pool5') |
| 110 | + |
| 111 | + concat_feat = tf.concat(3, [conv1a, conv3a, pool5]) |
| 112 | + conv_all = slim.conv2d(concat_feat, 192, [1,1], 1, padding= 'VALID', scope='conv_all') |
| 113 | + |
| 114 | + shape = int(np.prod(conv_all.get_shape()[1:])) |
| 115 | + # transposed for weight loading from chainer model |
| 116 | + fc_full = slim.fully_connected(tf.reshape(tf.transpose(conv_all, [0,3,1,2]), [-1, shape]), 3072, scope='fc_full') |
| 117 | + |
| 118 | + fc_detection = slim.fully_connected(fc_full, 512, scope='fc_detection1') |
| 119 | + fc_landmarks = slim.fully_connected(fc_full, 512, scope='fc_landmarks1') |
| 120 | + fc_visibility = slim.fully_connected(fc_full, 512, scope='fc_visibility1') |
| 121 | + fc_pose = slim.fully_connected(fc_full, 512, scope='fc_pose1') |
| 122 | + fc_gender = slim.fully_connected(fc_full, 512, scope='fc_gender1') |
| 123 | + |
| 124 | + out_detection = slim.fully_connected(fc_detection, 2, scope='fc_detection2', activation_fn = None) |
| 125 | + out_landmarks = slim.fully_connected(fc_landmarks, 42, scope='fc_landmarks2', activation_fn = None) |
| 126 | + out_visibility = slim.fully_connected(fc_visibility, 21, scope='fc_visibility2', activation_fn = None) |
| 127 | + out_pose = slim.fully_connected(fc_pose, 3, scope='fc_pose2', activation_fn = None) |
| 128 | + out_gender = slim.fully_connected(fc_gender, 2, scope='fc_gender2', activation_fn = None) |
| 129 | + |
| 130 | + return [tf.nn.softmax(out_detection), out_landmarks, out_visibility, out_pose, tf.nn.softmax(out_gender)] |
| 131 | + |
| 132 | + |
| 133 | + |
| 134 | + def localization_network(self,inputs): #VGG16 |
| 135 | + |
| 136 | + with tf.variable_scope('localization_network'): |
| 137 | + with slim.arg_scope([slim.conv2d, slim.fully_connected], |
| 138 | + activation_fn = tf.nn.relu, |
| 139 | + weights_initializer = tf.constant_initializer(0.0)): |
| 140 | + |
| 141 | + net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1') |
| 142 | + net = slim.max_pool2d(net, [2, 2], scope='pool1') |
| 143 | + net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2') |
| 144 | + net = slim.max_pool2d(net, [2, 2], scope='pool2') |
| 145 | + net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], scope='conv3') |
| 146 | + net = slim.max_pool2d(net, [2, 2], scope='pool3') |
| 147 | + net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv4') |
| 148 | + net = slim.max_pool2d(net, [2, 2], scope='pool4') |
| 149 | + net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv5') |
| 150 | + net = slim.max_pool2d(net, [2, 2], scope='pool5') |
| 151 | + shape = int(np.prod(net.get_shape()[1:])) |
| 152 | + |
| 153 | + net = slim.fully_connected(tf.reshape(net, [-1, shape]), 4096, scope='fc6') |
| 154 | + net = slim.fully_connected(net, 1024, scope='fc7') |
| 155 | + net = slim.fully_connected(net, 3, biases_initializer = tf.constant_initializer(1.0) , scope='fc8') |
| 156 | + |
| 157 | + return net |
| 158 | + |
| 159 | + |
| 160 | + |
| 161 | + def predict(self, imgs_path): |
| 162 | + print 'Running inference...' |
| 163 | + np.set_printoptions(suppress=True) |
| 164 | + imgs = (np.load(imgs_path) - 127.5)/128.0 |
| 165 | + shape = imgs.shape |
| 166 | + self.X = tf.placeholder(tf.float32, [shape[0], self.img_height, self.img_width, self.channel], name='images') |
| 167 | + pred = self.network(self.X, reuse = True) |
| 168 | + |
| 169 | + net_preds = self.sess.run(pred, feed_dict={self.X: imgs}) |
| 170 | + |
| 171 | + print net_preds[-1] |
| 172 | + import matplotlib.pyplot as plt |
| 173 | + plt.imshow(imgs[-1]);plt.show() |
| 174 | + |
| 175 | + brk() |
| 176 | + |
| 177 | + |
| 178 | + |
| 179 | + |
| 180 | + def load_weights(self, path): |
| 181 | + variables = slim.get_model_variables() |
| 182 | + print 'Loading weights...' |
| 183 | + for var in tqdm(variables): |
| 184 | + if ('conv' in var.name) and ('weights' in var.name): |
| 185 | + self.sess.run(var.assign(np.load(path+var.name.split('/')[0]+'/W.npy').transpose((2,3,1,0)))) |
| 186 | + elif ('fc' in var.name) and ('weights' in var.name): |
| 187 | + self.sess.run(var.assign(np.load(path+var.name.split('/')[0]+'/W.npy').T)) |
| 188 | + elif 'biases' in var.name: |
| 189 | + self.sess.run(var.assign(np.load(path+var.name.split('/')[0]+'/b.npy'))) |
| 190 | + print 'Weights loaded!!' |
| 191 | + |
| 192 | + def print_variables(self): |
| 193 | + variables = slim.get_model_variables() |
| 194 | + print 'Model Variables:\n' |
| 195 | + for var in variables: |
| 196 | + print var.name, ' ', var.get_shape() |
| 197 | + |
| 198 | + |
| 199 | + |
| 200 | + |
0 commit comments