Skip to content

[tinker] Support additional LoRA config parameters#1643

Open
taivu1998 wants to merge 1 commit into
NovaSky-AI:mainfrom
taivu1998:tdv/issue-1632-tinker-lora-config
Open

[tinker] Support additional LoRA config parameters#1643
taivu1998 wants to merge 1 commit into
NovaSky-AI:mainfrom
taivu1998:tdv/issue-1632-tinker-lora-config

Conversation

@taivu1998
Copy link
Copy Markdown

Summary

Fixes #1632 by teaching SkyRL's Tinker-compatible API and SkyRL-Train backend to honor the additional public Tinker LoRA configuration fields: seed, train_attn, train_mlp, and train_unembed.

Root Cause

The API accepted only rank and seed, and SkyRL-Train only copied LoRA rank/alpha into its config. That left target_modules effectively fixed server-side through cfg.trainer.policy.model.lora.target_modules, so users could not select attention-only, MLP-only, unembedding, or mixed LoRA training surfaces through the Tinker API.

Changes

  • Add Tinker-compatible LoRAConfig fields and validation to the API model.
  • Preserve whether seed was explicitly provided so SkyRL-Train multi-LoRA can distinguish explicit seed mismatches from omitted-seed requests.
  • Persist and return LoRA train flags through create-model, get-info, and weights-info flows.
  • Add a small SkyRL-Train resolver that maps Tinker train flags to FSDP and Megatron target modules while preserving the old attn + mlp, no unembed all-linear behavior.
  • Set SkyRL-Train trainer.seed from the resolved LoRA seed for first model construction.
  • Include resolved trainable surfaces in SkyRL-Train multi-LoRA compatibility checks.
  • Reject Megatron train_unembed=True with pipeline parallelism greater than one, where output_layer only exists on the final pipeline stage.
  • Add focused API, persistence, resolver, and engine regression coverage.

Validation

  • uv run --extra tinker --extra dev python -m pytest tests/tinker/test_api_validation.py tests/tinker/test_skyrl_train_lora_config.py -q
  • uv run --with ml-dtypes --with transformers --extra tinker --extra dev --extra jax python -m pytest tests/tinker/test_engine.py -q
  • uv run --with ruff ruff check skyrl/tinker/types.py skyrl/tinker/api.py skyrl/backends/backend.py skyrl/tinker/engine.py skyrl/backends/jax.py skyrl/train/config/config.py skyrl/backends/skyrl_train_backend.py skyrl/backends/skyrl_train_lora.py tests/tinker/test_api_validation.py tests/tinker/test_skyrl_train_lora_config.py
  • git diff --check

@taivu1998 taivu1998 marked this pull request as ready for review May 11, 2026 03:11
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces granular control over LoRA training targets by adding train_attn, train_mlp, and train_unembed flags to the configuration. It includes a new utility to map these flags to specific target modules for different training strategies (FSDP and Megatron) and enhances seed management to ensure consistency across model components. Review feedback identifies several redundant integer casts and points out an inconsistency in default values for the new LoRA flags between the API and internal type definitions.


# Apply LoRA configuration
if lora_config is not None and lora_config.rank > 0:
cfg.trainer.seed = int(lora_config.seed)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The int() cast is redundant as lora_config.seed is already defined as an int in the LoraConfig type definition.

Suggested change
cfg.trainer.seed = int(lora_config.seed)
cfg.trainer.seed = lora_config.seed

cfg.trainer.critic.model.lora.rank = lora_config.rank
cfg.trainer.critic.model.lora.alpha = int(lora_config.alpha)
if lora_config is not None and lora_config.rank > 0:
cfg.trainer.seed = int(lora_config.seed)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The int() cast is redundant here as well.

Suggested change
cfg.trainer.seed = int(lora_config.seed)
cfg.trainer.seed = lora_config.seed

"Multi-LoRA with the SkyRLTrainBackend requires identical "
"(rank, alpha, target_modules, exclude_modules, lora_type) across all adapters."
)
if seed_was_provided and self._base_lora_seed is not None and int(lora_config.seed) != self._base_lora_seed:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The int() cast is redundant here as well.

Suggested change
if seed_was_provided and self._base_lora_seed is not None and int(lora_config.seed) != self._base_lora_seed:
if seed_was_provided and self._base_lora_seed is not None and lora_config.seed != self._base_lora_seed:

self._build_policy(PolicyWorker, model_id=model_id)
if is_lora:
self._base_lora_signature = self._lora_signature_from(lora_config)
self._base_lora_seed = int(lora_config.seed)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The int() cast is redundant here as well.

Suggested change
self._base_lora_seed = int(lora_config.seed)
self._base_lora_seed = lora_config.seed

pipeline_parallel_size=pipeline_parallel_size,
)
return (
int(lora_config.rank),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The int() cast is redundant as lora_config.rank is already an int.

Suggested change
int(lora_config.rank),
lora_config.rank,

Comment thread skyrl/tinker/api.py
seed: int | None = Field(
default=None, description="Seed for LoRA weight initialization. If None, a random seed is used."
)
train_unembed: bool = True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The default value for train_unembed is inconsistent between the API model (True) and the internal type definition in skyrl/tinker/types.py (False).

Given the PR's objective to preserve legacy behavior (which targeted only attention and MLP layers), the default should likely be False. If the intention is to change the default behavior to include unembedding training by default, then skyrl/tinker/types.py should be updated to match.

Suggested change
train_unembed: bool = True
train_unembed: bool = False


assert api_cfg.train_attn is True
assert api_cfg.train_mlp is True
assert api_cfg.train_unembed is True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This assertion should be updated if the default for train_unembed is changed to False in api.py.

Suggested change
assert api_cfg.train_unembed is True
assert api_cfg.train_unembed is False

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[tinker] Support additional Tinker LoraConfig parameters

1 participant