102
102
from jetstream .core .metrics .prometheus import JetstreamMetricsCollector
103
103
import numpy as np
104
104
105
- log_level = os .getenv ("LOG_LEVEL" , "WARNING" ).upper ()
105
+ from jax .experimental import layout as jax_layout
106
+ DLL = jax_layout .DeviceLocalLayout
107
+ Layout = jax_layout .Layout
108
+
109
+ log_level = os .getenv ("LOG_LEVEL" , "DEBUG" ).upper ()
106
110
107
111
logger = logging .getLogger ("JetstreamLogger" )
108
112
logger .propagate = False
@@ -405,6 +409,26 @@ def __init__(
405
409
406
410
self ._jax_padding = jax_padding
407
411
412
+ ##### Auto layout compile for interleaved engine
413
+ self ._generate_executables = [None for _ in self ._generate_engines ]
414
+ self ._cached_insert = [None for _ in self ._generate_engines ]
415
+ self ._cached_prefill = [None for _ in self ._prefill_engines ]
416
+ if self ._interleaved_mode :
417
+ for idx in range (len (self ._generate_engines )):
418
+ logger .debug ("Compiling interleaved engine {}" .format (idx ))
419
+ engine = self ._generate_engines [idx ]
420
+ params = self ._generate_params [idx ]
421
+ engine , params , gen_fn , prefill_fn , insert_fn = self ._auto_layout_compile (engine , params )
422
+
423
+ self ._prefill_engines [idx ] = engine
424
+ self ._generate_engines [idx ] = engine
425
+ self ._prefill_params [idx ] = params
426
+ self ._generate_params [idx ] = params
427
+ self ._cached_prefill [idx ] = prefill_fn
428
+ self ._cached_insert [idx ] = insert_fn
429
+ self ._generate_executables [idx ] = gen_fn
430
+
431
+
408
432
# Create all threads
409
433
self ._prefill_threads = [
410
434
JetThread (
@@ -670,6 +694,56 @@ def _do_chunked_prefill(
670
694
671
695
return prefill_result , first_token
672
696
697
+ def _auto_layout_compile (self , engine , params ):
698
+ logger .debug ("Compiling generate function" )
699
+ generate_executable , params , decode_state_executable = engine .aot_compile (
700
+ params , pass_rng_shape = False
701
+ )
702
+ decode_state = decode_state_executable (None )
703
+
704
+ # prefill
705
+ interesting_buckets = [
706
+ 64 ,
707
+ 128 ,
708
+ 256 ,
709
+ 512 ,
710
+ 1024 ,
711
+ ]
712
+
713
+ cached_prefill = {}
714
+ cached_insert = {}
715
+ for length in interesting_buckets :
716
+ i32_scalar = jax .ShapeDtypeStruct ((), int )
717
+ logger .debug ("Compiling prefill: %d" , length )
718
+ input_data = jax .ShapeDtypeStruct ((length ,), jax .numpy .dtype ("int32" ))
719
+
720
+ cached_prefill [length ] = (
721
+ jax .jit (
722
+ engine .prefill_aot ,
723
+ in_shardings = (engine .param_layouts , None , None ),
724
+ out_shardings = (Layout (DLL .AUTO ), Layout (DLL .AUTO )),
725
+ ).lower (params , input_data , i32_scalar )
726
+ ).compile (compiler_options = None )
727
+
728
+ logger .debug ("Generate dummy prefix: %d" , length )
729
+ dummy_tokens = jax .numpy .ones (shape = (length ,), dtype = jax .numpy .dtype ("int32" ))
730
+ prefix_shapes = jax .eval_shape (engine .prefill_aot , params , dummy_tokens , 1 )
731
+
732
+ logger .debug ("Compiling insert: %d" , length )
733
+ prefill_output_layout , _ = cached_prefill [length ].output_layouts
734
+ logger .debug ("Prefill output layout: {}" .format (prefill_output_layout ))
735
+ logger .debug ("Prefix shapes: {}" .format (prefix_shapes ))
736
+ i32_scalar = jax .ShapeDtypeStruct ((), int )
737
+ cached_insert [length ] = (
738
+ jax .jit (
739
+ engine .insert ,
740
+ in_shardings = (prefill_output_layout , engine .decode_state_layouts , None ),
741
+ out_shardings = (engine .decode_state_layouts ),
742
+ donate_argnames = ("decode_state" ),
743
+ ).lower (prefix_shapes [0 ], engine .decode_state_shapes , i32_scalar )
744
+ ).compile (compiler_options = None )
745
+ return engine , params , generate_executable , cached_prefill , cached_insert
746
+
673
747
def _prefill_thread (self , idx : int ):
674
748
"""Thread which runs in the background performing prefills."""
675
749
logger .info ("Spinning up prefill thread %d." , idx )
@@ -683,6 +757,12 @@ def _prefill_thread(self, idx: int):
683
757
thread_name = f"Prefill thread { idx } "
684
758
ThreadDebugLog (thread_name , f"Prefill params { idx } loaded." )
685
759
760
+ if not self .interleaved :
761
+ prefill_engine , prefill_params , gen_fn , prefill_fn , insert_fn = self ._auto_layout_compile (
762
+ prefill_engine , prefill_params
763
+ )
764
+ self ._cached_prefill [idx ] = prefill_fn
765
+
686
766
while self .live :
687
767
my_transfer_backlog = self ._transfer_backlogs [idx ]
688
768
# The prefill thread can just sleep until it has work to do.
@@ -759,10 +839,11 @@ def _prefill_thread(self, idx: int):
759
839
)
760
840
else :
761
841
# Compute new kv cache for the prefill_content.
762
- prefill_result , first_token = prefill_engine .prefill (
763
- params = final_prefill_params ,
764
- padded_tokens = padded_tokens ,
765
- true_length = true_length ,
842
+ assert padded_tokens .shape [0 ] in self ._cached_prefill [idx ]
843
+ prefill_result , first_token = self ._cached_prefill [idx ][padded_tokens .shape [0 ]](
844
+ final_prefill_params ,
845
+ padded_tokens ,
846
+ true_length ,
766
847
)
767
848
768
849
request .complete = np .zeros (
@@ -967,10 +1048,11 @@ def _insert_if_possible(
967
1048
else :
968
1049
break
969
1050
970
- decode_state = generate_engine .insert (
1051
+ length = new_request .prefill_result ['cache' ]['decoder' ]['layers_0' ]['self_attention' ]['KVCache_0' ]['cache_prefill_segment_id' ].value .shape [1 ]
1052
+ decode_state = self ._cached_insert [idx ][length ](
971
1053
new_request .prefill_result ,
972
1054
decode_state ,
973
- slot = slot ,
1055
+ slot ,
974
1056
# request_id=new_request.request_id,
975
1057
)
976
1058
ThreadDebugLog (
@@ -1115,9 +1197,15 @@ def _generate_thread(self, idx: int):
1115
1197
# Keep track of what step tokens were generated at.
1116
1198
generate_timestep = 0
1117
1199
# State to store things like running kv cache in.
1118
- decode_state = generate_engine .init_decode_state ()
1119
-
1200
+ decode_state = self .decode_state
1120
1201
generate_params = self ._generate_params [idx ]
1202
+
1203
+ if not self .interleaved :
1204
+ generate_engine , generate_params , gen_fn , prefill_fn , insert_fn = self ._auto_layout_compile (
1205
+ generate_engine , generate_params
1206
+ )
1207
+ self ._generate_executables [idx ] = gen_fn
1208
+
1121
1209
thread_name = f"Generate thread { idx } "
1122
1210
ThreadDebugLog (thread_name , f"Generate params { idx } loaded." )
1123
1211
time_of_last_generate = time .time ()
@@ -1178,8 +1266,8 @@ def _generate_thread(self, idx: int):
1178
1266
), "At this point we must have some requests inserted into the slots."
1179
1267
1180
1268
# Now we actually take a generate step on requests in the slots.
1181
- decode_state , sampled_tokens = generate_engine . generate (
1182
- generate_params , decode_state
1269
+ decode_state , sampled_tokens = self . _generate_executables [ idx ] (
1270
+ generate_params , decode_state , None
1183
1271
)
1184
1272
sampled_tokens .copy_to_host_async ()
1185
1273
# Respond to detokenization backpressure.
0 commit comments