fix(laguna): match vLLM YaRN attention scaling to fix RL KL mismatch#2802
Draft
S1ro1 wants to merge 1 commit into
Draft
fix(laguna): match vLLM YaRN attention scaling to fix RL KL mismatch#2802S1ro1 wants to merge 1 commit into
S1ro1 wants to merge 1 commit into
Conversation
…s at the floor Laguna's full_attention layers use YaRN rope (factor=64, partial_rotary_factor=0.5). The HF rope init honors the checkpoint's `attention_factor=1.0`, so the trainer applied an mscale of 1.0 to cos/sin. vLLM's `get_rope`, however, only forwards `attn_factor` to its YaRN embedding and drops `attention_factor`, so it applies YaRN's default mscale `0.1*ln(factor)+1 ≈ 1.4159`. The trainer and the vLLM inference engine therefore computed rope-rotated q/k that differed by a constant ~1.4159x on every full_attention layer (10 of 40), inflating the RL importance-sampling mismatch. On the math env (batch_size=64, sign_sgd lr=1e-6, 20 steps) this showed up as a mean `mismatch_kl/all/mean` of ~0.030 — well above the ~0.0015 precision floor the other models reach. Replicate vLLM's behavior: for yarn layers, set the rotary attention scaling to `yarn_get_mscale(factor) * attn_factor` (ignoring `attention_factor`, exactly as vLLM does), in both the eager init path and `init_buffers_post_meta` (the meta-device path used in training). Sliding_attention layers use default rope and are unchanged. Result (poolside/Laguna-XS.2, 20-step mean mismatch_kl/all/mean): 0.030 -> 0.0034. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
Custom Laguna (
poolside/Laguna-XS.2) shows a large RL importance-sampling mismatch from the very first step: meanmismatch_kl/all/mean≈ 0.030 over 20 steps on the math env (batch_size=64,sign_sgdlr=1e-6), versus the ~0.0015 floor the other custom models reach. The mismatch is present at step 0 (systematic), pointing at a trainer-vs-vLLM modeling difference rather than the weight broadcast.Root cause
Laguna's
full_attentionlayers use YaRN rope (factor=64,partial_rotary_factor=0.5,attention_factor=1.0);sliding_attentionlayers use default rope.ROPE_INIT_FUNCTIONS["yarn"], which honors the config'sattention_factor=1.0→ cos/sin scaled by 1.0.get_ropeyarn branch only forwardsattn_factortoYaRNScalingRotaryEmbeddingand dropsattention_factor, so it applies YaRN's default mscale0.1·ln(64)+1 ≈ **1.4159**.So on every full_attention layer (10 of 40) the trainer's rope-rotated q/k differed from vLLM's by a constant ~1.4159×. A local cos/sin parity check confirmed: the only difference is the scaling (inv_freq matches to float32 precision), and after the fix the trainer and vLLM rope match exactly (Δ=0) for both layer types.
Fix
Replicate vLLM's YaRN scaling in
LagunaRotaryEmbedding: for yarn layers set the attention scaling toyarn_get_mscale(factor) * attn_factor(ignoringattention_factor, as vLLM does), applied in both the eager init path andinit_buffers_post_meta(the meta-device path used during training). Non-yarn layers are untouched.Result
poolside/Laguna-XS.2, 20-step meanmismatch_kl/all/mean: 0.030 → 0.0034 (wandb projectlaguna-trinity-kl, runslaguna-base-mainvslaguna-fix).🤖 Generated with Claude Code