5
5
6
6
class HyperFace (object ):
7
7
8
- def __init__ (self , sess ):
8
+ def __init__ (self , sess , tf_record_file_path = None ):
9
9
10
10
self .sess = sess
11
11
self .batch_size = 2
@@ -22,6 +22,9 @@ def __init__(self, sess):
22
22
self .weight_pose = 5
23
23
self .weight_gender = 2
24
24
25
+ #tf_Record Paramters
26
+ self .filename_queue = tf .train .string_input_producer ([tf_record_file_path ], num_epochs = self .num_epochs )
27
+
25
28
self .build_network ()
26
29
27
30
@@ -33,7 +36,9 @@ def build_network(self):
33
36
self .visibility = tf .placeholder (tf .float32 , [self .batch_size ,21 ], name = 'visibility' )
34
37
self .pose = tf .placeholder (tf .float32 , [self .batch_size ,3 ], name = 'pose' )
35
38
self .gender = tf .placeholder (tf .float32 , [self .batch_size ,2 ], name = 'gender' )
36
-
39
+
40
+ self .X = self .load_from_tfRecord (self .filename_queue )
41
+
37
42
net_output = self .network (self .X ) # (out_detection, out_landmarks, out_visibility, out_pose, out_gender)
38
43
39
44
loss_detection = tf .reduce_mean (tf .nn .sigmoid_cross_entropy_with_logits (net_output [0 ], self .detection ))
@@ -57,10 +62,6 @@ def train(self):
57
62
writer = tf .summary .FileWriter ('./logs' , self .sess .graph )
58
63
loss_summ = tf .summary .scalar ('loss' , self .loss )
59
64
60
-
61
-
62
-
63
-
64
65
def network (self ,inputs ):
65
66
66
67
with slim .arg_scope ([slim .conv2d , slim .fully_connected ],
@@ -102,6 +103,32 @@ def network(self,inputs):
102
103
103
104
return [out_detection , out_landmarks , out_visibility , out_pose , out_gender ]
104
105
106
+ def load_from_tfRecord (self ,filename_queue ):
107
+
108
+ reader = tf .TFRecordReader ()
109
+ _ , serialized_example = reader .read (filename_queue )
110
+
111
+ features = tf .parse_single_example (
112
+ serialized_example ,
113
+ features = {
114
+ 'image_raw' :tf .FixedLenFeature ([], tf .string ),
115
+ 'width' : tf .FixedLenFeature ([], tf .int64 ),
116
+ 'height' : tf .FixedLenFeature ([], tf .int64 )
117
+ })
118
+
119
+ image = tf .decode_raw (features ['image_raw' ], tf .float32 )
120
+ orig_height = tf .cast (features ['height' ], tf .int32 )
121
+ orig_width = tf .cast (features ['width' ], tf .int32 )
122
+
123
+ image_shape = tf .pack ([orig_height ,orig_width ,3 ])
124
+ image_tf = tf .reshape (image ,image_shape )
125
+
126
+ resized_image = tf .image .resize_image_with_crop_or_pad (image_tf ,target_height = self .img_height ,target_width = self .img_width )
127
+
128
+ images = tf .train .shuffle_batch ([resized_image ],batch_size = self .batch_size ,num_threads = 1 ,capacity = 50 ,min_after_dequeue = 10 )
129
+
130
+ return images
131
+
105
132
def print_variables (self ):
106
133
variables = slim .get_model_variables ()
107
134
print 'Model Variables:'
0 commit comments