Skip to content

Commit 56626ff

Browse files
authored
Implement top_k sampling (kingoflolz#13)
1 parent 6a020fa commit 56626ff

File tree

2 files changed

+18
-5
lines changed

2 files changed

+18
-5
lines changed

mesh_transformer/sampling.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,24 @@ def softmax_sample(key, logits, _, temp=1):
77
return jax.random.categorical(key, logits/temp, -1).astype(jnp.uint32), None
88

99

10-
def nucleaus_filter(logits, top_p=0.9):
10+
def nucleaus_filter(logits, top_p=0.9, top_k=None):
1111
sorted_logits = jnp.sort(logits)[:, ::-1] # sort descending
1212
sorted_indices = jnp.argsort(logits)[:, ::-1]
1313
cumulative_probs = jnp.cumsum(jax.nn.softmax(sorted_logits), axis=-1)
1414

15+
if top_k is not None:
16+
# Keep only top_k tokens
17+
indices_range = jnp.arange(len(sorted_indices[0]))
18+
indices_range = jnp.stack([indices_range] * len(sorted_indices), axis=0)
19+
20+
sorted_indices_to_remove = jnp.where(indices_range > top_k, sorted_indices, 0)
21+
22+
_, indices_to_remove = jax.lax.sort_key_val(sorted_indices, sorted_indices_to_remove)
23+
24+
logit_mask = 1e10 * indices_to_remove
25+
26+
logits -= logit_mask
27+
1528
# Remove tokens with cumulative probability above a threshold
1629
sorted_indices_to_remove = cumulative_probs > top_p
1730
sorted_indices_to_remove = jnp.concatenate((jnp.zeros_like(sorted_indices_to_remove[:, :1]), sorted_indices_to_remove), axis=-1)[:, :-1]
@@ -25,8 +38,8 @@ def nucleaus_filter(logits, top_p=0.9):
2538
return logits
2639

2740

28-
def nucleaus_sample(key, logits, _, top_p=0.9, temp=1):
29-
logits = nucleaus_filter(logits, top_p)
41+
def nucleaus_sample(key, logits, _, top_p=0.9, temp=1, top_k=None):
42+
logits = nucleaus_filter(logits, top_p, top_k=top_k)
3043

3144
return softmax_sample(key, logits, None, temp=temp)
3245

resharding_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
# move the state to CPU/system memory so it's not duplicated by xmap
5353
network.state = jax.device_put(network.state, jax.devices("cpu")[0])
5454

55-
def infer(context, top_p=0.9, temp=1.0, gen_len=512):
55+
def infer(context, top_k=40, top_p=0.9, temp=1.0, gen_len=512):
5656
tokens = tokenizer.encode(context)
5757

5858
provided_ctx = len(tokens)
@@ -63,7 +63,7 @@ def infer(context, top_p=0.9, temp=1.0, gen_len=512):
6363
length = np.ones(per_replica_batch, dtype=np.uint32) * len(tokens)
6464

6565
start = time.time()
66-
output = network.generate(batched_tokens, length, gen_len, {"top_p": np.ones(per_replica_batch) * top_p, "temp": np.ones(per_replica_batch) * temp})
66+
output = network.generate(batched_tokens, length, gen_len, {"top_p": np.ones(per_replica_batch) * top_p, "top_k": top_k is not None and (np.ones(per_replica_batch, dtype=np.int32) * top_k) or None, "temp": np.ones(per_replica_batch) * temp})
6767

6868
samples = []
6969
decoded_tokens = output[1][0]

0 commit comments

Comments
 (0)