Skip to content

Commit fc00896

Browse files
committed
Added Preprocessing Network for Making Batches
1 parent 0272d6c commit fc00896

File tree

2 files changed

+38
-7
lines changed

2 files changed

+38
-7
lines changed

main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
with tf.Session() as sess:
1212
print 'Building Graph...'
13-
net = HyperFace(sess)
13+
net = HyperFace(sess,tf_record_file_path='aflw_train.tfrecords')
1414
print 'Graph Built!'
1515
sess.run(tf.global_variables_initializer())
1616
net.print_variables()

model.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
class HyperFace(object):
77

8-
def __init__(self, sess):
8+
def __init__(self, sess,tf_record_file_path=None):
99

1010
self.sess = sess
1111
self.batch_size = 2
@@ -22,6 +22,13 @@ def __init__(self, sess):
2222
self.weight_pose = 5
2323
self.weight_gender = 2
2424

25+
#tf_Record Paramters
26+
self.filename_queue = tf.train.string_input_producer([tf_record_file_path], num_epochs=self.num_epochs)
27+
28+
#Spatial Transformer Input
29+
self.sp_input_width = 500
30+
self.sp_input_height = 500
31+
2532
self.build_network()
2633

2734

@@ -33,7 +40,9 @@ def build_network(self):
3340
self.visibility = tf.placeholder(tf.float32, [self.batch_size,21], name='visibility')
3441
self.pose = tf.placeholder(tf.float32, [self.batch_size,3], name='pose')
3542
self.gender = tf.placeholder(tf.float32, [self.batch_size,2], name='gender')
36-
43+
44+
self.X = self.load_from_tfRecord(self.filename_queue)
45+
3746
net_output = self.network(self.X) # (out_detection, out_landmarks, out_visibility, out_pose, out_gender)
3847

3948
loss_detection = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(net_output[0], self.detection))
@@ -57,10 +66,6 @@ def train(self):
5766
writer = tf.summary.FileWriter('./logs', self.sess.graph)
5867
loss_summ = tf.summary.scalar('loss', self.loss)
5968

60-
61-
62-
63-
6469
def network(self,inputs):
6570

6671
with slim.arg_scope([slim.conv2d, slim.fully_connected],
@@ -102,6 +107,32 @@ def network(self,inputs):
102107

103108
return [out_detection, out_landmarks, out_visibility, out_pose, out_gender]
104109

110+
def load_from_tfRecord(self,filename_queue):
111+
112+
reader = tf.TFRecordReader()
113+
_, serialized_example = reader.read(filename_queue)
114+
115+
features = tf.parse_single_example(
116+
serialized_example,
117+
features={
118+
'image_raw':tf.FixedLenFeature([], tf.string),
119+
'width': tf.FixedLenFeature([], tf.int64),
120+
'height': tf.FixedLenFeature([], tf.int64)
121+
})
122+
123+
image = tf.decode_raw(features['image_raw'], tf.float32)
124+
orig_height = tf.cast(features['height'], tf.int32)
125+
orig_width = tf.cast(features['width'], tf.int32)
126+
127+
image_shape = tf.pack([orig_height,orig_width,3])
128+
image_tf = tf.reshape(image,image_shape)
129+
130+
resized_image = tf.image.resize_image_with_crop_or_pad(image_tf,target_height=self.img_height,target_width=self.img_width)
131+
132+
images = tf.train.shuffle_batch([resized_image],batch_size=self.batch_size,num_threads=1,capacity=50,min_after_dequeue=10)
133+
134+
return images
135+
105136
def print_variables(self):
106137
variables = slim.get_model_variables()
107138
print 'Model Variables:'

0 commit comments

Comments
 (0)