1
1
import tensorflow as tf
2
2
import tensorflow .contrib .slim as slim
3
3
import numpy as np
4
+ from tqdm import tqdm
5
+ from pdb import set_trace as brk
4
6
5
7
6
8
class HyperFace (object ):
@@ -62,7 +64,10 @@ def train(self):
62
64
writer = tf .summary .FileWriter ('./logs' , self .sess .graph )
63
65
loss_summ = tf .summary .scalar ('loss' , self .loss )
64
66
65
- def network (self ,inputs ):
67
+ def network (self ,inputs ,reuse = False ):
68
+
69
+ if reuse :
70
+ tf .get_variable_scope ().reuse_variables ()
66
71
67
72
with slim .arg_scope ([slim .conv2d , slim .fully_connected ],
68
73
activation_fn = tf .nn .relu ,
@@ -87,21 +92,40 @@ def network(self,inputs):
87
92
conv_all = slim .conv2d (concat_feat , 192 , [1 ,1 ], 1 , padding = 'VALID' , scope = 'conv_all' )
88
93
89
94
shape = int (np .prod (conv_all .get_shape ()[1 :]))
90
- fc_full = slim .fully_connected (tf .reshape (conv_all , [- 1 , shape ]), 3072 , scope = 'fc_full' )
95
+ fc_full = slim .fully_connected (tf .reshape (tf .transpose (conv_all , [0 ,3 ,1 ,2 ]), [- 1 , shape ]), 3072 , scope = 'fc_full' )
96
+
97
+ fc_detection = slim .fully_connected (fc_full , 512 , scope = 'fc_detection1' )
98
+ fc_landmarks = slim .fully_connected (fc_full , 512 , scope = 'fc_landmarks1' )
99
+ fc_visibility = slim .fully_connected (fc_full , 512 , scope = 'fc_visibility1' )
100
+ fc_pose = slim .fully_connected (fc_full , 512 , scope = 'fc_pose1' )
101
+ fc_gender = slim .fully_connected (fc_full , 512 , scope = 'fc_gender1' )
102
+
103
+ out_detection = slim .fully_connected (fc_detection , 2 , scope = 'fc_detection2' , activation_fn = None )
104
+ out_landmarks = slim .fully_connected (fc_landmarks , 42 , scope = 'fc_landmarks2' , activation_fn = None )
105
+ out_visibility = slim .fully_connected (fc_visibility , 21 , scope = 'fc_visibility2' , activation_fn = None )
106
+ out_pose = slim .fully_connected (fc_pose , 3 , scope = 'fc_pose2' , activation_fn = None )
107
+ out_gender = slim .fully_connected (fc_gender , 2 , scope = 'fc_gender2' , activation_fn = None )
108
+
109
+ return [tf .nn .softmax (out_detection ), out_landmarks , out_visibility , out_pose , tf .nn .softmax (out_gender )]
110
+
91
111
92
- fc_detection = slim .fully_connected (fc_full , 512 , scope = 'fc_detection' )
93
- fc_landmarks = slim .fully_connected (fc_full , 512 , scope = 'fc_landmarks' )
94
- fc_visibility = slim .fully_connected (fc_full , 512 , scope = 'fc_visibility' )
95
- fc_pose = slim .fully_connected (fc_full , 512 , scope = 'fc_pose' )
96
- fc_gender = slim .fully_connected (fc_full , 512 , scope = 'fc_gender' )
97
112
98
- out_detection = slim .fully_connected (fc_detection , 2 , scope = 'out_detection' )
99
- out_landmarks = slim .fully_connected (fc_landmarks , 42 , scope = 'out_landmarks' )
100
- out_visibility = slim .fully_connected (fc_visibility , 21 , scope = 'out_visibility' )
101
- out_pose = slim .fully_connected (fc_pose , 3 , scope = 'out_pose' )
102
- out_gender = slim .fully_connected (fc_gender , 2 , scope = 'out_gender' )
113
+ def predict (self , imgs_path ):
114
+ print 'Running inference...'
115
+ np .set_printoptions (suppress = True )
116
+ imgs = (np .load (imgs_path ) - 127.5 )/ 128.0
117
+ shape = imgs .shape
118
+ self .X = tf .placeholder (tf .float32 , [shape [0 ], self .img_height , self .img_width , self .channel ], name = 'images' )
119
+ pred = self .network (self .X , reuse = True )
120
+
121
+ net_preds = self .sess .run (pred , feed_dict = {self .X : imgs })
122
+
123
+ print 'gender: \n ' , net_preds [- 1 ]
124
+ import matplotlib .pyplot as plt
125
+ plt .imshow (imgs [- 1 ]);plt .show ()
126
+
127
+ brk ()
103
128
104
- return [out_detection , out_landmarks , out_visibility , out_pose , out_gender ]
105
129
106
130
def load_from_tfRecord (self ,filename_queue ):
107
131
@@ -129,6 +153,18 @@ def load_from_tfRecord(self,filename_queue):
129
153
130
154
return images
131
155
156
+ def load_weights (self , path ):
157
+ variables = slim .get_model_variables ()
158
+ print 'Loading weights...'
159
+ for var in tqdm (variables ):
160
+ if ('conv' in var .name ) and ('weights' in var .name ):
161
+ self .sess .run (var .assign (np .load (path + var .name .split ('/' )[0 ]+ '/W.npy' ).transpose ((2 ,3 ,1 ,0 ))))
162
+ elif ('fc' in var .name ) and ('weights' in var .name ):
163
+ self .sess .run (var .assign (np .load (path + var .name .split ('/' )[0 ]+ '/W.npy' ).T ))
164
+ elif 'biases' in var .name :
165
+ self .sess .run (var .assign (np .load (path + var .name .split ('/' )[0 ]+ '/b.npy' )))
166
+ print 'Weights loaded!!'
167
+
132
168
def print_variables (self ):
133
169
variables = slim .get_model_variables ()
134
170
print 'Model Variables:'
0 commit comments