Skip to content

Commit 6d95c5a

Browse files
committed
Added Preprocessing Network for Making Batches
1 parent fc00896 commit 6d95c5a

File tree

2 files changed

+31
-5
lines changed

2 files changed

+31
-5
lines changed

with SPN/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
weights_path = '/Users/shashank/Tensorflow/SPN/weights/'
66
imgs_path = '/Users/shashank/Tensorflow/CSE252C-Hyperface/git/truth_data.npy'
7-
7+
tf_record_file_path = 'aflw_train.tfrecords'
88
if not os.path.exists('./logs'):
99
os.makedirs('./logs')
1010

@@ -13,7 +13,7 @@
1313

1414
with tf.Session() as sess:
1515
print 'Building Graph...'
16-
model = Network(sess)
16+
model = Network(sess,tf_record_file_path=tf_record_file_path)
1717
print 'Graph Built!'
1818
sess.run(tf.global_variables_initializer())
1919
model.train()

with SPN/model.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
class Network(object):
99

10-
def __init__(self, sess):
10+
def __init__(self, sess,tf_record_file_path=None):
1111

1212
self.sess = sess
1313
self.batch_size = 2
@@ -26,6 +26,8 @@ def __init__(self, sess):
2626
self.weight_pose = 5
2727
self.weight_gender = 2
2828

29+
#tf_Record Paramters
30+
self.filename_queue = tf.train.string_input_producer([tf_record_file_path], num_epochs=self.num_epochs)
2931
self.build_network()
3032

3133

@@ -38,9 +40,9 @@ def build_network(self):
3840
self.pose = tf.placeholder(tf.float32, [self.batch_size,3], name='pose')
3941
self.gender = tf.placeholder(tf.float32, [self.batch_size,2], name='gender')
4042

41-
43+
self.X = self.load_from_tfRecord(self.filename_queue,resize_size=(self.img_width,self.img_height))
44+
4245
theta = self.localization_network(self.X)
43-
4446
T_mat = self.get_transformation_matrix(theta)
4547

4648
cropped = transformer(self.X, T_mat, [self.out_height, self.out_width])
@@ -175,7 +177,31 @@ def predict(self, imgs_path):
175177
brk()
176178

177179

180+
def load_from_tfRecord(self,filename_queue,resize_size=None):
181+
182+
reader = tf.TFRecordReader()
183+
_, serialized_example = reader.read(filename_queue)
184+
185+
features = tf.parse_single_example(
186+
serialized_example,
187+
features={
188+
'image_raw':tf.FixedLenFeature([], tf.string),
189+
'width': tf.FixedLenFeature([], tf.int64),
190+
'height': tf.FixedLenFeature([], tf.int64)
191+
})
192+
193+
image = tf.decode_raw(features['image_raw'], tf.float32)
194+
orig_height = tf.cast(features['height'], tf.int32)
195+
orig_width = tf.cast(features['width'], tf.int32)
196+
197+
image_shape = tf.pack([orig_height,orig_width,3])
198+
image_tf = tf.reshape(image,image_shape)
178199

200+
resized_image = tf.image.resize_image_with_crop_or_pad(image_tf,target_height=resize_size[1],target_width=resize_size[0])
201+
202+
images = tf.train.shuffle_batch([resized_image],batch_size=self.batch_size,num_threads=1,capacity=50,min_after_dequeue=10)
203+
204+
return images
179205

180206
def load_weights(self, path):
181207
variables = slim.get_model_variables()

0 commit comments

Comments
 (0)