Skip to content

Commit 2e0f6c3

Browse files
committed
Update infos for KDA
1 parent 05b567c commit 2e0f6c3

File tree

4 files changed

+92
-105
lines changed

4 files changed

+92
-105
lines changed

README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,11 @@ This repo aims at providing a collection of efficient Triton-based implementatio
2525
* [Benchmarks](#benchmarks)
2626
* [Citation](#citation)
2727
* [Star History](#star-history)
28-
* [Acknowledgments](#acknowledgments)
28+
* [Acknowledgements](#acknowledgements)
2929

3030
## News
3131

32+
- **$\texttt{[2025-10]}$:** 🌑 Add Kimi Delta Attention implementation to `fla`.
3233
- **$\texttt{[2025-09]}$:** 🌲 Add DeltaFormer implementation to `fla` ([paper](https://arxiv.org/abs/2505.19488v1)).
3334
- **$\texttt{[2025-09]}$:** 🐻 Thrilled to announce that [GDN](fla/ops/gated_delta_rule) has been integrated into Qwen3-Next. Check out their [blog post](https://qwen.ai/blog?id=4074cca80393150c248e508aa62983f9cb7d27cd&from=research.latest-advancements-list) for more infos!
3435
- **$\texttt{[2025-08]}$:** 🌲 Add Log-Linear Attention implementation to `fla` ([paper](https://arxiv.org/abs/2506.04761)).
@@ -561,6 +562,6 @@ If you find this repository helpful, please cite our work:
561562

562563
[![Star History Chart](https://api.star-history.com/svg?repos=fla-org/flash-linear-attention&type=Date)](https://star-history.com/#fla-org/flash-linear-attention&Date)
563564

564-
## Acknowledgments
565+
## Acknowledgements
565566

566-
We extend our gratitude to [Bitdeer](https://www.bitdeer.com/) for providing CI server resources that power our infrastructure.
567+
We extend our gratitude to [Bitdeer](https://www.bitdeer.com/) and [Moonshot AI](https://www.moonshot.ai/) for their support in maintaining and powering our project infrastructure.

fla/layers/kda.py

Lines changed: 28 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
# -*- coding: utf-8 -*-
21
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
32

43
from __future__ import annotations
54

65
import math
7-
from typing import TYPE_CHECKING, Dict, Optional, Tuple
6+
from typing import TYPE_CHECKING
87

98
import torch
109
import torch.nn as nn
@@ -26,19 +25,6 @@ class KimiDeltaAttention(nn.Module):
2625
"""
2726
Kimi Delta Attention (KDA) layer implementation.
2827
29-
Each layer contains approximately 6*hidden_size*hidden_size parameters.
30-
31-
Parameter allocation:
32-
- q_proj: hidden_size * key_dim (where key_dim = num_heads * head_dim)
33-
- k_proj: hidden_size * key_dim
34-
- v_proj: hidden_size * value_dim (where value_dim = num_v_heads * head_dim * expand_v)
35-
- o_proj: value_dim * hidden_size
36-
- f_proj: hidden_size * head_v_dim + head_v_dim * key_dim (with bias)
37-
- b_proj: hidden_size * num_heads
38-
- g_proj: hidden_size * head_v_dim + head_v_dim * value_dim (with bias)
39-
- A: num_heads parameters
40-
- Plus convolution layers when use_short_conv=True
41-
4228
Args:
4329
hidden_size (int, Optional):
4430
The hidden size of the input. Default: 2048.
@@ -52,7 +38,7 @@ class KimiDeltaAttention(nn.Module):
5238
The number of heads for the value projection, equal to `num_heads` if `None`.
5339
GVA (Grouped Value Attention) is applied if `num_v_heads` > `num_heads`. Default: `None`.
5440
mode (str, Optional):
55-
Which Gated DeltaNet kernel to use.
41+
Which Kimi Delta Attention kernel to use.
5642
Currently available: `chunk` and `fused_recurrent`.
5743
Default: `chunk`.
5844
use_short_conv (bool, Optional):
@@ -85,7 +71,7 @@ def __init__(
8571
conv_bias: bool = False,
8672
layer_idx: int = None,
8773
norm_eps: float = 1e-5,
88-
**kwargs
74+
**kwargs,
8975
) -> KimiDeltaAttention:
9076
super().__init__()
9177

@@ -112,17 +98,17 @@ def __init__(
11298
if not math.isclose(self.num_v_heads * self.head_dim * expand_v, self.value_dim, rel_tol=1e-5):
11399
raise ValueError(
114100
f"expand_v={expand_v} does not produce an integer value when multiplied by key_dim={self.key_dim}. "
115-
f"Resulting value_dim would be {self.num_v_heads * self.head_dim * expand_v}, which is invalid for nn.Linear."
101+
f"Resulting value_dim would be {self.num_v_heads * self.head_dim * expand_v}, which is invalid for nn.Linear.",
116102
)
117103
if self.num_v_heads > self.num_heads and self.num_v_heads % self.num_heads != 0:
118104
raise ValueError(
119-
f"num_v_heads={self.num_v_heads} must be divisible by num_heads={self.num_heads}."
105+
f"num_v_heads={self.num_v_heads} must be divisible by num_heads={self.num_heads}.",
120106
)
121107

122108
if not math.isclose(head_dim * expand_v, self.head_v_dim, rel_tol=1e-5):
123109
raise ValueError(
124110
f"expand_v={expand_v} does not produce an integer value when multiplied by head_dim={head_dim}. "
125-
f"Resulting head_v_dim would be {head_dim * expand_v}, which is invalid for FusedRMSNormGated."
111+
f"Resulting head_v_dim would be {head_dim * expand_v}, which is invalid for FusedRMSNormGated.",
126112
)
127113
assert mode in ['chunk', 'fused_recurrent'], f"Not supported mode `{mode}`."
128114

@@ -135,44 +121,46 @@ def __init__(
135121
hidden_size=self.key_dim,
136122
kernel_size=conv_size,
137123
bias=conv_bias,
138-
activation='silu'
124+
activation='silu',
139125
)
140126
self.k_conv1d = ShortConvolution(
141127
hidden_size=self.key_dim,
142128
kernel_size=conv_size,
143129
bias=conv_bias,
144-
activation='silu'
130+
activation='silu',
145131
)
146132
self.v_conv1d = ShortConvolution(
147133
hidden_size=self.value_dim,
148134
kernel_size=conv_size,
149135
bias=conv_bias,
150-
activation='silu'
136+
activation='silu',
151137
)
152138

153-
self.A = nn.Parameter(torch.log(torch.empty(self.num_heads, dtype=torch.float32).uniform_(1, 16)))
154139
self.f_proj = nn.Sequential(
155140
nn.Linear(hidden_size, self.head_v_dim, bias=False),
156-
nn.Linear(self.head_v_dim, self.key_dim, bias=True)
141+
nn.Linear(self.head_v_dim, self.key_dim, bias=False),
157142
)
158143
self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
159144

145+
self.A_log = nn.Parameter(torch.log(torch.empty(self.num_heads, dtype=torch.float32).uniform_(1, 16)))
146+
self.dt_bias = nn.Parameter(torch.zeros(self.key_dim, dtype=torch.float32))
147+
160148
self.g_proj = nn.Sequential(
161149
nn.Linear(hidden_size, self.head_v_dim, bias=False),
162-
nn.Linear(self.head_v_dim, self.value_dim, bias=True)
150+
nn.Linear(self.head_v_dim, self.value_dim, bias=True),
163151
)
164152
self.o_norm = FusedRMSNormGated(self.head_v_dim, activation='sigmoid', eps=norm_eps)
165153
self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
166154

167155
def forward(
168156
self,
169157
hidden_states: torch.Tensor,
170-
attention_mask: Optional[torch.Tensor] = None,
171-
past_key_values: Optional[Cache] = None,
172-
use_cache: Optional[bool] = False,
173-
output_attentions: Optional[bool] = False,
174-
**kwargs: Unpack[Dict]
175-
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
158+
attention_mask: torch.Tensor | None = None,
159+
past_key_values: Cache | None = None,
160+
use_cache: bool | None = False,
161+
output_attentions: bool | None = False,
162+
**kwargs: Unpack[dict],
163+
) -> tuple[torch.Tensor, torch.Tensor | None, Cache | None]:
176164
if attention_mask is not None:
177165
assert len(attention_mask.shape) == 2, (
178166
"Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
@@ -190,7 +178,7 @@ def forward(
190178
if past_key_values is not None and len(past_key_values) > self.layer_idx:
191179
last_state = past_key_values[self.layer_idx]
192180

193-
cu_seqlens = kwargs.get('cu_seqlens', None)
181+
cu_seqlens = kwargs.get('cu_seqlens')
194182
if attention_mask is not None:
195183
indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:])
196184
hidden_states = index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0)
@@ -203,34 +191,34 @@ def forward(
203191
x=self.q_proj(hidden_states),
204192
cache=conv_state_q,
205193
output_final_state=use_cache,
206-
cu_seqlens=cu_seqlens
194+
cu_seqlens=cu_seqlens,
207195
)
208196
k, conv_state_k = self.k_conv1d(
209197
x=self.k_proj(hidden_states),
210198
cache=conv_state_k,
211199
output_final_state=use_cache,
212-
cu_seqlens=cu_seqlens
200+
cu_seqlens=cu_seqlens,
213201
)
214202
v, conv_state_v = self.v_conv1d(
215203
x=self.v_proj(hidden_states),
216204
cache=conv_state_v,
217205
output_final_state=use_cache,
218-
cu_seqlens=cu_seqlens
206+
cu_seqlens=cu_seqlens,
219207
)
220208
else:
221209
q = F.silu(self.q_proj(hidden_states))
222210
k = F.silu(self.k_proj(hidden_states))
223211
v = F.silu(self.v_proj(hidden_states))
224212

225213
g = self.f_proj(hidden_states)
226-
g = fused_kda_gate(g, self.A, self.head_k_dim)
214+
g = fused_kda_gate(g, self.A_log, self.head_k_dim, g_bias=self.dt_bias)
227215
beta = self.b_proj(hidden_states).sigmoid()
228216

229-
q, k = map(lambda x: rearrange(x, '... (h d) -> ... h d', d=self.head_k_dim), (q, k))
217+
q, k = (rearrange(x, '... (h d) -> ... h d', d=self.head_k_dim) for x in (q, k))
230218
v = rearrange(v, '... (h d) -> ... h d', d=self.head_v_dim)
231219

232220
if self.num_v_heads > self.num_heads:
233-
q, k = map(lambda x: repeat(x, '... h d -> ... (h g) d', g=self.num_v_heads // self.num_heads), (q, k))
221+
q, k = (repeat(x, '... h d -> ... (h g) d', g=self.num_v_heads // self.num_heads) for x in (q, k))
234222

235223
if self.allow_neg_eigval:
236224
beta = beta * 2.
@@ -268,7 +256,7 @@ def forward(
268256
recurrent_state=recurrent_state,
269257
conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
270258
layer_idx=self.layer_idx,
271-
offset=q_len
259+
offset=q_len,
272260
)
273261

274262
o = self.o_norm(o, rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=self.head_v_dim))

0 commit comments

Comments
 (0)