@@ -271,7 +271,7 @@ def chunk_kda(
271271 beta (torch.Tensor):
272272 betas of shape `[B, T, H]`.
273273 scale (Optional[float]):
274- Scale factor for the RetNet attention scores.
274+ Scale factor for the KDA attention scores.
275275 If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
276276 initial_state (Optional[torch.Tensor]):
277277 Initial state of shape `[N, H, K, V]` for `N` input sequences.
@@ -302,7 +302,7 @@ def chunk_kda(
302302 >>> k = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
303303 >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
304304 >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
305- >>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
305+ >>> g = F.logsigmoid(torch.rand(B, T, H, K, dtype=torch.bfloat16, device='cuda'))
306306 >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
307307 >>> o, ht = chunk_kda(
308308 q, k, v, g, beta,
@@ -334,6 +334,11 @@ def chunk_kda(
334334 f"The number of initial states is expected to be equal to the number of input sequences, "
335335 f"i.e., { len (cu_seqlens ) - 1 } rather than { initial_state .shape [0 ]} ." ,
336336 )
337+ if initial_state is not None :
338+ assert initial_state .dtype == torch .float32 , "initial_state must be in float32."
339+ assert q .shape == k .shape == g .shape , "q, k, g must have the same shape."
340+ assert beta .shape == (q .shape [0 ], q .shape [1 ], q .shape [2 ]), "beta must be of shape (batch size, seq len, num of head)."
341+ assert v .shape == (q .shape [0 ], q .shape [1 ], q .shape [2 ], v .shape [- 1 ]), "v must be of shape (batch size, seq len, num of head, head dim)."
337342 if scale is None :
338343 scale = k .shape [- 1 ] ** - 0.5
339344 o , final_state = ChunkKDAFunction .apply (
0 commit comments