Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Twinkle framework by adding robust Sequence Parallelism support for the Qwen3.5 model. It involves deep integration of SP into the model's attention mechanisms, particularly optimizing linear attention with specialized kernels. The changes also include necessary adjustments to data handling and training configurations to seamlessly enable SP, alongside extensive testing to ensure functional correctness and memory benefits. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
本次 PR 为 Qwen3.5 模型引入了序列并行(Sequence Parallelism, SP)支持,这是一项重要且复杂的特性。实现上采用了良好的分层设计,将通用的 SP 策略与模型特化的适配逻辑清晰地分离开来。其中,针对线性注意力的定制化实现尤其值得称赞。此外,PR 中包含了全面的单元测试和多 GPU 的对等性测试,这极大地保证了代码变更的正确性。
我的审查意见主要包含几点关于代码重构的建议,旨在通过降低复杂度和代码重复来提升代码的可维护性。总体而言,这是一次高质量的贡献。
| def forward( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| cache_params: hf_qwen35.Qwen3_5DynamicCache | None = None, | ||
| cache_position: torch.LongTensor | None = None, | ||
| attention_mask: torch.Tensor | None = None, | ||
| cu_seq_lens_q: torch.Tensor | None = None, | ||
| sequence_parallel_context: Any | None = None, | ||
| ): | ||
| attention_mask = _resolve_local_padding_mask(attention_mask, hidden_states.shape[1], sequence_parallel_context) | ||
| hidden_states = hf_qwen35.apply_mask_to_padding_states(hidden_states, attention_mask) | ||
| batch_size, seq_len, _ = hidden_states.shape | ||
| use_precomputed_states = ( | ||
| cache_params is not None | ||
| and cache_params.has_previous_state | ||
| and seq_len == 1 | ||
| and cache_position is not None | ||
| ) | ||
|
|
||
| if cache_params is not None: | ||
| conv_state = cache_params.conv_states[self.layer_idx] | ||
| recurrent_state = cache_params.recurrent_states[self.layer_idx] | ||
| else: | ||
| conv_state = None | ||
| recurrent_state = None | ||
|
|
||
| mixed_qkv = self.in_proj_qkv(hidden_states) | ||
| z = self.in_proj_z(hidden_states).reshape(batch_size, seq_len, self.num_v_heads, self.head_v_dim) | ||
| b = self.in_proj_b(hidden_states) | ||
| a = self.in_proj_a(hidden_states) | ||
| full_attention_mask = attention_mask | ||
|
|
||
| sp_enabled = _sp_is_enabled(sequence_parallel_context) | ||
| if sp_enabled: | ||
| sp_world_size = int(sequence_parallel_context.sp_world_size) | ||
| if self.num_k_heads % sp_world_size != 0 or self.num_v_heads % sp_world_size != 0: | ||
| raise RuntimeError( | ||
| 'TwinkleQwen3_5 linear attention requires sp_world_size to divide both ' | ||
| f'linear_num_key_heads ({self.num_k_heads}) and linear_num_value_heads ({self.num_v_heads}).' | ||
| ) | ||
| local_num_k_heads = self.num_k_heads // sp_world_size | ||
| local_num_v_heads = self.num_v_heads // sp_world_size | ||
| local_key_dim = local_num_k_heads * self.head_k_dim | ||
| local_value_dim = local_num_v_heads * self.head_v_dim | ||
|
|
||
| q_proj, k_proj, v_proj = torch.split(mixed_qkv, [self.key_dim, self.key_dim, self.value_dim], dim=-1) | ||
| q_proj = q_proj.reshape(batch_size, seq_len, self.num_k_heads, self.head_k_dim) | ||
| k_proj = k_proj.reshape(batch_size, seq_len, self.num_k_heads, self.head_k_dim) | ||
| v_proj = v_proj.reshape(batch_size, seq_len, self.num_v_heads, self.head_v_dim) | ||
| q_proj = _seq_to_head_shard(q_proj, sequence_parallel_context) | ||
| k_proj = _seq_to_head_shard(k_proj, sequence_parallel_context) | ||
| v_proj = _seq_to_head_shard(v_proj, sequence_parallel_context) | ||
| b = _seq_to_head_shard(b.reshape(batch_size, seq_len, self.num_v_heads), sequence_parallel_context) | ||
| a = _seq_to_head_shard(a.reshape(batch_size, seq_len, self.num_v_heads), sequence_parallel_context) | ||
|
|
||
| mixed_qkv = torch.cat( | ||
| ( | ||
| q_proj.reshape(batch_size, q_proj.shape[1], local_key_dim), | ||
| k_proj.reshape(batch_size, k_proj.shape[1], local_key_dim), | ||
| v_proj.reshape(batch_size, v_proj.shape[1], local_value_dim), | ||
| ), | ||
| dim=-1, | ||
| ) | ||
| conv_weight = self._get_local_conv1d_weight(_get_sp_rank(sequence_parallel_context), local_key_dim, local_value_dim) | ||
| else: | ||
| local_num_k_heads = self.num_k_heads | ||
| local_num_v_heads = self.num_v_heads | ||
| local_key_dim = self.key_dim | ||
| local_value_dim = self.value_dim | ||
| b = b.reshape(batch_size, seq_len, self.num_v_heads) | ||
| a = a.reshape(batch_size, seq_len, self.num_v_heads) | ||
| conv_weight = self.conv1d.weight.squeeze(1) | ||
|
|
||
| packed_valid_mask = None | ||
| packed_cu_seqlens = cu_seq_lens_q | ||
| packed_seq_len = mixed_qkv.shape[1] | ||
| use_varlen_pack = cu_seq_lens_q is not None and not use_precomputed_states | ||
| if use_varlen_pack: | ||
| full_position_ids = getattr(sequence_parallel_context, 'real_position_ids', None) | ||
| packed_valid_mask, packed_cu_seqlens = _build_varlen_metadata( | ||
| position_ids=full_position_ids, | ||
| attention_mask=full_attention_mask, | ||
| full_seq_len=packed_seq_len, | ||
| ) | ||
| mixed_qkv = _pack_varlen_tensor(mixed_qkv, packed_valid_mask) | ||
| b = _pack_varlen_tensor(b, packed_valid_mask) | ||
| a = _pack_varlen_tensor(a, packed_valid_mask) | ||
|
|
||
| if use_precomputed_states: | ||
| if conv_state is None: | ||
| raise RuntimeError('Qwen3.5 decode requires initialized convolution state.') | ||
| mixed_qkv = self._apply_decode_conv(mixed_qkv, conv_state, conv_weight) | ||
| else: | ||
| if cache_params is not None: | ||
| cache_params.conv_states[self.layer_idx] = F.pad( | ||
| mixed_qkv.transpose(1, 2).contiguous(), | ||
| (self.conv_kernel_size - mixed_qkv.shape[1], 0), | ||
| ) | ||
| mixed_qkv = self._apply_varlen_conv(mixed_qkv, conv_weight, packed_cu_seqlens) | ||
|
|
||
| query, key, value = torch.split(mixed_qkv, [local_key_dim, local_key_dim, local_value_dim], dim=-1) | ||
| qkv_batch_size = 1 if use_varlen_pack else batch_size | ||
| query = query.reshape(qkv_batch_size, query.shape[1], local_num_k_heads, self.head_k_dim) | ||
| key = key.reshape(qkv_batch_size, key.shape[1], local_num_k_heads, self.head_k_dim) | ||
| value = value.reshape(qkv_batch_size, value.shape[1], local_num_v_heads, self.head_v_dim) | ||
|
|
||
| beta = b.sigmoid() | ||
| if sp_enabled: | ||
| head_offset = _get_sp_rank(sequence_parallel_context) * local_num_v_heads | ||
| head_slice = slice(head_offset, head_offset + local_num_v_heads) | ||
| g = -self.A_log[head_slice].float().exp() * F.softplus(a.float() + self.dt_bias[head_slice]) | ||
| else: | ||
| g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) | ||
|
|
||
| if self.num_v_heads // self.num_k_heads > 1: | ||
| repeat = self.num_v_heads // self.num_k_heads | ||
| query = query.repeat_interleave(repeat, dim=2) | ||
| key = key.repeat_interleave(repeat, dim=2) | ||
|
|
||
| if use_precomputed_states: | ||
| core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( | ||
| query, | ||
| key, | ||
| value, | ||
| g=g, | ||
| beta=beta, | ||
| initial_state=recurrent_state, | ||
| output_final_state=cache_params is not None, | ||
| use_qk_l2norm_in_kernel=True, | ||
| ) | ||
| else: | ||
| core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( | ||
| query, | ||
| key, | ||
| value, | ||
| g=g, | ||
| beta=beta, | ||
| initial_state=None, | ||
| output_final_state=cache_params is not None, | ||
| use_qk_l2norm_in_kernel=True, | ||
| cu_seqlens=packed_cu_seqlens, | ||
| ) | ||
|
|
||
| if cache_params is not None: | ||
| cache_params.recurrent_states[self.layer_idx] = last_recurrent_state | ||
|
|
||
| if use_varlen_pack: | ||
| core_attn_out = _unpack_varlen_tensor(core_attn_out, packed_valid_mask, batch_size, packed_seq_len) | ||
| core_attn_out = _head_to_seq_shard(core_attn_out, sequence_parallel_context) | ||
| core_attn_out = self.norm(core_attn_out.reshape(-1, self.head_v_dim), z.reshape(-1, self.head_v_dim)) | ||
| core_attn_out = core_attn_out.reshape(batch_size, seq_len, self.value_dim) | ||
| return self.out_proj(core_attn_out) | ||
|
|
There was a problem hiding this comment.
| def get_flattened_cu_seqlens_from_position_ids(position_ids: torch.LongTensor): | ||
| if position_ids.dim() == 1: | ||
| position_ids = position_ids.unsqueeze(0) | ||
| if position_ids.dim() != 2: | ||
| raise ValueError(f'Expected 1D or 2D position_ids, got shape={tuple(position_ids.shape)}') | ||
|
|
||
| device = position_ids.device | ||
| cu_seqlens = [0] | ||
| total = 0 | ||
| for row in position_ids: | ||
| row = row.clone() | ||
| row[row < 0] = 0 | ||
| seq_start_indices = torch.where(row == 0)[0] | ||
| if seq_start_indices.numel() == 0 or seq_start_indices[0].item() != 0: | ||
| seq_start_indices = torch.cat([torch.tensor([0], device=device, dtype=seq_start_indices.dtype), seq_start_indices]) | ||
| seq_end_indices = torch.cat([seq_start_indices[1:], torch.tensor([len(row)], device=device)]) | ||
| seq_lengths = (seq_end_indices - seq_start_indices).tolist() | ||
| for seq_length in seq_lengths: | ||
| total += int(seq_length) | ||
| cu_seqlens.append(total) | ||
| return torch.tensor(cu_seqlens, device=device, dtype=torch.long) | ||
|
|
There was a problem hiding this comment.
…string - Change double quotes to single quotes for consistency in `attn_implementation` parameter - Reformat multi-line imports to single line for better readability - Remove unnecessary import error message in linear attention validation - Maintain code style consistency across the codebase
| self.processor: Optional[InputProcessor] = None | ||
| self._set_work_init_fn() | ||
|
|
||
| @staticmethod |
There was a problem hiding this comment.
data_world_size 讲道理应该包含ulysses的判断才对
There was a problem hiding this comment.
world_size的判断放在具体组件里不太合适,收敛到device_mesh中比较好
| self.batch_size = batch_size | ||
| required_world_size = self._required_data_world_size(device_mesh) | ||
| assert batch_size >= required_world_size and batch_size % required_world_size == 0 | ||
| self.batch_size = self._resolve_runtime_batch_size(batch_size, device_mesh) |
| DEFAULT_WEIGHT_DECAY = 0.01 | ||
|
|
||
|
|
||
| def _default_gradient_accumulation_steps_for_device_mesh(device_mesh: Optional[DeviceMesh]) -> int: |
PR type
PR information
实现思路
本次实现采用的是“通用 SP 策略 + Qwen3.5 模型特化实现”的分层方案。
通用 SP 部分放在 strategy 层
通用的 Ulysses/SP 能力放在 sequence_parallel.py 中,负责:
根据 ulysses_size 在序列维度切分输入
在 attention 内部执行 seq <-> head 的 all-to-all 通信
维护 sequence parallel runtime context等运行时信息
Qwen3.5 特有部分放在模型层
Qwen3.5 中使用的是线性注意力,GatedDeltaNet 还有卷积层,仍然沿用ulysses 的思想
参考思路:
Linear Attention Ulysses (GatedDeltaNet)只是在 modeling_qwen3_5.py 中,目前重写了这些类
TwinkleQwen3_5TextModel
TwinkleQwen3_5DecoderLayer
TwinkleQwen3_5GatedDeltaNet
TwinkleQwen3_5ForCausalLM
在用户使用时指定model_cls=TwinkleQwen3_5ForCausalLM:
model = TransformersModel(
model_id=MODEL_ID,
model_cls=TwinkleQwen3_5ForCausalLM,
device_mesh=device_mesh,
strategy='native_fsdp',
)
综上
因此 softmax attention 继续复用通用 SP
linear attention 由模型内部显式处理 ,让 Qwen3.5 的语言模型骨干本身具备 SP 感知能力。
Experiment results
开启sp(ulysses=2)
[2026-03-23 15:31:43][INFO:twinkle] Current is optimizer step 60 of 63 (micro step 121 of 125), metric: {'loss': '1.1285', 'grad_norm': '0.139006', 'accuracy': '0.68', 'correct_tokens': 21312, 'total_tokens': 31280, 'learning rate(param group 1)': '6.586739e-07', 'learning rate(param group 2)': '6.586739e-07', 'iters': 60, 'total time elapse': '187 seconds', 'speed': '0.36 iters/s', 'rank': 0, 'local_rank': 0, 'device': 'cuda:0', 'mem_allocated': '2200.0 MiB', 'mem_reserved': '55992.0 MiB', 'mem_peak_allocated': '27645.0 MiB', 'mem_peak_reserved': '55992.0 MiB'}
关闭sp
[2026-03-23 13:47:27][INFO:twinkle] Current is step 60 of 63, metric: {'loss': '1.1290', 'grad_norm': '0.123370', 'accuracy': '0.68', 'correct_tokens': 21283, 'total_tokens': 31280, 'learning rate(param group 1)': '6.586739e-07', 'learning rate(param group 2)': '6.586739e-07', 'iters': 60, 'total time elapse': '129 seconds', 'speed': '0.53 iters/s', 'rank': 0, 'local_rank': 0, 'device': 'cuda:0', 'mem_allocated': '2183.5 MiB', 'mem_reserved': '62422.0 MiB', 'mem_peak_allocated': '36788.4 MiB', 'mem_peak_reserved': '77750.0 MiB'}