Skip to content

Commit b547466

Browse files
committed
hacky way to test aot in jetstream
1 parent 9cb7785 commit b547466

File tree

1 file changed

+99
-11
lines changed

1 file changed

+99
-11
lines changed

jetstream/core/orchestrator.py

Lines changed: 99 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,11 @@
102102
from jetstream.core.metrics.prometheus import JetstreamMetricsCollector
103103
import numpy as np
104104

105-
log_level = os.getenv("LOG_LEVEL", "WARNING").upper()
105+
from jax.experimental import layout as jax_layout
106+
DLL = jax_layout.DeviceLocalLayout
107+
Layout = jax_layout.Layout
108+
109+
log_level = os.getenv("LOG_LEVEL", "DEBUG").upper()
106110

107111
logger = logging.getLogger("JetstreamLogger")
108112
logger.propagate = False
@@ -405,6 +409,26 @@ def __init__(
405409

406410
self._jax_padding = jax_padding
407411

412+
##### Auto layout compile for interleaved engine
413+
self._generate_executables = [None for _ in self._generate_engines]
414+
self._cached_insert = [None for _ in self._generate_engines]
415+
self._cached_prefill = [None for _ in self._prefill_engines]
416+
if self._interleaved_mode:
417+
for idx in range(len(self._generate_engines)):
418+
logger.debug("Compiling interleaved engine {}".format(idx))
419+
engine = self._generate_engines[idx]
420+
params = self._generate_params[idx]
421+
engine, params, gen_fn, prefill_fn, insert_fn = self._auto_layout_compile(engine, params)
422+
423+
self._prefill_engines[idx] = engine
424+
self._generate_engines[idx] = engine
425+
self._prefill_params[idx] = params
426+
self._generate_params[idx] = params
427+
self._cached_prefill[idx] = prefill_fn
428+
self._cached_insert[idx] = insert_fn
429+
self._generate_executables[idx] = gen_fn
430+
431+
408432
# Create all threads
409433
self._prefill_threads = [
410434
JetThread(
@@ -670,6 +694,56 @@ def _do_chunked_prefill(
670694

671695
return prefill_result, first_token
672696

697+
def _auto_layout_compile(self, engine, params):
698+
logger.debug("Compiling generate function")
699+
generate_executable, params, decode_state_executable = engine.aot_compile(
700+
params, pass_rng_shape=False
701+
)
702+
decode_state = decode_state_executable(None)
703+
704+
# prefill
705+
interesting_buckets = [
706+
64,
707+
128,
708+
256,
709+
512,
710+
1024,
711+
]
712+
713+
cached_prefill = {}
714+
cached_insert = {}
715+
for length in interesting_buckets:
716+
i32_scalar = jax.ShapeDtypeStruct((), int)
717+
logger.debug("Compiling prefill: %d", length)
718+
input_data = jax.ShapeDtypeStruct((length,), jax.numpy.dtype("int32"))
719+
720+
cached_prefill[length] = (
721+
jax.jit(
722+
engine.prefill_aot,
723+
in_shardings=(engine.param_layouts, None, None),
724+
out_shardings=(Layout(DLL.AUTO), Layout(DLL.AUTO)),
725+
).lower(params, input_data, i32_scalar)
726+
).compile(compiler_options=None)
727+
728+
logger.debug("Generate dummy prefix: %d", length)
729+
dummy_tokens = jax.numpy.ones(shape=(length,), dtype=jax.numpy.dtype("int32"))
730+
prefix_shapes = jax.eval_shape(engine.prefill_aot, params, dummy_tokens, 1)
731+
732+
logger.debug("Compiling insert: %d", length)
733+
prefill_output_layout, _ = cached_prefill[length].output_layouts
734+
logger.debug("Prefill output layout: {}".format(prefill_output_layout))
735+
logger.debug("Prefix shapes: {}".format(prefix_shapes))
736+
i32_scalar = jax.ShapeDtypeStruct((), int)
737+
cached_insert[length] = (
738+
jax.jit(
739+
engine.insert,
740+
in_shardings=(prefill_output_layout, engine.decode_state_layouts, None),
741+
out_shardings=(engine.decode_state_layouts),
742+
donate_argnames=("decode_state"),
743+
).lower(prefix_shapes[0], engine.decode_state_shapes, i32_scalar)
744+
).compile(compiler_options=None)
745+
return engine, params, generate_executable, cached_prefill, cached_insert
746+
673747
def _prefill_thread(self, idx: int):
674748
"""Thread which runs in the background performing prefills."""
675749
logger.info("Spinning up prefill thread %d.", idx)
@@ -683,6 +757,12 @@ def _prefill_thread(self, idx: int):
683757
thread_name = f"Prefill thread {idx}"
684758
ThreadDebugLog(thread_name, f"Prefill params {idx} loaded.")
685759

760+
if not self.interleaved:
761+
prefill_engine, prefill_params, gen_fn, prefill_fn, insert_fn = self._auto_layout_compile(
762+
prefill_engine, prefill_params
763+
)
764+
self._cached_prefill[idx] = prefill_fn
765+
686766
while self.live:
687767
my_transfer_backlog = self._transfer_backlogs[idx]
688768
# The prefill thread can just sleep until it has work to do.
@@ -759,10 +839,11 @@ def _prefill_thread(self, idx: int):
759839
)
760840
else:
761841
# Compute new kv cache for the prefill_content.
762-
prefill_result, first_token = prefill_engine.prefill(
763-
params=final_prefill_params,
764-
padded_tokens=padded_tokens,
765-
true_length=true_length,
842+
assert padded_tokens.shape[0] in self._cached_prefill[idx]
843+
prefill_result, first_token = self._cached_prefill[idx][padded_tokens.shape[0]](
844+
final_prefill_params,
845+
padded_tokens,
846+
true_length,
766847
)
767848

768849
request.complete = np.zeros(
@@ -967,10 +1048,11 @@ def _insert_if_possible(
9671048
else:
9681049
break
9691050

970-
decode_state = generate_engine.insert(
1051+
length = new_request.prefill_result['cache']['decoder']['layers_0']['self_attention']['KVCache_0']['cache_prefill_segment_id'].value.shape[1]
1052+
decode_state = self._cached_insert[idx][length](
9711053
new_request.prefill_result,
9721054
decode_state,
973-
slot=slot,
1055+
slot,
9741056
# request_id=new_request.request_id,
9751057
)
9761058
ThreadDebugLog(
@@ -1115,9 +1197,15 @@ def _generate_thread(self, idx: int):
11151197
# Keep track of what step tokens were generated at.
11161198
generate_timestep = 0
11171199
# State to store things like running kv cache in.
1118-
decode_state = generate_engine.init_decode_state()
1119-
1200+
decode_state = self.decode_state
11201201
generate_params = self._generate_params[idx]
1202+
1203+
if not self.interleaved:
1204+
generate_engine, generate_params, gen_fn, prefill_fn, insert_fn = self._auto_layout_compile(
1205+
generate_engine, generate_params
1206+
)
1207+
self._generate_executables[idx] = gen_fn
1208+
11211209
thread_name = f"Generate thread {idx}"
11221210
ThreadDebugLog(thread_name, f"Generate params {idx} loaded.")
11231211
time_of_last_generate = time.time()
@@ -1178,8 +1266,8 @@ def _generate_thread(self, idx: int):
11781266
), "At this point we must have some requests inserted into the slots."
11791267

11801268
# Now we actually take a generate step on requests in the slots.
1181-
decode_state, sampled_tokens = generate_engine.generate(
1182-
generate_params, decode_state
1269+
decode_state, sampled_tokens = self._generate_executables[idx](
1270+
generate_params, decode_state, None
11831271
)
11841272
sampled_tokens.copy_to_host_async()
11851273
# Respond to detokenization backpressure.

0 commit comments

Comments
 (0)