1
+ import time
2
+ start = time .perf_counter ()
1
3
import tensorflow as tf
2
4
import argparse
3
5
import pickle
4
6
import os
5
7
from model import Model
6
8
from utils import build_dict , build_dataset , batch_iter
7
9
10
+ # Uncomment next 2 lines to suppress error and Tensorflow info verbosity. Or change logging levels
11
+ # tf.logging.set_verbosity(tf.logging.FATAL)
12
+ # os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
8
13
9
14
def add_arguments (parser ):
10
15
parser .add_argument ("--num_hidden" , type = int , default = 150 , help = "Network size." )
@@ -20,6 +25,9 @@ def add_arguments(parser):
20
25
21
26
parser .add_argument ("--toy" , action = "store_true" , help = "Use only 50K samples of data" )
22
27
28
+ parser .add_argument ("--with_model" , action = "store_true" , help = "Continue from previously saved model" )
29
+
30
+
23
31
24
32
parser = argparse .ArgumentParser ()
25
33
add_arguments (parser )
@@ -29,6 +37,11 @@ def add_arguments(parser):
29
37
30
38
if not os .path .exists ("saved_model" ):
31
39
os .mkdir ("saved_model" )
40
+ else :
41
+ if args .with_model :
42
+ old_model_checkpoint_path = open ('saved_model/checkpoint' , 'r' )
43
+ old_model_checkpoint_path = "" .join (["saved_model/" ,old_model_checkpoint_path .read ().splitlines ()[0 ].split ('"' )[1 ] ])
44
+
32
45
33
46
print ("Building dictionary..." )
34
47
word_dict , reversed_dict , article_max_len , summary_max_len = build_dict ("train" , args .toy )
@@ -40,11 +53,14 @@ def add_arguments(parser):
40
53
model = Model (reversed_dict , article_max_len , summary_max_len , args )
41
54
sess .run (tf .global_variables_initializer ())
42
55
saver = tf .train .Saver (tf .global_variables ())
56
+ if 'old_model_checkpoint_path' in globals ():
57
+ print ("Continuing from previous trained model:" , old_model_checkpoint_path , "..." )
58
+ saver .restore (sess , old_model_checkpoint_path )
43
59
44
60
batches = batch_iter (train_x , train_y , args .batch_size , args .num_epochs )
45
61
num_batches_per_epoch = (len (train_x ) - 1 ) // args .batch_size + 1
46
62
47
- print ("Iteration starts." )
63
+ print ("\n Iteration starts." )
48
64
print ("Number of batches per epoch :" , num_batches_per_epoch )
49
65
for batch_x , batch_y in batches :
50
66
batch_x_len = list (map (lambda x : len ([y for y in x if y != 0 ]), batch_x ))
@@ -72,5 +88,8 @@ def add_arguments(parser):
72
88
print ("step {0}: loss = {1}" .format (step , loss ))
73
89
74
90
if step % num_batches_per_epoch == 0 :
91
+ hours , rem = divmod (time .perf_counter () - start , 3600 )
92
+ minutes , seconds = divmod (rem , 60 )
75
93
saver .save (sess , "./saved_model/model.ckpt" , global_step = step )
76
- print ("Epoch {0}: Model is saved." .format (step // num_batches_per_epoch ))
94
+ print (" Epoch {0}: Model is saved." .format (step // num_batches_per_epoch ),
95
+ "Elapsed: {:0>2}:{:0>2}:{:05.2f}" .format (int (hours ),int (minutes ),seconds ) , "\n " )
0 commit comments