-
Notifications
You must be signed in to change notification settings - Fork 37
/
Copy pathtest_orchestrator.py
266 lines (237 loc) · 9.9 KB
/
test_orchestrator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Integration test of the orchestrator.
This test tests the multithreaded orchestrator, where a prefill request is
popped onto a prefill queue, prefilled, sent to a generation queue and run for
a number of decoding steps.
In operation, it will use gRPC so we can 'yield' from the function to get return
values in the same way that they would be streamed back.
Similar to 'mock_engine_test' we can use known token values and a singleton
weight to test our operation.
Let the prefill engine have a weight of [2] and the generate engine have a
weight of [3].
I.e. if we prefill [2, 65, 66] (i.e. <BOS>, 'A', 'B') using an ACII vocab,
we should get [4, 130, 132].
If we then insert that and run three generation steps, we should see
266+0 / 2 = 266
266 + [266] / 4 = 332
266 + [266, 332] / 4 = 415
I.e. ['Ċ', 'Ō', 'Ɵ'] when converted back with chr().
Therefore we should get back the character sequence '$lǔ' if we request 3 tokens
decoded (these are the ascii chars at those indices which is what the test
tokenizer returns).
"""
import unittest
from parameterized import parameterized
from jetstream.core import orchestrator
from jetstream.core.proto import jetstream_pb2
from jetstream.core.utils.return_sample import ReturnSample
from jetstream.engine import mock_engine
class OrchestratorTest(unittest.IsolatedAsyncioTestCase):
def _setup_driver(
self, interleaved_mode: bool = True, multi_sampling: bool = False
):
prefill_engine = mock_engine.TestEngine(
batch_size=32, cache_length=256, weight=2.0
)
# Create a generate engine with a different set of weights
# so that we can test that the right one is in use at a given time.
generate_engine = mock_engine.TestEngine(
batch_size=4, cache_length=32, weight=4.0
)
driver = orchestrator.Driver(
prefill_engines=[prefill_engine],
generate_engines=[generate_engine],
prefill_params=[prefill_engine.load_params()],
generate_params=[generate_engine.load_params()],
interleaved_mode=interleaved_mode,
multi_sampling=multi_sampling,
)
return driver
def _setup_driver_chunked_prefill(self, interleaved_mode: bool = True):
prefill_engine = mock_engine.TestEngine(
batch_size=32, cache_length=256, weight=2.0, use_chunked_prefill=True
)
# Create a generate engine with a different set of weights
# so that we can test that the right one is in use at a given time.
generate_engine = mock_engine.TestEngine(
batch_size=4,
cache_length=32,
weight=4.0,
use_chunked_prefill=True,
)
driver = orchestrator.Driver(
prefill_engines=[prefill_engine],
generate_engines=[generate_engine],
prefill_params=[prefill_engine.load_params_dict()],
generate_params=[generate_engine.load_params()],
interleaved_mode=interleaved_mode,
)
return driver
@unittest.skip("Rewrite mock engine to test chunked prefill call correctly.")
@parameterized.expand([True, False])
async def test_orchestrator_chunked_prefill(self, interleaved_mode: bool):
"""Test the multithreaded orchestration."""
driver = self._setup_driver_chunked_prefill(interleaved_mode)
client = orchestrator.LLMOrchestrator(driver=driver)
# The string representation of np.array([[65, 66]]), [2] will be prepend
# as BOS.
text = "AB"
request = jetstream_pb2.DecodeRequest(
text_content=jetstream_pb2.DecodeRequest.TextContent(text=text),
max_tokens=3,
)
iterator = client.Decode(request)
# chr of [135, 168, 210].
expected_text = ["\x85", "¦", "Ï", ""]
expected_token_ids = [133, 166, 207, None]
counter = 0
async for resp in iterator:
output_text = resp.stream_content.samples[0].text
token_ids = resp.stream_content.samples[0].token_ids
output_token_id = token_ids[0] if len(token_ids) > 0 else None
print(f"actual output: {output_text=} {output_token_id=}")
assert output_text == expected_text[counter]
assert output_token_id == expected_token_ids[counter]
counter += 1
driver.stop()
print("Orchestrator driver stopped.")
@parameterized.expand([True, False])
async def test_orchestrator(self, interleaved_mode: bool):
"""Test the multithreaded orchestration."""
driver = self._setup_driver(interleaved_mode)
client = orchestrator.LLMOrchestrator(driver=driver)
# The string representation of np.array([[65, 66]]), [2] will be prepend
# as BOS.
text = "AB"
request = jetstream_pb2.DecodeRequest(
text_content=jetstream_pb2.DecodeRequest.TextContent(text=text),
max_tokens=3,
)
iterator = client.Decode(request)
# chr of [266, 332, 415].
expected_text = ["Ċ", "Ō", "Ɵ", ""]
expected_token_ids = [266, 332, 415, None]
counter = 0
async for resp in iterator:
output_text = resp.stream_content.samples[0].text
token_ids = resp.stream_content.samples[0].token_ids
output_token_id = token_ids[0] if len(token_ids) > 0 else None
print(f"actual output: {output_text=} {output_token_id=}")
assert output_text == expected_text[counter]
assert output_token_id == expected_token_ids[counter]
counter += 1
driver.stop()
print("Orchestrator driver stopped.")
@parameterized.expand([1, 2, 3, 4])
async def test_orchestrator_multi_sampling(self, num_samples: int):
"""Test the multithreaded orchestration."""
driver = self._setup_driver(interleaved_mode=True, multi_sampling=True)
client = orchestrator.LLMOrchestrator(driver=driver)
# The string representation of np.array([[65, 66]]), [2] will be prepend
# as BOS.
text = "AB"
request = jetstream_pb2.DecodeRequest(
text_content=jetstream_pb2.DecodeRequest.TextContent(text=text),
max_tokens=3,
num_samples=num_samples,
)
iterator = client.Decode(request)
# chr of [266, 332, 415].
expected_text = ["Ċ", "Ō", "Ɵ", ""]
expected_token_ids = [266, 332, 415, None]
counter = 0
async for resp in iterator:
for sample in resp.stream_content.samples:
output_text = sample.text
token_ids = sample.token_ids
output_token_id = token_ids[0] if len(token_ids) > 0 else None
print(f"actual output: {output_text=} {output_token_id=}")
assert output_text == expected_text[counter]
assert output_token_id == expected_token_ids[counter]
counter += 1
driver.stop()
print("Orchestrator driver stopped.")
@unittest.skip("Rewrite mock engine to test chunked prefill call correctly.")
@parameterized.expand([True, False])
async def test_orchestrator_client_tokenization_chunked_prefill(
self, interleaved_mode: bool
):
"""Test the multithreaded orchestration."""
driver = self._setup_driver_chunked_prefill(interleaved_mode)
client = orchestrator.LLMOrchestrator(driver=driver)
# The token ids of string "AB", [2] will be prepend
# as BOS.
token_ids = [65, 66]
request = jetstream_pb2.DecodeRequest(
token_content=jetstream_pb2.DecodeRequest.TokenContent(
token_ids=token_ids
),
max_tokens=3,
)
iterator = client.Decode(request)
# Return token ids only when in client side tokenization mode.
expected_text = ["", "", "", ""]
expected_token_ids = [133, 166, 207, None]
counter = 0
async for resp in iterator:
output_text = resp.stream_content.samples[0].text
token_ids = resp.stream_content.samples[0].token_ids
output_token_id = token_ids[0] if len(token_ids) > 0 else None
print(f"actual output: {output_text=} {output_token_id=}")
assert output_text == expected_text[counter]
assert output_token_id == expected_token_ids[counter]
counter += 1
driver.stop()
print("Orchestrator driver stopped.")
@parameterized.expand([True, False])
async def test_orchestrator_client_tokenization(self, interleaved_mode: bool):
"""Test the multithreaded orchestration."""
driver = self._setup_driver(interleaved_mode)
client = orchestrator.LLMOrchestrator(driver=driver)
# The token ids of string "AB", [2] will be prepend
# as BOS.
token_ids = [65, 66]
request = jetstream_pb2.DecodeRequest(
token_content=jetstream_pb2.DecodeRequest.TokenContent(
token_ids=token_ids
),
max_tokens=3,
)
iterator = client.Decode(request)
# Return token ids only when in client side tokenization mode.
expected_text = ["", "", "", ""]
expected_token_ids = [266, 332, 415, None]
counter = 0
async for resp in iterator:
output_text = resp.stream_content.samples[0].text
token_ids = resp.stream_content.samples[0].token_ids
output_token_id = token_ids[0] if len(token_ids) > 0 else None
print(f"actual output: {output_text=} {output_token_id=}")
assert output_text == expected_text[counter]
assert output_token_id == expected_token_ids[counter]
counter += 1
driver.stop()
print("Orchestrator driver stopped.")
@parameterized.expand([True, False])
def test_should_buffer_response(self, interleaved_mode: bool):
driver = self._setup_driver(interleaved_mode)
client = orchestrator.LLMOrchestrator(driver=driver)
self.assertTrue(
client.should_buffer_response(
[ReturnSample(text=["<0xAB>"], token_ids=[13])]
)
)
driver.stop()
print("Orchestrator driver stopped.")