-
Notifications
You must be signed in to change notification settings - Fork 37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Supporting Multi-LoRA inferencing via JetStream server #221
base: main
Are you sure you want to change the base?
Conversation
…ion to Orbax format.
…e2e with LoRA paths.
…tAdapters, LoadAdapter and UnloadAdapter. 2) Driver which is holding list of all loaded base-parameters is now storing the list of lora updated paramters for loaded lora. Implemented methods for loading, unloading and listing LoRA adapters into the Driver object. Original base model params are intact and saved into the params dictionary with key . 3) Created a proxy-client to make MultiAdapterManager service requests to JetStream server.
…pters. Its functionality includes loading, unloading of adapters between CPU RAM and HBM. It also follows LRU policy to evict the adapter if a new load_adapter request comes up. Currently it is only storing the adapter as separate tensors (lora_a and lora_b). Calculation of lora_b x lora_a is being done in prefill() and generate() during decode request. Adapter_tensorstore can be configured with a max_limit on HBM and RAM. 2) Functionality to load from a catalog file at the start of the server is added. If no file is given, it will just load the base params. Loading from the catalog file is done on CPU RAM. After that based on incoming requests, those params are moved/evicted to/from HBM. 3) Some proto updates to get only single path for each adapter, and that path is expected to have an adapter_config.json and Orbax format weights in 0/items folder.
…n API (https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/main/docs/proposals/003-model-server-protocol/README.md#inference-api-protocol), & . 2) Added a flag to explicitly run the JetStream server with these APIs when . Else only expose older Decode() & HealthCheck() APIs of the JetStream Server. 3) Fixed a bug in the adapter_tensorstore while converting jnp_array and np_array. 4) Added a which made requests to the new APIs (v1/load_lora_adapter, v1/unload_lora_adapter, v1/models, v1/completions)
1) kv_cache_utilization: This refers to percentage of memory in the allocated kv-cache on TPU HBM, that is actually used during decode. It is based on the percentage of slots used. 2) num_requests_waiting: Total number of requests which are waiting to be decoded. 3) lora_requests_info: List of LoRA adapters that are loaded into the TPU HBM for serving the requests.
2) Fixing model_ckpt_conversion.sh after refactoring and merging from main.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looked at it at high level, left some comments. Will take a deeper look again.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR, its a bit longish and I'd have preferred you to send the adapter_tensorstore.py
and related code as a separate PR since its isolated enough along with the unittests before sending the the PR to integrate it into orchestrator.
I've some initial comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How is this different from what multi_adapter_service_client.py
is doing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
multi_lora_decode_requester.py: This is a derived from benchmark_serving to run multiple decode requests at the same time for adapters.
multi_adapter_service_client.py: This was added when the requirement APIs includes v1/load_lora_adapter
, v1/unload_lora_adapter
. I have added this client to test all those API endpoints. It also tests v1/models
for listing and v1/completions
for decoding right now. But it only make 1 request at a time.
I will refactor this into one once we completely deprecate load_lora_adapter and unload_lora_adapter endpoints.
lru_time = float('inf') | ||
|
||
for adapter_id, metadata in self.adapter_registry.items(): | ||
if metadata.status == "loaded_hbm" if from_hbm else metadata.status == "loaded_cpu": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't understand this logic here, can you add a code comment after fixing it? Also, I don't see any unitests for the AdapterTensorstore. Definitely need some to assert the functionality, feel free to mock the gcs calls.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same evict function is getting used for eviction either form DRAM or HBM. So for eviction from HBM, it finds the LRU from those which are in HBM else look only for the adapters in the DRAM only.
- Implemented unapply lora from base_params - Fixed some comments from the PR
Supporting Multi-LoRA inferencing via JetStream server following LLM Inference gateway API protocols.