Skip to content

Qwen35 sp fix by ulysses#127

Closed
meichangsu1 wants to merge 6 commits intomodelscope:mainfrom
meichangsu1:qwen35_sp_fix_ljl
Closed

Qwen35 sp fix by ulysses#127
meichangsu1 wants to merge 6 commits intomodelscope:mainfrom
meichangsu1:qwen35_sp_fix_ljl

Conversation

@meichangsu1
Copy link
Collaborator

@meichangsu1 meichangsu1 commented Mar 24, 2026

PR type

  • [ x] Bug Fix
  • New Feature
  • Document Updates
  • More Models or Datasets Support

PR information

实现思路
本次实现采用的是“通用 SP 策略 + Qwen3.5 模型特化实现”的分层方案。

  1. 通用 SP 部分放在 strategy 层
    通用的 Ulysses/SP 能力放在 sequence_parallel.py 中,负责:
    根据 ulysses_size 在序列维度切分输入
    在 attention 内部执行 seq <-> head 的 all-to-all 通信
    维护 sequence parallel runtime context等运行时信息

  2. Qwen3.5 特有部分放在模型层
    Qwen3.5 中使用的是线性注意力,GatedDeltaNet 还有卷积层,仍然沿用ulysses 的思想
    参考思路: Linear Attention Ulysses (GatedDeltaNet)

只是在 modeling_qwen3_5.py 中,目前重写了这些类
TwinkleQwen3_5TextModel
TwinkleQwen3_5DecoderLayer
TwinkleQwen3_5GatedDeltaNet
TwinkleQwen3_5ForCausalLM

在用户使用时指定model_cls=TwinkleQwen3_5ForCausalLM:
model = TransformersModel(
model_id=MODEL_ID,
model_cls=TwinkleQwen3_5ForCausalLM,
device_mesh=device_mesh,
strategy='native_fsdp',
)

综上
因此 softmax attention 继续复用通用 SP
linear attention 由模型内部显式处理 ,让 Qwen3.5 的语言模型骨干本身具备 SP 感知能力。

Experiment results

开启sp(ulysses=2)
[2026-03-23 15:31:43][INFO:twinkle] Current is optimizer step 60 of 63 (micro step 121 of 125), metric: {'loss': '1.1285', 'grad_norm': '0.139006', 'accuracy': '0.68', 'correct_tokens': 21312, 'total_tokens': 31280, 'learning rate(param group 1)': '6.586739e-07', 'learning rate(param group 2)': '6.586739e-07', 'iters': 60, 'total time elapse': '187 seconds', 'speed': '0.36 iters/s', 'rank': 0, 'local_rank': 0, 'device': 'cuda:0', 'mem_allocated': '2200.0 MiB', 'mem_reserved': '55992.0 MiB', 'mem_peak_allocated': '27645.0 MiB', 'mem_peak_reserved': '55992.0 MiB'}

关闭sp
[2026-03-23 13:47:27][INFO:twinkle] Current is step 60 of 63, metric: {'loss': '1.1290', 'grad_norm': '0.123370', 'accuracy': '0.68', 'correct_tokens': 21283, 'total_tokens': 31280, 'learning rate(param group 1)': '6.586739e-07', 'learning rate(param group 2)': '6.586739e-07', 'iters': 60, 'total time elapse': '129 seconds', 'speed': '0.53 iters/s', 'rank': 0, 'local_rank': 0, 'device': 'cuda:0', 'mem_allocated': '2183.5 MiB', 'mem_reserved': '62422.0 MiB', 'mem_peak_allocated': '36788.4 MiB', 'mem_peak_reserved': '77750.0 MiB'}

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the Twinkle framework by adding robust Sequence Parallelism support for the Qwen3.5 model. It involves deep integration of SP into the model's attention mechanisms, particularly optimizing linear attention with specialized kernels. The changes also include necessary adjustments to data handling and training configurations to seamlessly enable SP, alongside extensive testing to ensure functional correctness and memory benefits.

Highlights

  • Qwen3.5 Sequence Parallelism (SP) Implementation: Introduced a layered approach for Qwen3.5 SP, with generic SP capabilities in the strategy layer and Qwen3.5-specific optimizations in the model layer, particularly for linear attention.
  • Optimized Linear Attention: Integrated Flash-Linear-Attention (FLA) for TwinkleQwen3_5GatedDeltaNet to enhance performance and memory efficiency of linear attention layers when SP is enabled.
  • Dynamic Batch Size Adjustment: Implemented automatic adjustment of batch sizes in the DataLoader to correctly account for Ulysses Sequence Parallelism, ensuring proper data distribution across devices.
  • Comprehensive Testing and Benchmarking: Added new memory benchmarking scripts (qwen3_5_sp_memory_bench.py) and distributed parity tests to validate the correctness, memory usage, and gradient consistency of the Qwen3.5 SP implementation.
  • Refined Attention Mask Handling: Improved the logic for resolving and passing attention masks, especially for FlashAttention2/3, to prevent incorrect behavior with padded sequences in SP mode.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

本次 PR 为 Qwen3.5 模型引入了序列并行(Sequence Parallelism, SP)支持,这是一项重要且复杂的特性。实现上采用了良好的分层设计,将通用的 SP 策略与模型特化的适配逻辑清晰地分离开来。其中,针对线性注意力的定制化实现尤其值得称赞。此外,PR 中包含了全面的单元测试和多 GPU 的对等性测试,这极大地保证了代码变更的正确性。

我的审查意见主要包含几点关于代码重构的建议,旨在通过降低复杂度和代码重复来提升代码的可维护性。总体而言,这是一次高质量的贡献。

Comment on lines +281 to +433
def forward(
self,
hidden_states: torch.Tensor,
cache_params: hf_qwen35.Qwen3_5DynamicCache | None = None,
cache_position: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
cu_seq_lens_q: torch.Tensor | None = None,
sequence_parallel_context: Any | None = None,
):
attention_mask = _resolve_local_padding_mask(attention_mask, hidden_states.shape[1], sequence_parallel_context)
hidden_states = hf_qwen35.apply_mask_to_padding_states(hidden_states, attention_mask)
batch_size, seq_len, _ = hidden_states.shape
use_precomputed_states = (
cache_params is not None
and cache_params.has_previous_state
and seq_len == 1
and cache_position is not None
)

if cache_params is not None:
conv_state = cache_params.conv_states[self.layer_idx]
recurrent_state = cache_params.recurrent_states[self.layer_idx]
else:
conv_state = None
recurrent_state = None

mixed_qkv = self.in_proj_qkv(hidden_states)
z = self.in_proj_z(hidden_states).reshape(batch_size, seq_len, self.num_v_heads, self.head_v_dim)
b = self.in_proj_b(hidden_states)
a = self.in_proj_a(hidden_states)
full_attention_mask = attention_mask

sp_enabled = _sp_is_enabled(sequence_parallel_context)
if sp_enabled:
sp_world_size = int(sequence_parallel_context.sp_world_size)
if self.num_k_heads % sp_world_size != 0 or self.num_v_heads % sp_world_size != 0:
raise RuntimeError(
'TwinkleQwen3_5 linear attention requires sp_world_size to divide both '
f'linear_num_key_heads ({self.num_k_heads}) and linear_num_value_heads ({self.num_v_heads}).'
)
local_num_k_heads = self.num_k_heads // sp_world_size
local_num_v_heads = self.num_v_heads // sp_world_size
local_key_dim = local_num_k_heads * self.head_k_dim
local_value_dim = local_num_v_heads * self.head_v_dim

q_proj, k_proj, v_proj = torch.split(mixed_qkv, [self.key_dim, self.key_dim, self.value_dim], dim=-1)
q_proj = q_proj.reshape(batch_size, seq_len, self.num_k_heads, self.head_k_dim)
k_proj = k_proj.reshape(batch_size, seq_len, self.num_k_heads, self.head_k_dim)
v_proj = v_proj.reshape(batch_size, seq_len, self.num_v_heads, self.head_v_dim)
q_proj = _seq_to_head_shard(q_proj, sequence_parallel_context)
k_proj = _seq_to_head_shard(k_proj, sequence_parallel_context)
v_proj = _seq_to_head_shard(v_proj, sequence_parallel_context)
b = _seq_to_head_shard(b.reshape(batch_size, seq_len, self.num_v_heads), sequence_parallel_context)
a = _seq_to_head_shard(a.reshape(batch_size, seq_len, self.num_v_heads), sequence_parallel_context)

mixed_qkv = torch.cat(
(
q_proj.reshape(batch_size, q_proj.shape[1], local_key_dim),
k_proj.reshape(batch_size, k_proj.shape[1], local_key_dim),
v_proj.reshape(batch_size, v_proj.shape[1], local_value_dim),
),
dim=-1,
)
conv_weight = self._get_local_conv1d_weight(_get_sp_rank(sequence_parallel_context), local_key_dim, local_value_dim)
else:
local_num_k_heads = self.num_k_heads
local_num_v_heads = self.num_v_heads
local_key_dim = self.key_dim
local_value_dim = self.value_dim
b = b.reshape(batch_size, seq_len, self.num_v_heads)
a = a.reshape(batch_size, seq_len, self.num_v_heads)
conv_weight = self.conv1d.weight.squeeze(1)

packed_valid_mask = None
packed_cu_seqlens = cu_seq_lens_q
packed_seq_len = mixed_qkv.shape[1]
use_varlen_pack = cu_seq_lens_q is not None and not use_precomputed_states
if use_varlen_pack:
full_position_ids = getattr(sequence_parallel_context, 'real_position_ids', None)
packed_valid_mask, packed_cu_seqlens = _build_varlen_metadata(
position_ids=full_position_ids,
attention_mask=full_attention_mask,
full_seq_len=packed_seq_len,
)
mixed_qkv = _pack_varlen_tensor(mixed_qkv, packed_valid_mask)
b = _pack_varlen_tensor(b, packed_valid_mask)
a = _pack_varlen_tensor(a, packed_valid_mask)

if use_precomputed_states:
if conv_state is None:
raise RuntimeError('Qwen3.5 decode requires initialized convolution state.')
mixed_qkv = self._apply_decode_conv(mixed_qkv, conv_state, conv_weight)
else:
if cache_params is not None:
cache_params.conv_states[self.layer_idx] = F.pad(
mixed_qkv.transpose(1, 2).contiguous(),
(self.conv_kernel_size - mixed_qkv.shape[1], 0),
)
mixed_qkv = self._apply_varlen_conv(mixed_qkv, conv_weight, packed_cu_seqlens)

query, key, value = torch.split(mixed_qkv, [local_key_dim, local_key_dim, local_value_dim], dim=-1)
qkv_batch_size = 1 if use_varlen_pack else batch_size
query = query.reshape(qkv_batch_size, query.shape[1], local_num_k_heads, self.head_k_dim)
key = key.reshape(qkv_batch_size, key.shape[1], local_num_k_heads, self.head_k_dim)
value = value.reshape(qkv_batch_size, value.shape[1], local_num_v_heads, self.head_v_dim)

beta = b.sigmoid()
if sp_enabled:
head_offset = _get_sp_rank(sequence_parallel_context) * local_num_v_heads
head_slice = slice(head_offset, head_offset + local_num_v_heads)
g = -self.A_log[head_slice].float().exp() * F.softplus(a.float() + self.dt_bias[head_slice])
else:
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)

if self.num_v_heads // self.num_k_heads > 1:
repeat = self.num_v_heads // self.num_k_heads
query = query.repeat_interleave(repeat, dim=2)
key = key.repeat_interleave(repeat, dim=2)

if use_precomputed_states:
core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule(
query,
key,
value,
g=g,
beta=beta,
initial_state=recurrent_state,
output_final_state=cache_params is not None,
use_qk_l2norm_in_kernel=True,
)
else:
core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(
query,
key,
value,
g=g,
beta=beta,
initial_state=None,
output_final_state=cache_params is not None,
use_qk_l2norm_in_kernel=True,
cu_seqlens=packed_cu_seqlens,
)

if cache_params is not None:
cache_params.recurrent_states[self.layer_idx] = last_recurrent_state

if use_varlen_pack:
core_attn_out = _unpack_varlen_tensor(core_attn_out, packed_valid_mask, batch_size, packed_seq_len)
core_attn_out = _head_to_seq_shard(core_attn_out, sequence_parallel_context)
core_attn_out = self.norm(core_attn_out.reshape(-1, self.head_v_dim), z.reshape(-1, self.head_v_dim))
core_attn_out = core_attn_out.reshape(batch_size, seq_len, self.value_dim)
return self.out_proj(core_attn_out)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

TwinkleQwen3_5GatedDeltaNetforward 方法当前实现较长,同时处理了序列并行、变长序列打包、prefill 与 decode 路径等多个逻辑,降低了代码的可读性和可维护性。建议将此方法重构为多个更小、职责更单一的辅助方法。例如,可以分别创建处理序列并行张量操作(如 _apply_sequence_parallelism)、卷积步骤(如 _apply_convolution)和核心注意力计算(如 _apply_gated_delta_rule)的方法。这样可以让主 forward 方法的逻辑更清晰,主要负责调用这些步骤。

Comment on lines +26 to +47
def get_flattened_cu_seqlens_from_position_ids(position_ids: torch.LongTensor):
if position_ids.dim() == 1:
position_ids = position_ids.unsqueeze(0)
if position_ids.dim() != 2:
raise ValueError(f'Expected 1D or 2D position_ids, got shape={tuple(position_ids.shape)}')

device = position_ids.device
cu_seqlens = [0]
total = 0
for row in position_ids:
row = row.clone()
row[row < 0] = 0
seq_start_indices = torch.where(row == 0)[0]
if seq_start_indices.numel() == 0 or seq_start_indices[0].item() != 0:
seq_start_indices = torch.cat([torch.tensor([0], device=device, dtype=seq_start_indices.dtype), seq_start_indices])
seq_end_indices = torch.cat([seq_start_indices[1:], torch.tensor([len(row)], device=device)])
seq_lengths = (seq_end_indices - seq_start_indices).tolist()
for seq_length in seq_lengths:
total += int(seq_length)
cu_seqlens.append(total)
return torch.tensor(cu_seqlens, device=device, dtype=torch.long)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

get_flattened_cu_seqlens_from_position_ids 函数与 src/twinkle/model/transformers/models/qwen3_5/modeling_qwen3_5.py 文件中的 _build_varlen_metadata 函数功能上存在重叠。两者都用于从 position_ids 计算用于处理可变长度序列的 cu_seqlens。为了避免代码重复并提高可维护性,建议将这部分逻辑统一到一个共享的工具函数中。该函数可以放在 sequence_parallel.py 中,并根据需要返回 cu_seqlensvalid_mask

…string

- Change double quotes to single quotes for consistency in `attn_implementation` parameter
- Reformat multi-line imports to single line for better readability
- Remove unnecessary import error message in linear attention validation
- Maintain code style consistency across the codebase
self.processor: Optional[InputProcessor] = None
self._set_work_init_fn()

@staticmethod
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

data_world_size 讲道理应该包含ulysses的判断才对

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

world_size的判断放在具体组件里不太合适,收敛到device_mesh中比较好

self.batch_size = batch_size
required_world_size = self._required_data_world_size(device_mesh)
assert batch_size >= required_world_size and batch_size % required_world_size == 0
self.batch_size = self._resolve_runtime_batch_size(batch_size, device_mesh)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的逻辑是什么原因呢

DEFAULT_WEIGHT_DECAY = 0.01


def _default_gradient_accumulation_steps_for_device_mesh(device_mesh: Optional[DeviceMesh]) -> int:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是否可以让用户自行决定

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants