Skip to content

Commit e83bab0

Browse files
author
Paulo Querido
committed
Added argument --with_model, verbosity control, epoch elapsed time
1 parent 2721381 commit e83bab0

File tree

1 file changed

+21
-2
lines changed

1 file changed

+21
-2
lines changed

train.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
1+
import time
2+
start = time.perf_counter()
13
import tensorflow as tf
24
import argparse
35
import pickle
46
import os
57
from model import Model
68
from utils import build_dict, build_dataset, batch_iter
79

10+
# Uncomment next 2 lines to suppress error and Tensorflow info verbosity. Or change logging levels
11+
# tf.logging.set_verbosity(tf.logging.FATAL)
12+
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
813

914
def add_arguments(parser):
1015
parser.add_argument("--num_hidden", type=int, default=150, help="Network size.")
@@ -20,6 +25,9 @@ def add_arguments(parser):
2025

2126
parser.add_argument("--toy", action="store_true", help="Use only 50K samples of data")
2227

28+
parser.add_argument("--with_model", action="store_true", help="Continue from previously saved model")
29+
30+
2331

2432
parser = argparse.ArgumentParser()
2533
add_arguments(parser)
@@ -29,6 +37,11 @@ def add_arguments(parser):
2937

3038
if not os.path.exists("saved_model"):
3139
os.mkdir("saved_model")
40+
else:
41+
if args.with_model:
42+
old_model_checkpoint_path = open('saved_model/checkpoint', 'r')
43+
old_model_checkpoint_path = "".join(["saved_model/",old_model_checkpoint_path.read().splitlines()[0].split('"')[1] ])
44+
3245

3346
print("Building dictionary...")
3447
word_dict, reversed_dict, article_max_len, summary_max_len = build_dict("train", args.toy)
@@ -40,11 +53,14 @@ def add_arguments(parser):
4053
model = Model(reversed_dict, article_max_len, summary_max_len, args)
4154
sess.run(tf.global_variables_initializer())
4255
saver = tf.train.Saver(tf.global_variables())
56+
if 'old_model_checkpoint_path' in globals():
57+
print("Continuing from previous trained model:" , old_model_checkpoint_path , "...")
58+
saver.restore(sess, old_model_checkpoint_path )
4359

4460
batches = batch_iter(train_x, train_y, args.batch_size, args.num_epochs)
4561
num_batches_per_epoch = (len(train_x) - 1) // args.batch_size + 1
4662

47-
print("Iteration starts.")
63+
print("\nIteration starts.")
4864
print("Number of batches per epoch :", num_batches_per_epoch)
4965
for batch_x, batch_y in batches:
5066
batch_x_len = list(map(lambda x: len([y for y in x if y != 0]), batch_x))
@@ -72,5 +88,8 @@ def add_arguments(parser):
7288
print("step {0}: loss = {1}".format(step, loss))
7389

7490
if step % num_batches_per_epoch == 0:
91+
hours, rem = divmod(time.perf_counter() - start, 3600)
92+
minutes, seconds = divmod(rem, 60)
7593
saver.save(sess, "./saved_model/model.ckpt", global_step=step)
76-
print("Epoch {0}: Model is saved.".format(step // num_batches_per_epoch))
94+
print(" Epoch {0}: Model is saved.".format(step // num_batches_per_epoch),
95+
"Elapsed: {:0>2}:{:0>2}:{:05.2f}".format(int(hours),int(minutes),seconds) , "\n")

0 commit comments

Comments
 (0)