Skip to content

Commit 8839d1a

Browse files
committed
- JetStream changes for Jax based implementation of unified_lora_params for decoding batch of multiple different lora adapters.
- Creating a cache with first dimension equals to the number of slots and holding the adapter_weights for inference at that slot. - Added functionality to have different scale factor for different adapters in a batch.
1 parent 89acc8c commit 8839d1a

File tree

6 files changed

+150
-3
lines changed

6 files changed

+150
-3
lines changed

jetstream/core/lora/adapter_tensorstore.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def __init__(
9999
adapters_dir_path: str,
100100
hbm_memory_budget: int,
101101
cpu_memory_budget: int,
102+
total_slots: int,
102103
):
103104
"""Initializes the AdapterTensorStore."""
104105
self.engine = engine # Possibly MaxEngine object
@@ -119,8 +120,27 @@ def __init__(
119120
self.running_requests: int = (
120121
0 # Number of async tasks which are in "loading" state
121122
)
123+
self.decoding_adapters_cache: Dict[str, Any] = {}
124+
125+
# TODO: Make dtype configurable for the scale factor array
126+
self.adapters_scale_factor = jnp.empty(1, dtype=jnp.bfloat16)
127+
128+
self.total_slots = total_slots
122129
self.lock = asyncio.Lock() # Use an asyncio Lock for thread safety
123130

131+
def _get_adapter_scale_factor(self, adapter_id: str):
132+
"""
133+
Internal: Get the LoRA scale_factor using the adapter_id.
134+
"""
135+
adapter_config = self.adapter_registry[adapter_id].config
136+
lora_scale_factor = float(1)
137+
138+
if "r" in adapter_config and "lora_alpha" in adapter_config:
139+
lora_rank = int(adapter_config["r"])
140+
lora_scale_factor = float(adapter_config["lora_alpha"]) / lora_rank
141+
142+
return lora_scale_factor
143+
124144
# --- Unsafe Internal methods which assumes that lock is held ---
125145
def _unsafe_transfer_to_hbm(self, adapter_id: str):
126146
"""
@@ -207,6 +227,90 @@ def _unsafe_unload_adapter(self, adapter_id: str):
207227
metadata.size_hbm = 0
208228
metadata.size_cpu = 0
209229

230+
def _initialize_decoding_adapters_cache(self, adapter_weights):
231+
"""
232+
Create a new PyTree with zero tensors at the paths corresponding to
233+
non-None leaves in the input PyTree. The zero tensors have an added
234+
dimension of size `self.totol_slots`.
235+
Args:
236+
adatper_weights: The input PyTree, whose structure will be mirrored.
237+
Returns:
238+
A new PyTree with zero Tensors or None values, mirroring the structure
239+
of the input PyTree.
240+
"""
241+
242+
def create_zero_leaf(leaf):
243+
if leaf is not None:
244+
original_shape = leaf.shape
245+
if not original_shape: # handle scalar case
246+
zero_tensor_shape = (self.total_slots,)
247+
else:
248+
zero_tensor_shape = (
249+
self.total_slots,
250+
) + original_shape # Prepend a new dimension
251+
252+
return jnp.zeros(zero_tensor_shape, dtype=leaf.dtype)
253+
else:
254+
return None # Maintain None structure for None leaves
255+
256+
self.adapters_scale_factor = jnp.ones(self.total_slots, dtype=jnp.bfloat16)
257+
return jax.tree_util.tree_map(create_zero_leaf, adapter_weights)
258+
259+
def insert_adapter_in_cache(self, adapter_id: str, slot_id: int):
260+
"""
261+
Insert the specific adapter tensors into a slot in the
262+
serving_adapters_cache.
263+
Args:
264+
adapter_id: The id of the adapter, whose tensors will be inserted
265+
slot_id: The id of slot, which represents the index in the
266+
serving_adapter_cache where the adapter tensors will be inserted.
267+
"""
268+
269+
def insert_leaf(dest_leaf, source_leaf):
270+
if dest_leaf is not None and source_leaf is not None:
271+
return dest_leaf.at[slot_id].set(
272+
source_leaf
273+
) # Insert at the specific index
274+
elif dest_leaf is not None:
275+
return dest_leaf # If source_leaf is None, keep the zero_leaf as is
276+
elif (
277+
source_leaf is not None
278+
): # In this case the adapters have different target modules
279+
original_shape = source_leaf.shape
280+
if not original_shape: # Handle scalar case
281+
zero_tensor_shape = (self.total_slots,)
282+
else:
283+
zero_tensor_shape = (self.total_slots,) + original_shape
284+
new_dest_leaf = jnp.zeros(zero_tensor_shape, dtype=source_leaf.dtype)
285+
return new_dest_leaf.at[slot_id].set(source_leaf)
286+
else:
287+
return None # If both are None, return None
288+
289+
if adapter_id == "":
290+
logging.info(
291+
"Empty adapter id. No LoRA tensors added to adapter_tensorstore cache"
292+
)
293+
return
294+
295+
asyncio.run(self.load_adapter(adapter_id, None, True))
296+
297+
adapter_weights = self.loaded_adapters_hbm[adapter_id]
298+
299+
if not self.decoding_adapters_cache:
300+
self.decoding_adapters_cache = self._initialize_decoding_adapters_cache(
301+
adapter_weights
302+
)
303+
304+
adapter_scale_factor = jnp.bfloat16(
305+
self._get_adapter_scale_factor(adapter_id)
306+
)
307+
self.adapters_scale_factor = self.adapters_scale_factor.at[slot_id].set(
308+
adapter_scale_factor
309+
)
310+
self.decoding_adapters_cache = jax.tree_util.tree_map(
311+
insert_leaf, self.decoding_adapters_cache, adapter_weights
312+
)
313+
210314
# --- Public Methods (Acquire lock, then call unsafe methods) ---
211315

212316
async def register_adapter(

jetstream/core/orchestrator.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1018,6 +1018,10 @@ def _insert_if_possible(
10181018
# Check if there are any free my_slots. We don't want to block here since
10191019
# we can still generate if we can't insert. We do this in a while loop to
10201020
# insert as many sequences as possible.
1021+
adapter_tensorstore = None
1022+
if self._generate_adapterstore and idx < len(self._generate_adapterstore):
1023+
adapter_tensorstore = self._generate_adapterstore[idx]
1024+
10211025
while True:
10221026
my_slots_size = my_slots.qsize()
10231027

@@ -1086,8 +1090,13 @@ def _insert_if_possible(
10861090
new_request.prefill_result,
10871091
decode_state,
10881092
slot=slot,
1089-
# request_id=new_request.request_id,
10901093
)
1094+
1095+
if adapter_tensorstore:
1096+
adapter_tensorstore.insert_adapter_in_cache(
1097+
new_request.adapter_id, slot
1098+
)
1099+
10911100
ThreadDebugLog(
10921101
thread_name,
10931102
f"Generate slice {idx} filled slot {slot} at step "
@@ -1227,6 +1236,10 @@ def _generate_thread(self, idx: int):
12271236
my_generate_backlog = self._generate_backlogs[idx]
12281237
my_detokenize_backlog = self._detokenize_backlogs[idx]
12291238

1239+
adapter_tensorstore = None
1240+
if self._generate_adapterstore and idx < len(self._generate_adapterstore):
1241+
adapter_tensorstore = self._generate_adapterstore[idx]
1242+
12301243
# Keep track of what step tokens were generated at.
12311244
generate_timestep = 0
12321245
# State to store things like running kv cache in.
@@ -1292,6 +1305,24 @@ def _generate_thread(self, idx: int):
12921305
my_slots.qsize() < max_concurrent_decodes
12931306
), "At this point we must have some requests inserted into the slots."
12941307

1308+
if adapter_tensorstore:
1309+
decoding_adapters_params = adapter_tensorstore.decoding_adapters_cache
1310+
adapters_scale_factor = adapter_tensorstore.adapters_scale_factor
1311+
b = adapters_scale_factor.shape[0]
1312+
1313+
# Reshaped the scale_factors array to 4-D to align with shape of
1314+
# the vectors `(batch, hidden_size, num_heads, head_dim)`.
1315+
reshaped_scale_factors = adapters_scale_factor.reshape((b, 1, 1, 1))
1316+
1317+
lora_state = {}
1318+
lora_state["scale_factor"] = reshaped_scale_factors
1319+
lora_state["lora_params"] = decoding_adapters_params
1320+
1321+
if isinstance(decode_state, dict):
1322+
decode_state["lora_state"] = lora_state
1323+
else: # flax.struct.dataclass
1324+
decode_state = decode_state.replace(lora_state=lora_state)
1325+
12951326
# Now we actually take a generate step on requests in the slots.
12961327
decode_state, sampled_tokens = generate_engine.generate(
12971328
generate_params, decode_state

jetstream/core/server_lib.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,23 +174,26 @@ def create_driver(
174174
shared_adapterstore = []
175175

176176
if lora_input_adapters_path:
177+
# TODO: Make hbm_memory_budget and cpu_memory_budget configurable
177178
for pe in engines.prefill_engines:
178179
prefill_adapterstore.append(
179180
adapterstore.AdapterTensorStore(
180181
engine=pe,
181182
adapters_dir_path=lora_input_adapters_path,
182183
hbm_memory_budget=20 * (1024**3), # 20 GB HBM
183184
cpu_memory_budget=100 * (1024**3), # 100 GB RAM
185+
total_slots=pe.max_concurrent_decodes,
184186
)
185187
)
186-
# TODO: Make hbm_memory_budget and cpu_memory_budget configurable
188+
187189
for ge in engines.generate_engines:
188190
generate_adapterstore.append(
189191
adapterstore.AdapterTensorStore(
190192
engine=ge,
191193
adapters_dir_path=lora_input_adapters_path,
192194
hbm_memory_budget=20 * (1024**3), # 20 GB HBM
193195
cpu_memory_budget=100 * (1024**3), # 100 GB RAM
196+
total_slots=ge.max_concurrent_decodes,
194197
)
195198
)
196199

@@ -201,6 +204,7 @@ def create_driver(
201204
adapters_dir_path=lora_input_adapters_path,
202205
hbm_memory_budget=20 * (1024**3), # 20 GB HBM
203206
cpu_memory_budget=100 * (1024**3), # 100 GB RAM
207+
total_slots=ie.max_concurrent_decodes,
204208
)
205209
)
206210

@@ -315,6 +319,9 @@ def run(
315319
"Not starting Prometheus server: --prometheus_port flag not set"
316320
)
317321

322+
if multi_sampling and lora_input_adapters_path:
323+
raise ValueError("LoRA adapters is not enabled for multi_sampling mode.")
324+
318325
driver = create_driver(
319326
config,
320327
devices,

jetstream/engine/mock_engine.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333
import functools
3434
from dataclasses import asdict
35-
from typing import Any, Callable, Optional, Tuple
35+
from typing import Any, Dict, Callable, Optional, Tuple
3636

3737
import jax
3838
import jax.numpy as jnp
@@ -71,6 +71,7 @@ class DecodeState:
7171
generate_cache_index: int
7272
generate_lengths: jax.Array
7373
generate_tokens: jax.Array
74+
lora_state: Optional[Dict[str, Any]] = None
7475

7576

7677
class TestEngine(engine_api.Engine):
@@ -509,6 +510,7 @@ def init_decode_state(self) -> DecodeState:
509510
generate_tokens=jnp.zeros(
510511
(self.generate_cache_batch, 1), dtype=jnp.int32
511512
),
513+
lora_state={},
512514
)
513515

514516
@property

jetstream/tests/core/lora/test_adapter_tensorstore.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ async def asyncSetUp(self):
145145
adapters_dir_path=self.adapters_dir_path,
146146
hbm_memory_budget=self.hbm_budget,
147147
cpu_memory_budget=self.cpu_budget,
148+
total_slots=8,
148149
)
149150

150151
# Pre-register adapters for most tests to simplify setup

jetstream/tests/core/test_orchestrator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,13 +123,15 @@ async def _setup_driver_with_adapterstore(
123123
adapters_dir_path="/tmp/",
124124
hbm_memory_budget=20 * (1024**3), # 20 GB HBM
125125
cpu_memory_budget=100 * (1024**3), # 100 GB RAM
126+
total_slots=8,
126127
)
127128

128129
generate_adapterstore = adapterstore.AdapterTensorStore(
129130
engine=generate_engine,
130131
adapters_dir_path="/tmp/",
131132
hbm_memory_budget=20 * (1024**3), # 20 GB HBM
132133
cpu_memory_budget=100 * (1024**3), # 100 GB RAM
134+
total_slots=8,
133135
)
134136

135137
await prefill_adapterstore.register_adapter(

0 commit comments

Comments
 (0)