Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JetStream changes for Jax based implementation of unified_lora_params for decoding batch of multiple different lora adapters. #222

Open
wants to merge 3 commits into
base: amangu-lora
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 77 additions & 1 deletion jetstream/core/lora/adapter_tensorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,18 @@ class AdapterTensorStore:
to manage memory usage. It supports asynchronous loading and unloading
of adapters to avoid blocking the main inference thread.

This class also creates a unified_lora_weights of all the adapters which is being
used at any time for decoding purposes. These unified weights allows the backend
model to server multiple different LoRA adapters in a single batch.

Args:
hbm_memory_budget (int): The maximum amount of HBM (in bytes) to use for
storing LoRA adapter weights.
cpu_memory_budget (int): The maximum amount of CPU RAM (in bytes) to use
for storing LoRA adapter weights.
total_slots: Number of generate slots. This is also equals to max_concurrent_decodes.
"""


def __init__(self,
engine: engine_api.Engine,
adapters_dir_path: str,
Expand All @@ -106,6 +110,8 @@ def __init__(self,
self.current_hbm_usage: int = 0
self.current_cpu_usage: int = 0
self.running_requests: int = 0 # Number of async tasks which are in "loading" state
self.decoding_adapters_cache: Dict[str, Any] = {}
self.total_slots = total_slots
self.lock = asyncio.Lock() # Use an asyncio Lock for thread safety


Expand Down Expand Up @@ -211,6 +217,76 @@ async def _transfer_to_cpu(self, adapter_id: str):
metadata.last_accessed = time.time()


def _initialize_decoding_adapters_cache(self, adapter_weights):
"""
Create a new PyTree with zero tensors at the paths corresponding to non-None leaves
in the input PyTree. The zero tensors have an added dimension of size `self.totol_slots`.

Args:
adatper_weights: The input PyTree, whose structure will be mirrored.

Returns:
A new PyTree with zero Tensors or None values, mirroring the structure of the input PyTree.
"""
def create_zero_leaf(leaf):
if leaf is not None:
original_shape = leaf.shape
if not original_shape: # handle scalar case
zero_tensor_shape = (self.total_slots,)
else:
zero_tensor_shape = (self.total_slots,) + original_shape # Prepend a new dimension

return jnp.zeros(zero_tensor_shape, dtype=leaf.dtype)
else:
return None # Maintain None structure for None leaves

return jax.tree_util.tree_map(create_zero_leaf, adapter_weights)


def insert_adapter_in_cache(self, adapter_id: str, slot_id: int):
"""
Insert the specific adapter tensors into a slot in the serving_adapters_cache.

Args:
adapter_id: The id of the adapter, whose tensors will be inserted
slot_id: The id of slot, which represents the index in the serving_adapter_cache
where the adapter tensors will be inserted.
"""

def insert_leaf(dest_leaf, source_leaf):
if dest_leaf is not None and source_leaf is not None:
return dest_leaf.at[slot_id].set(source_leaf) # Insert at the specific index
elif dest_leaf is not None:
return dest_leaf # If source_leaf is None, keep the zero_leaf as is
elif source_leaf is not None: # In this case the adapters have different target modules
original_shape = source_leaf.shape
if not original_shape: # Handle scalar case
zero_tensor_shape = (self.total_slots,)
else:
zero_tensor_shape = (self.total_slots,) + original_shape
new_dest_leaf = jnp.zeros(zero_tensor_shape, dtype=source_leaf.dtype)
return new_dest_leaf.at[slot_id].set(source_leaf)
else:
return None # If both are None, return None

if adapter_id == "":
logging.info("Empty adapter id. So no LoRA tensors inserted into the cache in adapter_tensorStore.")
return

metadata = self.adapter_registry[adapter_id]

asyncio.run(self.load_adapter(adapter_id, True))

adapter_weights = self.loaded_adapters_hbm[adapter_id]

if not self.decoding_adapters_cache:
self.decoding_adapters_cache = self._initialize_decoding_adapters_cache(adapter_weights)

self.decoding_adapters_cache = jax.tree_util.tree_map(insert_leaf,
self.decoding_adapters_cache,
adapter_weights)


async def get_hbm_loaded_adapters(self):
"""Returns a comma separated list of adapters loaded into HBM."""

Expand Down
14 changes: 13 additions & 1 deletion jetstream/core/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,15 @@ def __init__(
self._metrics_collector = metrics_collector
self._multi_sampling = multi_sampling

total_slots = 0
for engine in self._generate_engines:
total_slots += engine.max_concurrent_decodes

self._adapter_tensorstore = adapter_tensorstore.AdapterTensorStore(
hbm_memory_budget=(20 * (1024 ** 3)), # 20 GB HBM
cpu_memory_budget=(100 * (1024 ** 3)), # 100 GB RAM
total_slots=total_slots)

# Stages 1-4 represent the life cycle of a request.
# Stage 1
# At first, a request is placed here in order to get prefilled.
Expand Down Expand Up @@ -969,6 +978,9 @@ def _insert_if_possible(
slot=slot,
#request_id=new_request.request_id,
)

self._adapter_tensorstore.insert_adapter_in_cache(new_request.adapter_id, slot)

ThreadDebugLog(
thread_name,
f"Generate slice {idx} filled slot {slot} at step "
Expand Down Expand Up @@ -1175,7 +1187,7 @@ def _generate_thread(self, idx: int):

# Now we actually take a generate step on requests in the slots.
decode_state, sampled_tokens = generate_engine.generate(
generate_params, decode_state
generate_params, decode_state, self._adapter_tensorstore.decoding_adapters_cache,
)
sampled_tokens.copy_to_host_async()
# Respond to detokenization backpressure.
Expand Down
1 change: 1 addition & 0 deletions jetstream/engine/engine_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def generate(
params: Params,
decode_state: DecodeState,
sampler: Optional[Callable[[Any], Any]] = None,
lora_params: Params = None,
) -> Tuple[DecodeState, ResultTokens]:
"""Generates tokens for each sequence being decoded in parallel.

Expand Down
Loading