Skip to content

Commit a7213ce

Browse files
author
maxtext authors
committed
[maxtext] improve profiling in decode
Limit the profiling of decode to `config.profiler_steps` steps. User may request generate many tokens, but since each generation is identical, it's not necessary to profile all of them. Profiling only the first few should be sufficient. Also added a method in Profiler to do post-pocessing on the collected trace. PiperOrigin-RevId: 773929904
1 parent 35f3325 commit a7213ce

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

MaxText/decode.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,12 +157,20 @@ def main(argv: Sequence[str]) -> None:
157157
decode_state = engine.insert(prefill_result_list[i], decode_state, slot=i)
158158

159159
# Generate
160+
prof_deactivated = False
160161
steps = range(config.max_prefill_predict_length, config.max_target_length)
161162
sampled_tokens_list.append(_batch_first_result_token(first_token_list, batch_size))
162163
for i in steps:
163164
rng, rng_generate = jax.random.split(rng)
164165
with jax.profiler.StepTraceAnnotation("generate", step=i):
165166
decode_state, sampled_tokens = engine.generate(params, decode_state, rng=rng_generate)
167+
168+
# Automatically deactivate profiler after profiler_steps steps
169+
if i > config.max_prefill_predict_length + config.profiler_steps:
170+
jax.block_until_ready(sampled_tokens)
171+
prof.deactivate()
172+
prof_deactivated = True
173+
166174
sampled_tokens_list.append(sampled_tokens)
167175

168176
# Get results
@@ -176,7 +184,11 @@ def main(argv: Sequence[str]) -> None:
176184
), f"generated text mismatch {output=}, {config.autoregressive_decode_assert=}"
177185

178186
# Deactivate profiler
179-
prof.deactivate()
187+
if not prof_deactivated:
188+
jax.block_until_ready(output)
189+
prof.deactivate()
190+
191+
prof.post_process()
180192

181193
def _validate_config(config):
182194
assert config.load_full_state_path == "", (

MaxText/profiler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,6 @@ def should_activate_periodic_profile(self, step):
9292

9393
def should_deactivate_periodic_profile(self, step):
9494
return self.profile_period > 0 and (step - self.finished_initial_profile_step) % self.profile_period == 0
95+
96+
def post_process(self):
97+
pass

0 commit comments

Comments
 (0)