Skip to content

Commit f4082b3

Browse files
committed
add input size assertions. fix kda doc
1 parent feb153a commit f4082b3

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

fla/ops/gla/chunk.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1316,5 +1316,9 @@ def chunk_gla(
13161316
)
13171317
if scale is None:
13181318
scale = q.shape[-1] ** -0.5
1319+
if initial_state is not None:
1320+
assert initial_state.dtype == torch.float32, "initial_state must be in float32."
1321+
assert q.shape == k.shape == g.shape, "q, k, g must have the same shape."
1322+
assert v.shape == (q.shape[0], q.shape[1], q.shape[2], v.shape[-1]), "v must be of shape (batch size, seq len, num of head, head dim)."
13191323
o, final_state = ChunkGLAFunction.apply(q, k, v, g, scale, initial_state, output_final_state, cu_seqlens)
13201324
return o, final_state

fla/ops/kda/chunk.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def chunk_kda(
271271
beta (torch.Tensor):
272272
betas of shape `[B, T, H]`.
273273
scale (Optional[float]):
274-
Scale factor for the RetNet attention scores.
274+
Scale factor for the KDA attention scores.
275275
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
276276
initial_state (Optional[torch.Tensor]):
277277
Initial state of shape `[N, H, K, V]` for `N` input sequences.
@@ -302,7 +302,7 @@ def chunk_kda(
302302
>>> k = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
303303
>>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
304304
>>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
305-
>>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
305+
>>> g = F.logsigmoid(torch.rand(B, T, H, K, dtype=torch.bfloat16, device='cuda'))
306306
>>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
307307
>>> o, ht = chunk_kda(
308308
q, k, v, g, beta,
@@ -334,6 +334,11 @@ def chunk_kda(
334334
f"The number of initial states is expected to be equal to the number of input sequences, "
335335
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.",
336336
)
337+
if initial_state is not None:
338+
assert initial_state.dtype == torch.float32, "initial_state must be in float32."
339+
assert q.shape == k.shape == g.shape, "q, k, g must have the same shape."
340+
assert beta.shape == (q.shape[0], q.shape[1], q.shape[2]), "beta must be of shape (batch size, seq len, num of head)."
341+
assert v.shape == (q.shape[0], q.shape[1], q.shape[2], v.shape[-1]), "v must be of shape (batch size, seq len, num of head, head dim)."
337342
if scale is None:
338343
scale = k.shape[-1] ** -0.5
339344
o, final_state = ChunkKDAFunction.apply(

0 commit comments

Comments
 (0)