Skip to content

Commit b45a748

Browse files
Shashank TyagiShashank Tyagi
Shashank Tyagi
authored and
Shashank Tyagi
committed
remove conflicts
2 parents cc7991d + 104193e commit b45a748

File tree

4 files changed

+66
-12
lines changed

4 files changed

+66
-12
lines changed

main.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
with tf.Session() as sess:
1212
print 'Building Graph...'
13-
net = HyperFace(sess)
13+
net = HyperFace(sess,tf_record_file_path='aflw_train.tfrecords')
1414
print 'Graph Built!'
1515
sess.run(tf.global_variables_initializer())
1616
net.print_variables()

model.py

+33-6
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
class HyperFace(object):
77

8-
def __init__(self, sess):
8+
def __init__(self, sess,tf_record_file_path=None):
99

1010
self.sess = sess
1111
self.batch_size = 2
@@ -22,6 +22,9 @@ def __init__(self, sess):
2222
self.weight_pose = 5
2323
self.weight_gender = 2
2424

25+
#tf_Record Paramters
26+
self.filename_queue = tf.train.string_input_producer([tf_record_file_path], num_epochs=self.num_epochs)
27+
2528
self.build_network()
2629

2730

@@ -33,7 +36,9 @@ def build_network(self):
3336
self.visibility = tf.placeholder(tf.float32, [self.batch_size,21], name='visibility')
3437
self.pose = tf.placeholder(tf.float32, [self.batch_size,3], name='pose')
3538
self.gender = tf.placeholder(tf.float32, [self.batch_size,2], name='gender')
36-
39+
40+
self.X = self.load_from_tfRecord(self.filename_queue)
41+
3742
net_output = self.network(self.X) # (out_detection, out_landmarks, out_visibility, out_pose, out_gender)
3843

3944
loss_detection = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(net_output[0], self.detection))
@@ -57,10 +62,6 @@ def train(self):
5762
writer = tf.summary.FileWriter('./logs', self.sess.graph)
5863
loss_summ = tf.summary.scalar('loss', self.loss)
5964

60-
61-
62-
63-
6465
def network(self,inputs):
6566

6667
with slim.arg_scope([slim.conv2d, slim.fully_connected],
@@ -102,6 +103,32 @@ def network(self,inputs):
102103

103104
return [out_detection, out_landmarks, out_visibility, out_pose, out_gender]
104105

106+
def load_from_tfRecord(self,filename_queue):
107+
108+
reader = tf.TFRecordReader()
109+
_, serialized_example = reader.read(filename_queue)
110+
111+
features = tf.parse_single_example(
112+
serialized_example,
113+
features={
114+
'image_raw':tf.FixedLenFeature([], tf.string),
115+
'width': tf.FixedLenFeature([], tf.int64),
116+
'height': tf.FixedLenFeature([], tf.int64)
117+
})
118+
119+
image = tf.decode_raw(features['image_raw'], tf.float32)
120+
orig_height = tf.cast(features['height'], tf.int32)
121+
orig_width = tf.cast(features['width'], tf.int32)
122+
123+
image_shape = tf.pack([orig_height,orig_width,3])
124+
image_tf = tf.reshape(image,image_shape)
125+
126+
resized_image = tf.image.resize_image_with_crop_or_pad(image_tf,target_height=self.img_height,target_width=self.img_width)
127+
128+
images = tf.train.shuffle_batch([resized_image],batch_size=self.batch_size,num_threads=1,capacity=50,min_after_dequeue=10)
129+
130+
return images
131+
105132
def print_variables(self):
106133
variables = slim.get_model_variables()
107134
print 'Model Variables:'

with SPN/main.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
weights_path = '/Users/shashank/Tensorflow/SPN/weights/'
66
imgs_path = '/Users/shashank/Tensorflow/CSE252C-Hyperface/git/truth_data.npy'
7-
7+
tf_record_file_path = 'aflw_train.tfrecords'
88
if not os.path.exists('./logs'):
99
os.makedirs('./logs')
1010

@@ -13,7 +13,7 @@
1313

1414
with tf.Session() as sess:
1515
print 'Building Graph...'
16-
model = Network(sess)
16+
model = Network(sess,tf_record_file_path=tf_record_file_path)
1717
print 'Graph Built!'
1818
sess.run(tf.global_variables_initializer())
1919
model.train()

with SPN/model.py

+30-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
class Network(object):
99

10-
def __init__(self, sess):
10+
def __init__(self, sess,tf_record_file_path=None):
1111

1212
self.sess = sess
1313
self.batch_size = 2
@@ -26,6 +26,8 @@ def __init__(self, sess):
2626
self.weight_pose = 5
2727
self.weight_gender = 2
2828

29+
#tf_Record Paramters
30+
self.filename_queue = tf.train.string_input_producer([tf_record_file_path], num_epochs=self.num_epochs)
2931
self.build_network()
3032

3133

@@ -39,10 +41,11 @@ def build_network(self):
3941
self.gender = tf.placeholder(tf.float32, [self.batch_size,2], name='gender')
4042

4143

42-
theta = self.localization_squeezenet(self.X)
44+
45+
self.X = self.load_from_tfRecord(self.filename_queue,resize_size=(self.img_width,self.img_height))
4346

47+
theta = self.localization_squeezenet(self.X)
4448
self.T_mat = tf.reshape(theta, [-1, 2,3])
45-
4649
self.cropped = transformer(self.X, self.T_mat, [self.out_height, self.out_width])
4750

4851
net_output = self.hyperface(self.cropped) # (out_detection, out_landmarks, out_visibility, out_pose, out_gender)
@@ -236,7 +239,31 @@ def predict(self, imgs_path):
236239
brk()
237240

238241

242+
def load_from_tfRecord(self,filename_queue,resize_size=None):
243+
244+
reader = tf.TFRecordReader()
245+
_, serialized_example = reader.read(filename_queue)
246+
247+
features = tf.parse_single_example(
248+
serialized_example,
249+
features={
250+
'image_raw':tf.FixedLenFeature([], tf.string),
251+
'width': tf.FixedLenFeature([], tf.int64),
252+
'height': tf.FixedLenFeature([], tf.int64)
253+
})
254+
255+
image = tf.decode_raw(features['image_raw'], tf.float32)
256+
orig_height = tf.cast(features['height'], tf.int32)
257+
orig_width = tf.cast(features['width'], tf.int32)
258+
259+
image_shape = tf.pack([orig_height,orig_width,3])
260+
image_tf = tf.reshape(image,image_shape)
239261

262+
resized_image = tf.image.resize_image_with_crop_or_pad(image_tf,target_height=resize_size[1],target_width=resize_size[0])
263+
264+
images = tf.train.shuffle_batch([resized_image],batch_size=self.batch_size,num_threads=1,capacity=50,min_after_dequeue=10)
265+
266+
return images
240267

241268
def load_weights(self, path):
242269
variables = slim.get_model_variables()

0 commit comments

Comments
 (0)