[tinker] Support additional LoRA config parameters#1643
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces granular control over LoRA training targets by adding train_attn, train_mlp, and train_unembed flags to the configuration. It includes a new utility to map these flags to specific target modules for different training strategies (FSDP and Megatron) and enhances seed management to ensure consistency across model components. Review feedback identifies several redundant integer casts and points out an inconsistency in default values for the new LoRA flags between the API and internal type definitions.
|
|
||
| # Apply LoRA configuration | ||
| if lora_config is not None and lora_config.rank > 0: | ||
| cfg.trainer.seed = int(lora_config.seed) |
| cfg.trainer.critic.model.lora.rank = lora_config.rank | ||
| cfg.trainer.critic.model.lora.alpha = int(lora_config.alpha) | ||
| if lora_config is not None and lora_config.rank > 0: | ||
| cfg.trainer.seed = int(lora_config.seed) |
| "Multi-LoRA with the SkyRLTrainBackend requires identical " | ||
| "(rank, alpha, target_modules, exclude_modules, lora_type) across all adapters." | ||
| ) | ||
| if seed_was_provided and self._base_lora_seed is not None and int(lora_config.seed) != self._base_lora_seed: |
There was a problem hiding this comment.
| self._build_policy(PolicyWorker, model_id=model_id) | ||
| if is_lora: | ||
| self._base_lora_signature = self._lora_signature_from(lora_config) | ||
| self._base_lora_seed = int(lora_config.seed) |
| pipeline_parallel_size=pipeline_parallel_size, | ||
| ) | ||
| return ( | ||
| int(lora_config.rank), |
| seed: int | None = Field( | ||
| default=None, description="Seed for LoRA weight initialization. If None, a random seed is used." | ||
| ) | ||
| train_unembed: bool = True |
There was a problem hiding this comment.
The default value for train_unembed is inconsistent between the API model (True) and the internal type definition in skyrl/tinker/types.py (False).
Given the PR's objective to preserve legacy behavior (which targeted only attention and MLP layers), the default should likely be False. If the intention is to change the default behavior to include unembedding training by default, then skyrl/tinker/types.py should be updated to match.
| train_unembed: bool = True | |
| train_unembed: bool = False |
|
|
||
| assert api_cfg.train_attn is True | ||
| assert api_cfg.train_mlp is True | ||
| assert api_cfg.train_unembed is True |
Summary
Fixes #1632 by teaching SkyRL's Tinker-compatible API and SkyRL-Train backend to honor the additional public Tinker LoRA configuration fields:
seed,train_attn,train_mlp, andtrain_unembed.Root Cause
The API accepted only
rankandseed, and SkyRL-Train only copied LoRArank/alphainto its config. That lefttarget_moduleseffectively fixed server-side throughcfg.trainer.policy.model.lora.target_modules, so users could not select attention-only, MLP-only, unembedding, or mixed LoRA training surfaces through the Tinker API.Changes
LoRAConfigfields and validation to the API model.seedwas explicitly provided so SkyRL-Train multi-LoRA can distinguish explicit seed mismatches from omitted-seed requests.attn + mlp, no unembedall-linearbehavior.trainer.seedfrom the resolved LoRA seed for first model construction.train_unembed=Truewith pipeline parallelism greater than one, whereoutput_layeronly exists on the final pipeline stage.Validation
uv run --extra tinker --extra dev python -m pytest tests/tinker/test_api_validation.py tests/tinker/test_skyrl_train_lora_config.py -quv run --with ml-dtypes --with transformers --extra tinker --extra dev --extra jax python -m pytest tests/tinker/test_engine.py -quv run --with ruff ruff check skyrl/tinker/types.py skyrl/tinker/api.py skyrl/backends/backend.py skyrl/tinker/engine.py skyrl/backends/jax.py skyrl/train/config/config.py skyrl/backends/skyrl_train_backend.py skyrl/backends/skyrl_train_lora.py tests/tinker/test_api_validation.py tests/tinker/test_skyrl_train_lora_config.pygit diff --check