Skip to content

JetStream changes for Jax based implementation of unified_lora_params for decoding batch of multiple different lora adapters. #222

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions jetstream/core/lora/adapter_tensorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def __init__(
adapters_dir_path: str,
hbm_memory_budget: int,
cpu_memory_budget: int,
total_slots: int,
):
"""Initializes the AdapterTensorStore."""
self.engine = engine # Possibly MaxEngine object
Expand All @@ -119,8 +120,27 @@ def __init__(
self.running_requests: int = (
0 # Number of async tasks which are in "loading" state
)
self.decoding_adapters_cache: Dict[str, Any] = {}

# TODO: Make dtype configurable for the scale factor array
self.adapters_scale_factor = jnp.empty(1, dtype=jnp.bfloat16)

self.total_slots = total_slots
self.lock = asyncio.Lock() # Use an asyncio Lock for thread safety

def _get_adapter_scale_factor(self, adapter_id: str):
"""
Internal: Get the LoRA scale_factor using the adapter_id.
"""
adapter_config = self.adapter_registry[adapter_id].config
lora_scale_factor = float(1)

if "r" in adapter_config and "lora_alpha" in adapter_config:
lora_rank = int(adapter_config["r"])
lora_scale_factor = float(adapter_config["lora_alpha"]) / lora_rank

return lora_scale_factor

# --- Unsafe Internal methods which assumes that lock is held ---
def _unsafe_transfer_to_hbm(self, adapter_id: str):
"""
Expand Down Expand Up @@ -207,6 +227,90 @@ def _unsafe_unload_adapter(self, adapter_id: str):
metadata.size_hbm = 0
metadata.size_cpu = 0

def _initialize_decoding_adapters_cache(self, adapter_weights):
"""
Create a new PyTree with zero tensors at the paths corresponding to
non-None leaves in the input PyTree. The zero tensors have an added
dimension of size `self.totol_slots`.
Args:
adatper_weights: The input PyTree, whose structure will be mirrored.
Returns:
A new PyTree with zero Tensors or None values, mirroring the structure
of the input PyTree.
"""

def create_zero_leaf(leaf):
if leaf is not None:
original_shape = leaf.shape
if not original_shape: # handle scalar case
zero_tensor_shape = (self.total_slots,)
else:
zero_tensor_shape = (
self.total_slots,
) + original_shape # Prepend a new dimension

return jnp.zeros(zero_tensor_shape, dtype=leaf.dtype)
else:
return None # Maintain None structure for None leaves

self.adapters_scale_factor = jnp.ones(self.total_slots, dtype=jnp.bfloat16)
return jax.tree_util.tree_map(create_zero_leaf, adapter_weights)

def insert_adapter_in_cache(self, adapter_id: str, slot_id: int):
"""
Insert the specific adapter tensors into a slot in the
serving_adapters_cache.
Args:
adapter_id: The id of the adapter, whose tensors will be inserted
slot_id: The id of slot, which represents the index in the
serving_adapter_cache where the adapter tensors will be inserted.
"""

def insert_leaf(dest_leaf, source_leaf):
if dest_leaf is not None and source_leaf is not None:
return dest_leaf.at[slot_id].set(
source_leaf
) # Insert at the specific index
elif dest_leaf is not None:
return dest_leaf # If source_leaf is None, keep the zero_leaf as is
elif (
source_leaf is not None
): # In this case the adapters have different target modules
original_shape = source_leaf.shape
if not original_shape: # Handle scalar case
zero_tensor_shape = (self.total_slots,)
else:
zero_tensor_shape = (self.total_slots,) + original_shape
new_dest_leaf = jnp.zeros(zero_tensor_shape, dtype=source_leaf.dtype)
return new_dest_leaf.at[slot_id].set(source_leaf)
else:
return None # If both are None, return None

if adapter_id == "":
logging.info(
"Empty adapter id. No LoRA tensors added to adapter_tensorstore cache"
)
return

asyncio.run(self.load_adapter(adapter_id, None, True))

adapter_weights = self.loaded_adapters_hbm[adapter_id]

if not self.decoding_adapters_cache:
self.decoding_adapters_cache = self._initialize_decoding_adapters_cache(
adapter_weights
)

adapter_scale_factor = jnp.bfloat16(
self._get_adapter_scale_factor(adapter_id)
)
self.adapters_scale_factor = self.adapters_scale_factor.at[slot_id].set(
adapter_scale_factor
)
self.decoding_adapters_cache = jax.tree_util.tree_map(
insert_leaf, self.decoding_adapters_cache, adapter_weights
)

# --- Public Methods (Acquire lock, then call unsafe methods) ---

async def register_adapter(
Expand Down
33 changes: 32 additions & 1 deletion jetstream/core/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,6 +1018,10 @@ def _insert_if_possible(
# Check if there are any free my_slots. We don't want to block here since
# we can still generate if we can't insert. We do this in a while loop to
# insert as many sequences as possible.
adapter_tensorstore = None
if self._generate_adapterstore and idx < len(self._generate_adapterstore):
adapter_tensorstore = self._generate_adapterstore[idx]

while True:
my_slots_size = my_slots.qsize()

Expand Down Expand Up @@ -1086,8 +1090,13 @@ def _insert_if_possible(
new_request.prefill_result,
decode_state,
slot=slot,
# request_id=new_request.request_id,
)

if adapter_tensorstore:
adapter_tensorstore.insert_adapter_in_cache(
new_request.adapter_id, slot
)

ThreadDebugLog(
thread_name,
f"Generate slice {idx} filled slot {slot} at step "
Expand Down Expand Up @@ -1227,6 +1236,10 @@ def _generate_thread(self, idx: int):
my_generate_backlog = self._generate_backlogs[idx]
my_detokenize_backlog = self._detokenize_backlogs[idx]

adapter_tensorstore = None
if self._generate_adapterstore and idx < len(self._generate_adapterstore):
adapter_tensorstore = self._generate_adapterstore[idx]

# Keep track of what step tokens were generated at.
generate_timestep = 0
# State to store things like running kv cache in.
Expand Down Expand Up @@ -1292,6 +1305,24 @@ def _generate_thread(self, idx: int):
my_slots.qsize() < max_concurrent_decodes
), "At this point we must have some requests inserted into the slots."

if adapter_tensorstore:
decoding_adapters_params = adapter_tensorstore.decoding_adapters_cache
adapters_scale_factor = adapter_tensorstore.adapters_scale_factor
b = adapters_scale_factor.shape[0]

# Reshaped the scale_factors array to 4-D to align with shape of
# the vectors `(batch, hidden_size, num_heads, head_dim)`.
reshaped_scale_factors = adapters_scale_factor.reshape((b, 1, 1, 1))

lora_state = {}
lora_state["scale_factor"] = reshaped_scale_factors
lora_state["lora_params"] = decoding_adapters_params

if isinstance(decode_state, dict):
decode_state["lora_state"] = lora_state
else: # flax.struct.dataclass
decode_state = decode_state.replace(lora_state=lora_state)

# Now we actually take a generate step on requests in the slots.
decode_state, sampled_tokens = generate_engine.generate(
generate_params, decode_state
Expand Down
9 changes: 8 additions & 1 deletion jetstream/core/server_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,23 +174,26 @@ def create_driver(
shared_adapterstore = []

if lora_input_adapters_path:
# TODO: Make hbm_memory_budget and cpu_memory_budget configurable
for pe in engines.prefill_engines:
prefill_adapterstore.append(
adapterstore.AdapterTensorStore(
engine=pe,
adapters_dir_path=lora_input_adapters_path,
hbm_memory_budget=20 * (1024**3), # 20 GB HBM
cpu_memory_budget=100 * (1024**3), # 100 GB RAM
total_slots=pe.max_concurrent_decodes,
)
)
# TODO: Make hbm_memory_budget and cpu_memory_budget configurable

for ge in engines.generate_engines:
generate_adapterstore.append(
adapterstore.AdapterTensorStore(
engine=ge,
adapters_dir_path=lora_input_adapters_path,
hbm_memory_budget=20 * (1024**3), # 20 GB HBM
cpu_memory_budget=100 * (1024**3), # 100 GB RAM
total_slots=ge.max_concurrent_decodes,
)
)

Expand All @@ -201,6 +204,7 @@ def create_driver(
adapters_dir_path=lora_input_adapters_path,
hbm_memory_budget=20 * (1024**3), # 20 GB HBM
cpu_memory_budget=100 * (1024**3), # 100 GB RAM
total_slots=ie.max_concurrent_decodes,
)
)

Expand Down Expand Up @@ -315,6 +319,9 @@ def run(
"Not starting Prometheus server: --prometheus_port flag not set"
)

if multi_sampling and lora_input_adapters_path:
raise ValueError("LoRA adapters is not enabled for multi_sampling mode.")

driver = create_driver(
config,
devices,
Expand Down
4 changes: 3 additions & 1 deletion jetstream/engine/mock_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

import functools
from dataclasses import asdict
from typing import Any, Callable, Optional, Tuple
from typing import Any, Dict, Callable, Optional, Tuple

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -71,6 +71,7 @@ class DecodeState:
generate_cache_index: int
generate_lengths: jax.Array
generate_tokens: jax.Array
lora_state: Optional[Dict[str, Any]] = None


class TestEngine(engine_api.Engine):
Expand Down Expand Up @@ -509,6 +510,7 @@ def init_decode_state(self) -> DecodeState:
generate_tokens=jnp.zeros(
(self.generate_cache_batch, 1), dtype=jnp.int32
),
lora_state={},
)

@property
Expand Down
1 change: 1 addition & 0 deletions jetstream/tests/core/lora/test_adapter_tensorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ async def asyncSetUp(self):
adapters_dir_path=self.adapters_dir_path,
hbm_memory_budget=self.hbm_budget,
cpu_memory_budget=self.cpu_budget,
total_slots=8,
)

# Pre-register adapters for most tests to simplify setup
Expand Down
2 changes: 2 additions & 0 deletions jetstream/tests/core/test_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,15 @@ async def _setup_driver_with_adapterstore(
adapters_dir_path="/tmp/",
hbm_memory_budget=20 * (1024**3), # 20 GB HBM
cpu_memory_budget=100 * (1024**3), # 100 GB RAM
total_slots=8,
)

generate_adapterstore = adapterstore.AdapterTensorStore(
engine=generate_engine,
adapters_dir_path="/tmp/",
hbm_memory_budget=20 * (1024**3), # 20 GB HBM
cpu_memory_budget=100 * (1024**3), # 100 GB RAM
total_slots=8,
)

await prefill_adapterstore.register_adapter(
Expand Down