Skip to content

Commit 2721381

Browse files
committed
Change data_utils function
1 parent 1f91203 commit 2721381

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
print("Loading dictionary...")
1111
word_dict, reversed_dict, article_max_len, summary_max_len = build_dict("valid", args.toy)
1212
print("Loading validation dataset...")
13-
valid_x, valid_y = build_dataset("valid", word_dict, article_max_len, summary_max_len, args.toy)
13+
valid_x = build_dataset("valid", word_dict, article_max_len, summary_max_len, args.toy)
1414
valid_x_len = list(map(lambda x: len([y for y in x if y != 0]), valid_x))
1515

1616
with tf.Session() as sess:
@@ -20,10 +20,10 @@
2020
ckpt = tf.train.get_checkpoint_state("./saved_model/")
2121
saver.restore(sess, ckpt.model_checkpoint_path)
2222

23-
batches = batch_iter(valid_x, valid_y, args.batch_size, 1)
23+
batches = batch_iter(valid_x, [0] * len(valid_x), args.batch_size, 1)
2424

2525
print("Writing summaries to 'result.txt'...")
26-
for batch_x, batch_y in batches:
26+
for batch_x, _ in batches:
2727
batch_x_len = list(map(lambda x: len([y for y in x if y != 0]), batch_x))
2828

2929
valid_feed_dict = {

utils.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ def build_dataset(step, word_dict, article_max_len, summary_max_len, toy=False):
6767
title_list = get_text_list(train_title_path, toy)
6868
elif step == "valid":
6969
article_list = get_text_list(valid_article_path, toy)
70-
title_list = get_text_list(valid_title_path, toy)
7170
else:
7271
raise NotImplementedError
7372

@@ -76,11 +75,13 @@ def build_dataset(step, word_dict, article_max_len, summary_max_len, toy=False):
7675
x = list(map(lambda d: d[:article_max_len], x))
7776
x = list(map(lambda d: d + (article_max_len - len(d)) * [word_dict["<padding>"]], x))
7877

79-
y = list(map(lambda d: word_tokenize(d), title_list))
80-
y = list(map(lambda d: list(map(lambda w: word_dict.get(w, word_dict["<unk>"]), d)), y))
81-
y = list(map(lambda d: d[:(summary_max_len-1)], y))
82-
83-
return x, y
78+
if step == "valid":
79+
return x
80+
else:
81+
y = list(map(lambda d: word_tokenize(d), title_list))
82+
y = list(map(lambda d: list(map(lambda w: word_dict.get(w, word_dict["<unk>"]), d)), y))
83+
y = list(map(lambda d: d[:(summary_max_len-1)], y))
84+
return x, y
8485

8586

8687
def batch_iter(inputs, outputs, batch_size, num_epochs):

0 commit comments

Comments
 (0)