Skip to content

fix(laguna): match vLLM YaRN attention scaling to fix RL KL mismatch#2802

Draft
S1ro1 wants to merge 1 commit into
mainfrom
fix/laguna-yarn-mscale
Draft

fix(laguna): match vLLM YaRN attention scaling to fix RL KL mismatch#2802
S1ro1 wants to merge 1 commit into
mainfrom
fix/laguna-yarn-mscale

Conversation

@S1ro1

@S1ro1 S1ro1 commented Jun 13, 2026

Copy link
Copy Markdown
Collaborator

Problem

Custom Laguna (poolside/Laguna-XS.2) shows a large RL importance-sampling mismatch from the very first step: mean mismatch_kl/all/mean0.030 over 20 steps on the math env (batch_size=64, sign_sgd lr=1e-6), versus the ~0.0015 floor the other custom models reach. The mismatch is present at step 0 (systematic), pointing at a trainer-vs-vLLM modeling difference rather than the weight broadcast.

Root cause

Laguna's full_attention layers use YaRN rope (factor=64, partial_rotary_factor=0.5, attention_factor=1.0); sliding_attention layers use default rope.

  • The trainer builds rope via HF's ROPE_INIT_FUNCTIONS["yarn"], which honors the config's attention_factor=1.0 → cos/sin scaled by 1.0.
  • vLLM's get_rope yarn branch only forwards attn_factor to YaRNScalingRotaryEmbedding and drops attention_factor, so it applies YaRN's default mscale 0.1·ln(64)+1 ≈ **1.4159**.

So on every full_attention layer (10 of 40) the trainer's rope-rotated q/k differed from vLLM's by a constant ~1.4159×. A local cos/sin parity check confirmed: the only difference is the scaling (inv_freq matches to float32 precision), and after the fix the trainer and vLLM rope match exactly (Δ=0) for both layer types.

Fix

Replicate vLLM's YaRN scaling in LagunaRotaryEmbedding: for yarn layers set the attention scaling to yarn_get_mscale(factor) * attn_factor (ignoring attention_factor, as vLLM does), applied in both the eager init path and init_buffers_post_meta (the meta-device path used during training). Non-yarn layers are untouched.

Result

poolside/Laguna-XS.2, 20-step mean mismatch_kl/all/mean: 0.030 → 0.0034 (wandb project laguna-trinity-kl, runs laguna-base-main vs laguna-fix).

🤖 Generated with Claude Code

…s at the floor

Laguna's full_attention layers use YaRN rope (factor=64, partial_rotary_factor=0.5).
The HF rope init honors the checkpoint's `attention_factor=1.0`, so the trainer applied
an mscale of 1.0 to cos/sin. vLLM's `get_rope`, however, only forwards `attn_factor`
to its YaRN embedding and drops `attention_factor`, so it applies YaRN's default
mscale `0.1*ln(factor)+1 ≈ 1.4159`. The trainer and the vLLM inference engine therefore
computed rope-rotated q/k that differed by a constant ~1.4159x on every full_attention
layer (10 of 40), inflating the RL importance-sampling mismatch.

On the math env (batch_size=64, sign_sgd lr=1e-6, 20 steps) this showed up as a
mean `mismatch_kl/all/mean` of ~0.030 — well above the ~0.0015 precision floor the
other models reach.

Replicate vLLM's behavior: for yarn layers, set the rotary attention scaling to
`yarn_get_mscale(factor) * attn_factor` (ignoring `attention_factor`, exactly as vLLM
does), in both the eager init path and `init_buffers_post_meta` (the meta-device path
used in training). Sliding_attention layers use default rope and are unchanged.

Result (poolside/Laguna-XS.2, 20-step mean mismatch_kl/all/mean): 0.030 -> 0.0034.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant