Skip to content

Commit c0859d7

Browse files
albertwujjWuTheFWasThat
authored andcommitted
Fix TODO in sample.sample_sequences- Avoid 'leaving last token calculation to while loop' (#119)
* do initial run on full context * decrement while loop iterations * add context to output * remove first param * removing first param: change shape invariant
1 parent e5c5054 commit c0859d7

File tree

1 file changed

+11
-14
lines changed

1 file changed

+11
-14
lines changed

src/sample.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -41,36 +41,33 @@ def step(hparams, tokens, past=None):
4141
}
4242

4343
with tf.name_scope('sample_sequence'):
44-
# Don't feed the last context token -- leave that to the loop below
45-
# TODO: Would be slightly faster if we called step on the entire context,
46-
# rather than leaving the last token transformer calculation to the while loop.
47-
context_output = step(hparams, context[:, :-1])
48-
4944
def body(past, prev, output):
50-
next_outputs = step(hparams, prev[:, tf.newaxis], past=past)
45+
next_outputs = step(hparams, prev, past=past)
5146
logits = next_outputs['logits'][:, -1, :] / tf.to_float(temperature)
5247
logits = top_k_logits(logits, k=top_k)
5348
samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32)
5449
return [
55-
tf.concat([past, next_outputs['presents']], axis=-2),
56-
tf.squeeze(samples, axis=[1]),
57-
tf.concat([output, samples], axis=1),
50+
next_outputs['presents'] if past is None else tf.concat([past, next_outputs['presents']], axis=-2),
51+
samples,
52+
tf.concat([output, samples], axis=1)
5853
]
5954

55+
past, prev, output = body(None, context, context)
56+
6057
def cond(*args):
6158
return True
6259

6360
_, _, tokens = tf.while_loop(
6461
cond=cond, body=body,
65-
maximum_iterations=length,
62+
maximum_iterations=length - 1,
6663
loop_vars=[
67-
context_output['presents'],
68-
context[:, -1],
69-
context,
64+
past,
65+
prev,
66+
output
7067
],
7168
shape_invariants=[
7269
tf.TensorShape(model.past_shape(hparams=hparams, batch_size=batch_size)),
73-
tf.TensorShape([batch_size]),
70+
tf.TensorShape([batch_size, None]),
7471
tf.TensorShape([batch_size, None]),
7572
],
7673
back_prop=False,

0 commit comments

Comments
 (0)