7
7
8
8
class Network (object ):
9
9
10
- def __init__ (self , sess ):
10
+ def __init__ (self , sess , tf_record_file_path = None ):
11
11
12
12
self .sess = sess
13
13
self .batch_size = 2
@@ -26,6 +26,8 @@ def __init__(self, sess):
26
26
self .weight_pose = 5
27
27
self .weight_gender = 2
28
28
29
+ #tf_Record Paramters
30
+ self .filename_queue = tf .train .string_input_producer ([tf_record_file_path ], num_epochs = self .num_epochs )
29
31
self .build_network ()
30
32
31
33
@@ -38,9 +40,9 @@ def build_network(self):
38
40
self .pose = tf .placeholder (tf .float32 , [self .batch_size ,3 ], name = 'pose' )
39
41
self .gender = tf .placeholder (tf .float32 , [self .batch_size ,2 ], name = 'gender' )
40
42
41
-
43
+ self .X = self .load_from_tfRecord (self .filename_queue ,resize_size = (self .img_width ,self .img_height ))
44
+
42
45
theta = self .localization_network (self .X )
43
-
44
46
T_mat = self .get_transformation_matrix (theta )
45
47
46
48
cropped = transformer (self .X , T_mat , [self .out_height , self .out_width ])
@@ -175,7 +177,31 @@ def predict(self, imgs_path):
175
177
brk ()
176
178
177
179
180
+ def load_from_tfRecord (self ,filename_queue ,resize_size = None ):
181
+
182
+ reader = tf .TFRecordReader ()
183
+ _ , serialized_example = reader .read (filename_queue )
184
+
185
+ features = tf .parse_single_example (
186
+ serialized_example ,
187
+ features = {
188
+ 'image_raw' :tf .FixedLenFeature ([], tf .string ),
189
+ 'width' : tf .FixedLenFeature ([], tf .int64 ),
190
+ 'height' : tf .FixedLenFeature ([], tf .int64 )
191
+ })
192
+
193
+ image = tf .decode_raw (features ['image_raw' ], tf .float32 )
194
+ orig_height = tf .cast (features ['height' ], tf .int32 )
195
+ orig_width = tf .cast (features ['width' ], tf .int32 )
196
+
197
+ image_shape = tf .pack ([orig_height ,orig_width ,3 ])
198
+ image_tf = tf .reshape (image ,image_shape )
178
199
200
+ resized_image = tf .image .resize_image_with_crop_or_pad (image_tf ,target_height = resize_size [1 ],target_width = resize_size [0 ])
201
+
202
+ images = tf .train .shuffle_batch ([resized_image ],batch_size = self .batch_size ,num_threads = 1 ,capacity = 50 ,min_after_dequeue = 10 )
203
+
204
+ return images
179
205
180
206
def load_weights (self , path ):
181
207
variables = slim .get_model_variables ()
0 commit comments