@@ -22,6 +22,9 @@ class TritonAttentionMetadata(AttentionMetadata):
22
22
fill_seqlens : torch .Tensor = None
23
23
quant_policy : Literal [0 , 4 , 8 ] = 0
24
24
kv_flatten_size : int = None
25
+ # flash mla
26
+ tile_scheduler_metadata : torch .Tensor = None
27
+ num_splits : torch .Tensor = None
25
28
26
29
27
30
def _cdiv (a , b ):
@@ -196,6 +199,144 @@ def forward(
196
199
return attn_output
197
200
198
201
202
+ class FlashMLAImpl (TritonAttentionImpl ):
203
+
204
+ def __init__ (
205
+ self ,
206
+ num_heads : int ,
207
+ head_size : int ,
208
+ scale : float = None ,
209
+ num_kv_heads : int = None ,
210
+ v_head_size : int = None ,
211
+ alibi : bool = False ,
212
+ sliding_window : int = None ,
213
+ logit_softcapping : float = None ,
214
+ causal : bool = True ,
215
+ ** kwargs ,
216
+ ):
217
+ assert sliding_window is None , 'sliding window not supported for FlashMLA'
218
+ assert alibi is False , 'alibi not supported for FlashMLA'
219
+ assert logit_softcapping is None , 'logit_softcapping not supported for FlashMLA'
220
+ super ().__init__ (
221
+ num_heads = num_heads ,
222
+ head_size = head_size ,
223
+ scale = scale ,
224
+ num_kv_heads = num_kv_heads ,
225
+ v_head_size = v_head_size ,
226
+ alibi = alibi ,
227
+ sliding_window = sliding_window ,
228
+ logit_softcapping = logit_softcapping ,
229
+ causal = causal ,
230
+ ** kwargs ,
231
+ )
232
+
233
+ from lmdeploy .pytorch .kernels .cuda import flash_mla_fwd
234
+ self .flash_mla_fwd = flash_mla_fwd
235
+ assert num_kv_heads == 1 , 'MLA requires num kv heads equal to 1'
236
+
237
+ def forward (
238
+ self ,
239
+ query : torch .Tensor ,
240
+ key : torch .Tensor ,
241
+ value : torch .Tensor ,
242
+ k_cache : torch .Tensor ,
243
+ v_cache : torch .Tensor ,
244
+ attn_metadata : TritonAttentionMetadata ,
245
+ k_scales_zeros : torch .Tensor = None ,
246
+ v_scales_zeros : torch .Tensor = None ,
247
+ inplace : bool = True ,
248
+ ) -> torch .Tensor :
249
+ """forward."""
250
+
251
+ block_offsets = attn_metadata .block_offsets
252
+ q_start_loc = attn_metadata .q_start_loc
253
+ fill_q_start_loc = q_start_loc
254
+ q_seqlens = attn_metadata .q_seqlens
255
+ fill_seqlens = q_seqlens
256
+ kv_start_loc = attn_metadata .kv_start_loc
257
+ kv_seqlens = attn_metadata .kv_seqlens
258
+ kv_flatten_size = attn_metadata .kv_flatten_size
259
+ quant_policy = attn_metadata .quant_policy
260
+ if attn_metadata .is_decoding :
261
+ max_q_seqlen = 1
262
+ else :
263
+ max_q_seqlen = query .numel () // (query .size (- 1 ) * query .size (- 2 ))
264
+ fill_max_q_seqlen = max_q_seqlen
265
+ if attn_metadata .fill_seqlens is not None :
266
+ fill_seqlens = attn_metadata .fill_seqlens
267
+ fill_max_q_seqlen = key .numel () // (key .size (- 1 ) * key .size (- 2 ))
268
+ fill_q_start_loc = fill_seqlens .cumsum (0 ) - fill_seqlens
269
+
270
+ # fill kv cache
271
+ if key is not None and value is not None :
272
+ self .fill_kv_cache (
273
+ key ,
274
+ value ,
275
+ k_cache ,
276
+ v_cache ,
277
+ fill_q_start_loc ,
278
+ fill_seqlens ,
279
+ kv_seq_length = kv_seqlens ,
280
+ max_q_seq_length = fill_max_q_seqlen ,
281
+ block_offsets = block_offsets ,
282
+ k_scales_zeros = k_scales_zeros ,
283
+ v_scales_zeros = v_scales_zeros ,
284
+ quant_policy = quant_policy ,
285
+ )
286
+
287
+ q_shape = query .shape
288
+ o_shape = q_shape [:- 1 ] + (self .v_head_size , )
289
+ attn_output = query .new_empty (o_shape )
290
+
291
+ is_decoding = attn_metadata .is_decoding
292
+ if is_decoding :
293
+ query = query .unsqueeze (1 )
294
+ if kv_seqlens .dtype == torch .int64 :
295
+ kv_seqlens = kv_seqlens .to (torch .int32 )
296
+ attn_output = self .flash_mla_fwd (query ,
297
+ k_cache = k_cache ,
298
+ block_table = block_offsets ,
299
+ cache_seqlens = kv_seqlens ,
300
+ head_dim_v = self .v_head_size ,
301
+ softmax_scale = self .scale ,
302
+ tile_scheduler_metadata = attn_metadata .tile_scheduler_metadata ,
303
+ num_splits = attn_metadata .num_splits ,
304
+ causal = True )
305
+
306
+ else :
307
+ BLOCK_BS = k_cache .size (1 )
308
+ # pad one more block to avoid invalid kv visit
309
+ out_size = (_cdiv (kv_flatten_size , BLOCK_BS ) * BLOCK_BS + BLOCK_BS )
310
+ flatten_k , flatten_v = self .flatten_kv_cache (
311
+ k_cache ,
312
+ v_cache ,
313
+ kv_seqlens ,
314
+ block_offsets ,
315
+ start_loc = kv_start_loc ,
316
+ out_size = out_size ,
317
+ out_dtype = query .dtype ,
318
+ k_scales_zeros = k_scales_zeros ,
319
+ v_scales_zeros = v_scales_zeros ,
320
+ quant_policy = quant_policy ,
321
+ )
322
+ self .flash_attention_fwd (
323
+ query ,
324
+ flatten_k ,
325
+ flatten_v ,
326
+ attn_output ,
327
+ q_start_loc = q_start_loc ,
328
+ q_seqlens = q_seqlens ,
329
+ kv_start_loc = kv_start_loc ,
330
+ kv_seqlens = kv_seqlens ,
331
+ max_seqlen = max_q_seqlen ,
332
+ window_size = self .sliding_window ,
333
+ sm_scale = self .scale ,
334
+ logit_softcapping = self .logit_softcapping ,
335
+ causal = self .causal ,
336
+ )
337
+ return attn_output
338
+
339
+
199
340
class TritonAttentionBuilder (AttentionBuilder [TritonAttentionMetadata ]):
200
341
"""triton attention builder."""
201
342
@@ -210,9 +351,21 @@ def build(
210
351
sliding_window : int = None ,
211
352
logical_softcapping : float = None ,
212
353
causal : bool = True ,
354
+ use_flash_mla : bool = False ,
213
355
** kwargs ,
214
356
) -> TritonAttentionImpl :
215
357
"""build."""
358
+ if use_flash_mla is True :
359
+ return FlashMLAImpl (num_heads ,
360
+ head_size ,
361
+ scale = scale ,
362
+ num_kv_heads = num_kv_heads ,
363
+ v_head_size = v_head_size ,
364
+ alibi = alibi ,
365
+ sliding_window = sliding_window ,
366
+ logical_softcapping = logical_softcapping ,
367
+ causal = causal ,
368
+ ** kwargs )
216
369
return TritonAttentionImpl (num_heads ,
217
370
head_size ,
218
371
scale = scale ,
0 commit comments