@@ -192,7 +192,9 @@ def prefill(
192
192
)
193
193
return (prefix , result_tokens )
194
194
195
- @functools .partial (jax .jit , static_argnums = (0 ,))
195
+ @functools .partial (
196
+ jax .jit , static_argnums = (0 ,), static_argnames = ("num_samples" ,)
197
+ )
196
198
def prefill_multisampling (
197
199
self ,
198
200
* ,
@@ -216,26 +218,30 @@ def prefill_multisampling(
216
218
# Generate dummy prefill cache content
217
219
prefill_cache = padded_tokens [None , :] * params
218
220
219
- # Create a dummy first generated token.
220
- first_generated_token = (prefill_cache .sum (axis = - 1 ).astype (jnp .int32 ))[
221
- :, jnp .newaxis
222
- ]
221
+ # Create dummy first generated tokens.
222
+ first_generated_tokens = []
223
+ for _ in range (num_samples ):
224
+ first_generated_token = (prefill_cache .sum (axis = - 1 ).astype (jnp .int32 ))[
225
+ :, jnp .newaxis
226
+ ]
227
+ first_generated_tokens .append (first_generated_token )
228
+ first_generated_tokens = jnp .concatenate (first_generated_tokens , axis = 0 )
223
229
224
230
prefix = Prefix (
225
231
logits = jax .random .normal (self ._prng_key , (1 , self .vocab_size )),
226
232
cache = prefill_cache ,
227
233
next_pos = jnp .full ((1 , 1 ), true_length , dtype = jnp .int32 ),
228
- num_generated_tokens = jnp .zeros ((1 , 1 ), dtype = jnp .int32 ),
229
- first_token = first_generated_token ,
234
+ num_generated_tokens = jnp .zeros ((num_samples , 1 ), dtype = jnp .int32 ),
235
+ first_token = first_generated_tokens ,
230
236
)
231
237
232
238
speculations = first_generated_token .shape [1 ]
233
239
result_tokens = engine_api .ResultTokens (
234
240
data = jnp .concatenate (
235
241
(
236
- first_generated_token ,
237
- jnp .ones_like (first_generated_token ),
238
- jnp .ones_like (first_generated_token ),
242
+ first_generated_tokens ,
243
+ jnp .ones_like (first_generated_tokens ),
244
+ jnp .ones_like (first_generated_tokens ),
239
245
),
240
246
axis = - 1 ,
241
247
),
@@ -244,7 +250,7 @@ def prefill_multisampling(
244
250
valid_idx = (speculations , 2 * speculations ),
245
251
# And lengths is rank 1.
246
252
length_idx = (2 * speculations , 2 * speculations + 1 ),
247
- samples_per_slot = self . generate_cache_batch // self . prefill_cache_batch ,
253
+ samples_per_slot = num_samples ,
248
254
)
249
255
return (prefix , result_tokens )
250
256
@@ -398,21 +404,21 @@ def bulk_insert(
398
404
"""Insert a single computed prefill cache into multiple slots in
399
405
KV cache.
400
406
"""
401
- prefill_cache = prefix . cache
407
+ prefill_cache = decode_state . prefill_cache
402
408
generate_cache = decode_state .generate_cache
403
409
generate_lengths = decode_state .generate_lengths
404
410
generate_tokens = decode_state .generate_tokens
405
411
for slot in slots :
406
412
prefill_cache = jax .lax .dynamic_update_slice_in_dim (
407
- decode_state . prefill_cache , prefill_cache , slot , axis = 0
413
+ prefill_cache , prefix . cache , slot , axis = 0
408
414
)
409
415
generate_cache = jax .lax .dynamic_update_slice_in_dim (
410
416
generate_cache ,
411
417
jnp .zeros ((1 , self .cache_length )),
412
418
slot ,
413
419
axis = 0 ,
414
420
)
415
- samples_per_slot = self . generate_cache_batch // self . prefill_cache_batch
421
+ samples_per_slot = 1
416
422
generate_lengths = jax .lax .dynamic_update_slice_in_dim (
417
423
generate_lengths ,
418
424
jnp .ones ((samples_per_slot ), dtype = jnp .int32 ),
0 commit comments