5
5
6
6
class HyperFace (object ):
7
7
8
- def __init__ (self , sess ):
8
+ def __init__ (self , sess , tf_record_file_path = None ):
9
9
10
10
self .sess = sess
11
11
self .batch_size = 2
@@ -22,6 +22,13 @@ def __init__(self, sess):
22
22
self .weight_pose = 5
23
23
self .weight_gender = 2
24
24
25
+ #tf_Record Paramters
26
+ self .filename_queue = tf .train .string_input_producer ([tf_record_file_path ], num_epochs = self .num_epochs )
27
+
28
+ #Spatial Transformer Input
29
+ self .sp_input_width = 500
30
+ self .sp_input_height = 500
31
+
25
32
self .build_network ()
26
33
27
34
@@ -33,7 +40,9 @@ def build_network(self):
33
40
self .visibility = tf .placeholder (tf .float32 , [self .batch_size ,21 ], name = 'visibility' )
34
41
self .pose = tf .placeholder (tf .float32 , [self .batch_size ,3 ], name = 'pose' )
35
42
self .gender = tf .placeholder (tf .float32 , [self .batch_size ,2 ], name = 'gender' )
36
-
43
+
44
+ self .X = self .load_from_tfRecord (self .filename_queue )
45
+
37
46
net_output = self .network (self .X ) # (out_detection, out_landmarks, out_visibility, out_pose, out_gender)
38
47
39
48
loss_detection = tf .reduce_mean (tf .nn .sigmoid_cross_entropy_with_logits (net_output [0 ], self .detection ))
@@ -57,10 +66,6 @@ def train(self):
57
66
writer = tf .summary .FileWriter ('./logs' , self .sess .graph )
58
67
loss_summ = tf .summary .scalar ('loss' , self .loss )
59
68
60
-
61
-
62
-
63
-
64
69
def network (self ,inputs ):
65
70
66
71
with slim .arg_scope ([slim .conv2d , slim .fully_connected ],
@@ -102,6 +107,32 @@ def network(self,inputs):
102
107
103
108
return [out_detection , out_landmarks , out_visibility , out_pose , out_gender ]
104
109
110
+ def load_from_tfRecord (self ,filename_queue ):
111
+
112
+ reader = tf .TFRecordReader ()
113
+ _ , serialized_example = reader .read (filename_queue )
114
+
115
+ features = tf .parse_single_example (
116
+ serialized_example ,
117
+ features = {
118
+ 'image_raw' :tf .FixedLenFeature ([], tf .string ),
119
+ 'width' : tf .FixedLenFeature ([], tf .int64 ),
120
+ 'height' : tf .FixedLenFeature ([], tf .int64 )
121
+ })
122
+
123
+ image = tf .decode_raw (features ['image_raw' ], tf .float32 )
124
+ orig_height = tf .cast (features ['height' ], tf .int32 )
125
+ orig_width = tf .cast (features ['width' ], tf .int32 )
126
+
127
+ image_shape = tf .pack ([orig_height ,orig_width ,3 ])
128
+ image_tf = tf .reshape (image ,image_shape )
129
+
130
+ resized_image = tf .image .resize_image_with_crop_or_pad (image_tf ,target_height = self .img_height ,target_width = self .img_width )
131
+
132
+ images = tf .train .shuffle_batch ([resized_image ],batch_size = self .batch_size ,num_threads = 1 ,capacity = 50 ,min_after_dequeue = 10 )
133
+
134
+ return images
135
+
105
136
def print_variables (self ):
106
137
variables = slim .get_model_variables ()
107
138
print 'Model Variables:'
0 commit comments