Skip to content

Commit 2c4f6d9

Browse files
authored
Add unit tests to increase coverage for multi-sampling (#220)
* Add unit tests to increase coverage for multi-samplingh
1 parent 698f33f commit 2c4f6d9

File tree

3 files changed

+67
-16
lines changed

3 files changed

+67
-16
lines changed

jetstream/engine/mock_engine.py

+20-14
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,9 @@ def prefill(
192192
)
193193
return (prefix, result_tokens)
194194

195-
@functools.partial(jax.jit, static_argnums=(0,))
195+
@functools.partial(
196+
jax.jit, static_argnums=(0,), static_argnames=("num_samples",)
197+
)
196198
def prefill_multisampling(
197199
self,
198200
*,
@@ -216,26 +218,30 @@ def prefill_multisampling(
216218
# Generate dummy prefill cache content
217219
prefill_cache = padded_tokens[None, :] * params
218220

219-
# Create a dummy first generated token.
220-
first_generated_token = (prefill_cache.sum(axis=-1).astype(jnp.int32))[
221-
:, jnp.newaxis
222-
]
221+
# Create dummy first generated tokens.
222+
first_generated_tokens = []
223+
for _ in range(num_samples):
224+
first_generated_token = (prefill_cache.sum(axis=-1).astype(jnp.int32))[
225+
:, jnp.newaxis
226+
]
227+
first_generated_tokens.append(first_generated_token)
228+
first_generated_tokens = jnp.concatenate(first_generated_tokens, axis=0)
223229

224230
prefix = Prefix(
225231
logits=jax.random.normal(self._prng_key, (1, self.vocab_size)),
226232
cache=prefill_cache,
227233
next_pos=jnp.full((1, 1), true_length, dtype=jnp.int32),
228-
num_generated_tokens=jnp.zeros((1, 1), dtype=jnp.int32),
229-
first_token=first_generated_token,
234+
num_generated_tokens=jnp.zeros((num_samples, 1), dtype=jnp.int32),
235+
first_token=first_generated_tokens,
230236
)
231237

232238
speculations = first_generated_token.shape[1]
233239
result_tokens = engine_api.ResultTokens(
234240
data=jnp.concatenate(
235241
(
236-
first_generated_token,
237-
jnp.ones_like(first_generated_token),
238-
jnp.ones_like(first_generated_token),
242+
first_generated_tokens,
243+
jnp.ones_like(first_generated_tokens),
244+
jnp.ones_like(first_generated_tokens),
239245
),
240246
axis=-1,
241247
),
@@ -244,7 +250,7 @@ def prefill_multisampling(
244250
valid_idx=(speculations, 2 * speculations),
245251
# And lengths is rank 1.
246252
length_idx=(2 * speculations, 2 * speculations + 1),
247-
samples_per_slot=self.generate_cache_batch // self.prefill_cache_batch,
253+
samples_per_slot=num_samples,
248254
)
249255
return (prefix, result_tokens)
250256

@@ -398,21 +404,21 @@ def bulk_insert(
398404
"""Insert a single computed prefill cache into multiple slots in
399405
KV cache.
400406
"""
401-
prefill_cache = prefix.cache
407+
prefill_cache = decode_state.prefill_cache
402408
generate_cache = decode_state.generate_cache
403409
generate_lengths = decode_state.generate_lengths
404410
generate_tokens = decode_state.generate_tokens
405411
for slot in slots:
406412
prefill_cache = jax.lax.dynamic_update_slice_in_dim(
407-
decode_state.prefill_cache, prefill_cache, slot, axis=0
413+
prefill_cache, prefix.cache, slot, axis=0
408414
)
409415
generate_cache = jax.lax.dynamic_update_slice_in_dim(
410416
generate_cache,
411417
jnp.zeros((1, self.cache_length)),
412418
slot,
413419
axis=0,
414420
)
415-
samples_per_slot = self.generate_cache_batch // self.prefill_cache_batch
421+
samples_per_slot = 1
416422
generate_lengths = jax.lax.dynamic_update_slice_in_dim(
417423
generate_lengths,
418424
jnp.ones((samples_per_slot), dtype=jnp.int32),

jetstream/engine/warmup_utils.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,13 @@ def initialize_prefill_jit_cache(
9999
def compile_prefill(length):
100100
padded_tokens, true_length = jnp.ones((length), dtype="int32"), length
101101

102-
_, _ = prefill_engine._downstream_engine.prefill( # pylint: disable=protected-access
102+
_, _ = prefill_engine.prefill(
103+
params=prefill_params,
104+
padded_tokens=padded_tokens,
105+
true_length=true_length,
106+
)
107+
108+
_, _ = prefill_engine.prefill_multisampling(
103109
params=prefill_params,
104110
padded_tokens=padded_tokens,
105111
true_length=true_length,
@@ -166,6 +172,10 @@ def compile_insert(length):
166172

167173
generate_engine.insert(prefix=prefill, decode_state=decode_state, slot=0)
168174

175+
generate_engine.bulk_insert(
176+
prefix=prefill, decode_state=decode_state, slots=[0]
177+
)
178+
169179
logging.info(
170180
"---------Generate engine %d compiled for insert length %d.---------",
171181
generate_idx,

jetstream/tests/core/test_orchestrator.py

+36-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@
5151

5252
class OrchestratorTest(unittest.IsolatedAsyncioTestCase):
5353

54-
def _setup_driver(self, interleaved_mode: bool = True):
54+
def _setup_driver(
55+
self, interleaved_mode: bool = True, multi_sampling: bool = False
56+
):
5557
prefill_engine = mock_engine.TestEngine(
5658
batch_size=32, cache_length=256, weight=2.0
5759
)
@@ -66,6 +68,7 @@ def _setup_driver(self, interleaved_mode: bool = True):
6668
prefill_params=[prefill_engine.load_params()],
6769
generate_params=[generate_engine.load_params()],
6870
interleaved_mode=interleaved_mode,
71+
multi_sampling=multi_sampling,
6972
)
7073
return driver
7174

@@ -150,6 +153,38 @@ async def test_orchestrator(self, interleaved_mode: bool):
150153
driver.stop()
151154
print("Orchestrator driver stopped.")
152155

156+
@parameterized.expand([1, 2, 3, 4])
157+
async def test_orchestrator_multi_sampling(self, num_samples: int):
158+
"""Test the multithreaded orchestration."""
159+
driver = self._setup_driver(interleaved_mode=True, multi_sampling=True)
160+
client = orchestrator.LLMOrchestrator(driver=driver)
161+
162+
# The string representation of np.array([[65, 66]]), [2] will be prepend
163+
# as BOS.
164+
text = "AB"
165+
166+
request = jetstream_pb2.DecodeRequest(
167+
text_content=jetstream_pb2.DecodeRequest.TextContent(text=text),
168+
max_tokens=3,
169+
num_samples=num_samples,
170+
)
171+
iterator = client.Decode(request)
172+
# chr of [266, 332, 415].
173+
expected_text = ["Ċ", "Ō", "Ɵ", ""]
174+
expected_token_ids = [266, 332, 415, None]
175+
counter = 0
176+
async for resp in iterator:
177+
for sample in resp.stream_content.samples:
178+
output_text = sample.text
179+
token_ids = sample.token_ids
180+
output_token_id = token_ids[0] if len(token_ids) > 0 else None
181+
print(f"actual output: {output_text=} {output_token_id=}")
182+
assert output_text == expected_text[counter]
183+
assert output_token_id == expected_token_ids[counter]
184+
counter += 1
185+
driver.stop()
186+
print("Orchestrator driver stopped.")
187+
153188
@parameterized.expand([True, False])
154189
async def test_orchestrator_client_tokenization_chunked_prefill(
155190
self, interleaved_mode: bool

0 commit comments

Comments
 (0)