Skip to content

Commit d85bc13

Browse files
Shashank TyagiShashank Tyagi
Shashank Tyagi
authored and
Shashank Tyagi
committed
add training method
1 parent b45a748 commit d85bc13

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

with SPN/main.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,17 @@
44

55
weights_path = '/Users/shashank/Tensorflow/SPN/weights/'
66
imgs_path = '/Users/shashank/Tensorflow/CSE252C-Hyperface/git/truth_data.npy'
7-
tf_record_file_path = 'aflw_train.tfrecords'
7+
tf_record_file_path = '../aflw_train.tfrecords'
88
if not os.path.exists('./logs'):
99
os.makedirs('./logs')
1010

1111
map(os.unlink, (os.path.join( './logs',f) for f in os.listdir('./logs')) )
1212

1313

14+
1415
with tf.Session() as sess:
1516
print 'Building Graph...'
16-
model = Network(sess,tf_record_file_path=tf_record_file_path)
17+
model = Network(sess,tf_record_file_path)
1718
print 'Graph Built!'
1819
sess.run(tf.global_variables_initializer())
1920
model.train()

with SPN/model.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,6 @@ def build_network(self):
4141
self.gender = tf.placeholder(tf.float32, [self.batch_size,2], name='gender')
4242

4343

44-
45-
self.X = self.load_from_tfRecord(self.filename_queue,resize_size=(self.img_width,self.img_height))
46-
4744
theta = self.localization_squeezenet(self.X)
4845
self.T_mat = tf.reshape(theta, [-1, 2,3])
4946
self.cropped = transformer(self.X, self.T_mat, [self.out_height, self.out_width])
@@ -87,6 +84,14 @@ def train(self):
8784

8885
print self.sess.run(self.T_mat, feed_dict={self.X: np.random.randn(self.batch_size, self.img_height, self.img_width, self.channel)})
8986

87+
images = self.load_from_tfRecord(self.filename_queue)
88+
89+
coord = tf.train.Coordinator()
90+
threads = tf.train.start_queue_runners(sess = self.sess, coord = coord)
91+
92+
for i in xrange(2):
93+
img_batch = self.sess.run(images)
94+
print img_batch.shape
9095

9196

9297
def hyperface(self,inputs, reuse = False):
@@ -238,7 +243,6 @@ def predict(self, imgs_path):
238243

239244
brk()
240245

241-
242246
def load_from_tfRecord(self,filename_queue,resize_size=None):
243247

244248
reader = tf.TFRecordReader()
@@ -258,13 +262,15 @@ def load_from_tfRecord(self,filename_queue,resize_size=None):
258262

259263
image_shape = tf.pack([orig_height,orig_width,3])
260264
image_tf = tf.reshape(image,image_shape)
261-
265+
print image_shape
262266
resized_image = tf.image.resize_image_with_crop_or_pad(image_tf,target_height=resize_size[1],target_width=resize_size[0])
263267

264268
images = tf.train.shuffle_batch([resized_image],batch_size=self.batch_size,num_threads=1,capacity=50,min_after_dequeue=10)
265269

266270
return images
267271

272+
273+
268274
def load_weights(self, path):
269275
variables = slim.get_model_variables()
270276
print 'Loading weights...'

0 commit comments

Comments
 (0)