Skip to content

[lora][tinker] Add pause and resume for multi-tenant lora #1657

Open
hao-aaron wants to merge 3 commits into
NovaSky-AI:mainfrom
hao-aaron:multi-lora-pause
Open

[lora][tinker] Add pause and resume for multi-tenant lora #1657
hao-aaron wants to merge 3 commits into
NovaSky-AI:mainfrom
hao-aaron:multi-lora-pause

Conversation

@hao-aaron
Copy link
Copy Markdown
Collaborator

Addresses #1647

Today's pause_generation() is a global vLLM keep-mode pause: when one LoRA tenant syncs new weights, every other tenant's in-flight generation freezes for the duration of the swap. This blocks practical multi-tenant LoRA RL training.

This PR adds a per-LoRA lora_name arg to pause/resume so a weight sync for adapter A only aborts A's requests; other adapters keep generating. The aborted requests come back with finish_reason="abort" and partial tokens, and a new sample_with_retry() client method accumulates those partial tokens, awaits resume, and resubmits with prompt + accumulated and remaining max_tokens until completion.

This is a transient fix. Hopefully in the future if we can upstream a lora specific pause, we delete sample_with_retry, the per-LoRA gate, and the abort endpoint; the lora_name kwarg stays and routes to the new vLLM API.

What changed

New: /skyrl/v1/abort_lora_requests server endpoint

vllm_server_actor.py gets a small custom endpoint that iterates engine.output_processor.request_states, filters by lora_name, and calls engine.abort(ids, internal=True). internal=True is load-bearing — the states dict is keyed by internal IDs.

pause_generation/resume_generation gain lora_name: Optional[str] = None

  • lora_name=None → unchanged: vLLM /pause?mode=keep global pause.
  • lora_name="X" → clears a per-LoRA asyncio.Event (gates retries client-side), sleeps a 5 s grace, then fans out to /skyrl/v1/abort_lora_requests.

The new inference path (_SKYRL_USE_NEW_INFERENCE=1, the default) goes through RemoteInferenceClientvllm_server_actor, which is where the new /skyrl/v1/abort_lora_requests endpoint lives — so that's the only path that actually supports targeted pause. Legacy-path classes (InferenceEngineClient, RayWrappedInferenceEngine, RemoteInferenceEngine, AsyncVLLMInferenceEngine) accept the lora_name kwarg as a guardrail but raise NotImplementedError("targeted pause is HTTP-only") when it's non-None. The legacy path keeps its existing global-pause behavior unchanged.

RemoteInferenceClient: sample() refactor + new sample_with_retry()

This is the only data-plane surface that gained retry logic.

sample() is split:

  • sample() — renders the prompt and dispatches once. Public behavior unchanged.
  • _sample_with_rendered_tokens() — the post-render half, parameterized on rendered token_ids. Pure refactor (regression-checked by the existing TestSample tests).

sample_with_retry() is new — the only retry-bearing method. Renders once, then runs a while stop_reason == "abort" loop:

  1. await _lora_pause_events[model].wait() (no-op if no event exists).
  2. Build a body with token_ids = original_prompt + accum_tokens and max_tokens = original_max_tokens - len(accum_tokens).
  3. Dispatch _sample_with_rendered_tokens.
  4. Extend accum_tokens + accum_logprobs with the returned segment.
  5. Repeat until non-abort.

Returns the same SampleResponse shape as sample(). Asserts num_samples == 1 (current Tinker callers all do this; multi-sample retry is straightforward but deferred).

generate(), chat_completion(), completion() are untouched — no retry, no per-LoRA gating. The multi-tenant Tinker path is the only production caller that flips to retry mode.

worker_dispatch.save_weights_for_sampler threads lora_name

worker_dispatch.py non-colocate branch forwards model_id (set by SkyRLTrainBackend.save_sampler_checkpoint to the LoRA name for multi-tenant, None for FFT) to the pause/resume calls.

API impact

API Change Migration
InferenceEngineInterface.pause_generation / resume_generation New optional lora_name: Optional[str] = None kwarg. None — default preserves existing behavior on every subclass.
RemoteInferenceClient.sample() Refactored into render + dispatch (pure refactor; public behavior unchanged). None.
RemoteInferenceClient.sample_with_retry() New. Same signature/return as sample(). Optional; auto-used by the multi-LoRA path in SkyRLTrainBackend.
generate(), chat_completion(), completion() Unchanged. No retry, no per-LoRA gate. None.
worker_dispatch.save_weights_for_sampler Now forwards lora_name=model_id to pause/resume on the non-colocate path. None — model_id=None for FFT preserves the global keep-mode pause.

Tests

Unit (no GPU) — tests/backends/skyrl_train/inference_servers/test_remote_inference_client.py

7 new cases in TestTargetedLoraPause:

  • pause_generation(lora_name=X) clears the event and fans out abort.
  • pause_generation() (no lora_name) still drives keep-mode.
  • sample_with_retry accumulates partial tokens across an abort, max_tokens decrements correctly, logprobs concatenate, final response shape OK.
  • sample_with_retry no-abort path is a single shot (refactor regression).
  • sample_with_retry blocks on the per-LoRA event until resume_generation is called.
  • sample_with_retry rejects num_samples > 1.
  • LoRAs that were never paused never block (no spurious event creation).

GPU integration — tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_pause_lora.py

4 end-to-end cases, run against a real vLLM server with two LoRA adapters (Meow + Woof) loaded:

  1. test_pause_lora_does_not_affect_other_lora — while lora-meow is paused, 4 concurrent sample_with_retry(model="lora-woof") calls complete promptly and contain "woof" content (proves the gate doesn't spill across adapters and weights aren't mixed up).
  2. test_sample_with_retry_recovers_from_abort — 4+4 concurrent in-flight samples on meow + woof; pausing lora-meow mid-flight aborts all 4 meow requests, retry resubmits after resume, all 8 complete with non-abort stop reasons and correct content. Hard-asserts 0/8 tasks completed pre-pause and 0/4 escaped during the pause window — without these assertions the test would silently no-op if the LoRA emitted EOS too fast.
  3. test_pause_swap_weights_resume_mid_sample — single sample call spans a real weight swap: starts with Meow weights, mid-flight calls pause → load_lora_adapter("lora-target", woof_path) → resume, and the merged output literally shows meow meow meow ... woof woof woof. Proves the abort/retry boundary preserves accumulated state AND that the retried request observes the newly-loaded weights.
  4. test_global_pause_still_workspause_generation() with no lora_name still drives keep-mode pause (FFT regression).

Signed-off-by: ahao-anyscale <ahao@anyscale.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces targeted per-LoRA pause and resume functionality to support multi-tenant LoRA training, allowing weight swaps for one adapter without freezing others. Key changes include a new sample_with_retry mechanism in the RemoteInferenceClient to handle aborted requests and a server-side endpoint for selective request abortion. Review feedback identified several high-severity issues, primarily focusing on the need for persistent state tracking to prevent asyncio.Event data loss across different event loops. Additionally, corrections were suggested for accessing LoRA metadata in vLLM, managing max_tokens during retries to prevent over-generation, and truncating logprobs to ensure consistency with the original prompt length.

Comment thread skyrl/backends/skyrl_train/inference_servers/vllm_server_actor.py
Comment thread skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py Outdated
@erictang000
Copy link
Copy Markdown
Collaborator

@hao-aaron btw have you tested this with multiple async rl runs in tinker mode?

hao-aaron added 2 commits May 13, 2026 05:45
x
Signed-off-by: ahao-anyscale <ahao@anyscale.com>
x
Signed-off-by: ahao-anyscale <ahao@anyscale.com>
@hao-aaron
Copy link
Copy Markdown
Collaborator Author

@erictang000 i have some runs here https://wandb.ai/sky-posttraining-uc-berkeley/lora_multi_tenant?nw=nwuserahao
A and B have been run concurrently.

Copy link
Copy Markdown
Collaborator

@erictang000 erictang000 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good to me, thanks for getting this out so quick

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants