Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supporting Multi-LoRA inferencing via JetStream server #221

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
459 changes: 459 additions & 0 deletions jetstream/core/lora/adapter_tensorstore.py

Large diffs are not rendered by default.

101 changes: 101 additions & 0 deletions jetstream/core/lora/multi_lora_inference_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Manages the list of fine-tuned adapters loaded on top of the base model for serving.
"""

import logging
import grpc

from typing import Optional
from jetstream.core import orchestrator
from jetstream.core.proto import multi_lora_decoding_pb2_grpc
from jetstream.core.proto import multi_lora_decoding_pb2


class MultiLoraManager(multi_lora_decoding_pb2_grpc.v1Servicer):
"""Manages the parameters of multiple lora requests and their status/lifetimes."""

_driver: orchestrator.Driver

def __init__(self, driver: orchestrator.Driver):
self._driver = driver

def models(
self,
request: multi_lora_decoding_pb2.ListAdaptersRequest,
context: Optional[grpc.aio.ServicerContext] = None,
) -> multi_lora_decoding_pb2.ListAdaptersResponse:
"""ListAdapters all loaded LoRA adapters."""

try:
adapters = self._driver.list_adapters_from_tensorstore()

adapter_infos = []
for adapter_id, adapter_data in adapters.items():
if adapter_data.status == "loaded_hbm":
loading_cost = 0
elif adapter_data.status == "loaded_cpu":
loading_cost = 1
elif adapter_data.status == "unloaded":
loading_cost = 2
else:
loading_cost = -1

adapter_info = multi_lora_decoding_pb2.AdapterInfo(
adapter_id=adapter_id,
loading_cost=loading_cost,
size_hbm=adapter_data.size_hbm,
size_cpu=adapter_data.size_cpu,
last_accessed=adapter_data.last_accessed,
status=adapter_data.status)

adapter_infos.append(adapter_info)

return multi_lora_decoding_pb2.ListAdaptersResponse(success=True, adapter_infos=adapter_infos)
except Exception as e:
logging.info(f"Listing of adapters failed with error: {str(e)}")
return multi_lora_decoding_pb2.ListAdaptersResponse(success=False, error_message=str(e))


def load_lora_adapter(
self,
request: multi_lora_decoding_pb2.LoadAdapterRequest,
context: Optional[grpc.aio.ServicerContext] = None,
) -> multi_lora_decoding_pb2.LoadAdapterResponse:
"""Load a LoRA adapter as mentioned in the request."""

try:
self._driver.load_adapter_to_tensorstore(request.adapter_id, request.adapter_path)

return multi_lora_decoding_pb2.LoadAdapterResponse(success=True)
except Exception as e:
logging.info(f"Loading of adapter_id={request.adapter_id} failed with error: {str(e)}")
return multi_lora_decoding_pb2.LoadAdapterResponse(success=False, error_message=str(e))


def unload_lora_adapter(
self,
request: multi_lora_decoding_pb2.UnloadAdapterRequest,
context: Optional[grpc.aio.ServicerContext] = None,
) -> multi_lora_decoding_pb2.UnloadAdapterResponse:
"""Unload a LoRA adapter as mentioned in the request."""

try:
self._driver.unload_adapter_from_tensorstore(request.adapter_id)
return multi_lora_decoding_pb2.UnloadAdapterResponse(success=True)
except Exception as e:
logging.info(f"Loading of adapter_id={request.adapter_id} failed with error: {str(e)}")
return multi_lora_decoding_pb2.UnloadAdapterResponse(success=False, error_message=str(e))

35 changes: 35 additions & 0 deletions jetstream/core/metrics/prometheus.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,31 @@ def __init__(self, model_name: Optional[str] = None):
],
)

_num_requests_waiting = Gauge(
name="num_requests_waiting",
documentation="Number of requests waiting to be processed for inference.",
labelnames=["id"],
multiprocess_mode="sum",
)

_kv_cache_utilization = Gauge(
name="kv_cache_utilization_perc",
documentation="kv-cache utilization % by the requests under processing.",
labelnames=["id"],
multiprocess_mode="sum",
)

_lora_request_info = Gauge(
name="lora_request_info",
documentation="LoRA adapters loaded into HBM for processing requests.",
labelnames=[
"id",
"max_lora",
"running_lora_adapters",
],
multiprocess_mode="livemostrecent",
)

def get_prefill_backlog_metric(self):
return self._prefill_backlog.labels(**self.universal_labels)

Expand Down Expand Up @@ -289,3 +314,13 @@ def get_request_output_length(self):

def get_request_success_count_metric(self):
return self._request_success_count.labels(**self.universal_labels)

def get_num_requests_waiting_metric(self):
return self._num_requests_waiting.labels(**self.universal_labels)

def get_kv_cache_utilization_metric(self):
return self._kv_cache_utilization.labels(**self.universal_labels)

def get_lora_request_info_metric(self, max_lora: int, loaded_adapters: str):
return self._lora_request_info.labels(**self.universal_labels,
max_lora=max_lora, running_lora_adapters=loaded_adapters)
Loading
Loading