Skip to content

Commit d369632

Browse files
author
Shashank Tyagi
committed
prediction code added
1 parent c0260bf commit d369632

File tree

2 files changed

+425
-0
lines changed

2 files changed

+425
-0
lines changed

main_prediction.py

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import tensorflow as tf
2+
import os
3+
from model import *
4+
5+
6+
7+
if not os.path.exists('../logs'):
8+
os.makedirs('../logs')
9+
10+
if not os.path.exists('../checkpoint'):
11+
os.makedirs('../checkpoint')
12+
13+
if not os.path.exists('../best_checkpoint'):
14+
os.makedirs('../best_checkpoint')
15+
16+
map(os.unlink, (os.path.join( '../logs',f) for f in os.listdir('../logs')) )
17+
18+
net = HyperFace(True, tf_record_file_path='./aflw_train_small_check.tfrecords',model_save_path='../checkpoint/',best_model_save_path='../best_checkpoint/',
19+
restore_model_path='../full_best_checkpoint/')
20+
21+
with tf.Session() as sess:
22+
print 'Building Graph...'
23+
net.build_network(sess)
24+
print 'Graph Built!'
25+
# net.print_variables()
26+
# net.load_weights('/Users/shashank/TensorFlow/SPN/weights/')
27+
net.predict()
28+
# net.train()
29+

0 commit comments

Comments
 (0)