@@ -7,11 +7,24 @@ def softmax_sample(key, logits, _, temp=1):
7
7
return jax .random .categorical (key , logits / temp , - 1 ).astype (jnp .uint32 ), None
8
8
9
9
10
- def nucleaus_filter (logits , top_p = 0.9 ):
10
+ def nucleaus_filter (logits , top_p = 0.9 , top_k = None ):
11
11
sorted_logits = jnp .sort (logits )[:, ::- 1 ] # sort descending
12
12
sorted_indices = jnp .argsort (logits )[:, ::- 1 ]
13
13
cumulative_probs = jnp .cumsum (jax .nn .softmax (sorted_logits ), axis = - 1 )
14
14
15
+ if top_k is not None :
16
+ # Keep only top_k tokens
17
+ indices_range = jnp .arange (len (sorted_indices [0 ]))
18
+ indices_range = jnp .stack ([indices_range ] * len (sorted_indices ), axis = 0 )
19
+
20
+ sorted_indices_to_remove = jnp .where (indices_range > top_k , sorted_indices , 0 )
21
+
22
+ _ , indices_to_remove = jax .lax .sort_key_val (sorted_indices , sorted_indices_to_remove )
23
+
24
+ logit_mask = 1e10 * indices_to_remove
25
+
26
+ logits -= logit_mask
27
+
15
28
# Remove tokens with cumulative probability above a threshold
16
29
sorted_indices_to_remove = cumulative_probs > top_p
17
30
sorted_indices_to_remove = jnp .concatenate ((jnp .zeros_like (sorted_indices_to_remove [:, :1 ]), sorted_indices_to_remove ), axis = - 1 )[:, :- 1 ]
@@ -25,8 +38,8 @@ def nucleaus_filter(logits, top_p=0.9):
25
38
return logits
26
39
27
40
28
- def nucleaus_sample (key , logits , _ , top_p = 0.9 , temp = 1 ):
29
- logits = nucleaus_filter (logits , top_p )
41
+ def nucleaus_sample (key , logits , _ , top_p = 0.9 , temp = 1 , top_k = None ):
42
+ logits = nucleaus_filter (logits , top_p , top_k = top_k )
30
43
31
44
return softmax_sample (key , logits , None , temp = temp )
32
45
0 commit comments