[lora][tinker] Add pause and resume for multi-tenant lora #1657
[lora][tinker] Add pause and resume for multi-tenant lora #1657hao-aaron wants to merge 3 commits into
Conversation
Signed-off-by: ahao-anyscale <ahao@anyscale.com>
There was a problem hiding this comment.
Code Review
This pull request introduces targeted per-LoRA pause and resume functionality to support multi-tenant LoRA training, allowing weight swaps for one adapter without freezing others. Key changes include a new sample_with_retry mechanism in the RemoteInferenceClient to handle aborted requests and a server-side endpoint for selective request abortion. Review feedback identified several high-severity issues, primarily focusing on the need for persistent state tracking to prevent asyncio.Event data loss across different event loops. Additionally, corrections were suggested for accessing LoRA metadata in vLLM, managing max_tokens during retries to prevent over-generation, and truncating logprobs to ensure consistency with the original prompt length.
|
@hao-aaron btw have you tested this with multiple async rl runs in tinker mode? |
|
@erictang000 i have some runs here https://wandb.ai/sky-posttraining-uc-berkeley/lora_multi_tenant?nw=nwuserahao |
erictang000
left a comment
There was a problem hiding this comment.
looks good to me, thanks for getting this out so quick
Addresses #1647
Today's
pause_generation()is a global vLLM keep-mode pause: when one LoRA tenant syncs new weights, every other tenant's in-flight generation freezes for the duration of the swap. This blocks practical multi-tenant LoRA RL training.This PR adds a per-LoRA
lora_namearg to pause/resume so a weight sync for adapter A only aborts A's requests; other adapters keep generating. The aborted requests come back withfinish_reason="abort"and partial tokens, and a newsample_with_retry()client method accumulates those partial tokens, awaits resume, and resubmits withprompt + accumulatedand remainingmax_tokensuntil completion.This is a transient fix. Hopefully in the future if we can upstream a lora specific pause, we delete
sample_with_retry, the per-LoRA gate, and the abort endpoint; thelora_namekwarg stays and routes to the new vLLM API.What changed
New:
/skyrl/v1/abort_lora_requestsserver endpointvllm_server_actor.py gets a small custom endpoint that iterates
engine.output_processor.request_states, filters bylora_name, and callsengine.abort(ids, internal=True).internal=Trueis load-bearing — the states dict is keyed by internal IDs.pause_generation/resume_generationgainlora_name: Optional[str] = Nonelora_name=None→ unchanged: vLLM/pause?mode=keepglobal pause.lora_name="X"→ clears a per-LoRAasyncio.Event(gates retries client-side), sleeps a 5 s grace, then fans out to/skyrl/v1/abort_lora_requests.The new inference path (
_SKYRL_USE_NEW_INFERENCE=1, the default) goes throughRemoteInferenceClient→vllm_server_actor, which is where the new/skyrl/v1/abort_lora_requestsendpoint lives — so that's the only path that actually supports targeted pause. Legacy-path classes (InferenceEngineClient,RayWrappedInferenceEngine,RemoteInferenceEngine,AsyncVLLMInferenceEngine) accept thelora_namekwarg as a guardrail but raiseNotImplementedError("targeted pause is HTTP-only")when it's non-None. The legacy path keeps its existing global-pause behavior unchanged.RemoteInferenceClient:sample()refactor + newsample_with_retry()This is the only data-plane surface that gained retry logic.
sample()is split:sample()— renders the prompt and dispatches once. Public behavior unchanged._sample_with_rendered_tokens()— the post-render half, parameterized on renderedtoken_ids. Pure refactor (regression-checked by the existingTestSampletests).sample_with_retry()is new — the only retry-bearing method. Renders once, then runs awhile stop_reason == "abort"loop:await _lora_pause_events[model].wait()(no-op if no event exists).token_ids = original_prompt + accum_tokensandmax_tokens = original_max_tokens - len(accum_tokens)._sample_with_rendered_tokens.accum_tokens+accum_logprobswith the returned segment.Returns the same
SampleResponseshape assample(). Assertsnum_samples == 1(current Tinker callers all do this; multi-sample retry is straightforward but deferred).generate(),chat_completion(),completion()are untouched — no retry, no per-LoRA gating. The multi-tenant Tinker path is the only production caller that flips to retry mode.worker_dispatch.save_weights_for_samplerthreadslora_nameworker_dispatch.py non-colocate branch forwards
model_id(set bySkyRLTrainBackend.save_sampler_checkpointto the LoRA name for multi-tenant,Nonefor FFT) to the pause/resume calls.API impact
InferenceEngineInterface.pause_generation/resume_generationlora_name: Optional[str] = Nonekwarg.RemoteInferenceClient.sample()RemoteInferenceClient.sample_with_retry()sample().SkyRLTrainBackend.generate(),chat_completion(),completion()worker_dispatch.save_weights_for_samplerlora_name=model_idto pause/resume on the non-colocate path.model_id=Nonefor FFT preserves the global keep-mode pause.Tests
Unit (no GPU) —
tests/backends/skyrl_train/inference_servers/test_remote_inference_client.py7 new cases in
TestTargetedLoraPause:pause_generation(lora_name=X)clears the event and fans out abort.pause_generation()(no lora_name) still drives keep-mode.sample_with_retryaccumulates partial tokens across an abort,max_tokensdecrements correctly, logprobs concatenate, final response shape OK.sample_with_retryno-abort path is a single shot (refactor regression).sample_with_retryblocks on the per-LoRA event untilresume_generationis called.sample_with_retryrejectsnum_samples > 1.GPU integration —
tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_pause_lora.py4 end-to-end cases, run against a real vLLM server with two LoRA adapters (Meow + Woof) loaded:
test_pause_lora_does_not_affect_other_lora— while lora-meow is paused, 4 concurrentsample_with_retry(model="lora-woof")calls complete promptly and contain "woof" content (proves the gate doesn't spill across adapters and weights aren't mixed up).test_sample_with_retry_recovers_from_abort— 4+4 concurrent in-flight samples on meow + woof; pausing lora-meow mid-flight aborts all 4 meow requests, retry resubmits after resume, all 8 complete with non-abort stop reasons and correct content. Hard-asserts 0/8 tasks completed pre-pause and 0/4 escaped during the pause window — without these assertions the test would silently no-op if the LoRA emitted EOS too fast.test_pause_swap_weights_resume_mid_sample— single sample call spans a real weight swap: starts with Meow weights, mid-flight callspause → load_lora_adapter("lora-target", woof_path) → resume, and the merged output literally showsmeow meow meow ... woof woof woof. Proves the abort/retry boundary preserves accumulated state AND that the retried request observes the newly-loaded weights.test_global_pause_still_works—pause_generation()with nolora_namestill drives keep-mode pause (FFT regression).