Skip to content

Commit bce6863

Browse files
authored
Add flash mla (InternLM#3218)
* Add flash mla * fix * Add comment * get_mla_metadata outside graph * remove [:] * refine * refine * remove useless and update attn meta inside cuda backend * SM90 check
1 parent 20770be commit bce6863

File tree

15 files changed

+286
-14
lines changed

15 files changed

+286
-14
lines changed

lmdeploy/pytorch/backends/attention.py

+3
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def __init__(
3535
sliding_window: int = None,
3636
logit_softcapping: float = None,
3737
causal: bool = True,
38+
use_flash_mla: bool = False,
3839
**kwargs,
3940
) -> None:
4041
if scale is None:
@@ -55,6 +56,7 @@ def __init__(
5556
self.sliding_window = sliding_window
5657
self.logit_softcapping = logit_softcapping
5758
self.causal = causal
59+
self.use_flash_mla = use_flash_mla
5860

5961
@abstractmethod
6062
def forward(
@@ -85,6 +87,7 @@ def build(
8587
sliding_window: int = None,
8688
logical_softcapping: float = None,
8789
causal: bool = True,
90+
use_flash_mla: bool = False,
8891
**kwargs,
8992
) -> AttentionImpl[T]:
9093
"""build."""

lmdeploy/pytorch/backends/cuda/attention.py

+153
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ class TritonAttentionMetadata(AttentionMetadata):
2222
fill_seqlens: torch.Tensor = None
2323
quant_policy: Literal[0, 4, 8] = 0
2424
kv_flatten_size: int = None
25+
# flash mla
26+
tile_scheduler_metadata: torch.Tensor = None
27+
num_splits: torch.Tensor = None
2528

2629

2730
def _cdiv(a, b):
@@ -196,6 +199,144 @@ def forward(
196199
return attn_output
197200

198201

202+
class FlashMLAImpl(TritonAttentionImpl):
203+
204+
def __init__(
205+
self,
206+
num_heads: int,
207+
head_size: int,
208+
scale: float = None,
209+
num_kv_heads: int = None,
210+
v_head_size: int = None,
211+
alibi: bool = False,
212+
sliding_window: int = None,
213+
logit_softcapping: float = None,
214+
causal: bool = True,
215+
**kwargs,
216+
):
217+
assert sliding_window is None, 'sliding window not supported for FlashMLA'
218+
assert alibi is False, 'alibi not supported for FlashMLA'
219+
assert logit_softcapping is None, 'logit_softcapping not supported for FlashMLA'
220+
super().__init__(
221+
num_heads=num_heads,
222+
head_size=head_size,
223+
scale=scale,
224+
num_kv_heads=num_kv_heads,
225+
v_head_size=v_head_size,
226+
alibi=alibi,
227+
sliding_window=sliding_window,
228+
logit_softcapping=logit_softcapping,
229+
causal=causal,
230+
**kwargs,
231+
)
232+
233+
from lmdeploy.pytorch.kernels.cuda import flash_mla_fwd
234+
self.flash_mla_fwd = flash_mla_fwd
235+
assert num_kv_heads == 1, 'MLA requires num kv heads equal to 1'
236+
237+
def forward(
238+
self,
239+
query: torch.Tensor,
240+
key: torch.Tensor,
241+
value: torch.Tensor,
242+
k_cache: torch.Tensor,
243+
v_cache: torch.Tensor,
244+
attn_metadata: TritonAttentionMetadata,
245+
k_scales_zeros: torch.Tensor = None,
246+
v_scales_zeros: torch.Tensor = None,
247+
inplace: bool = True,
248+
) -> torch.Tensor:
249+
"""forward."""
250+
251+
block_offsets = attn_metadata.block_offsets
252+
q_start_loc = attn_metadata.q_start_loc
253+
fill_q_start_loc = q_start_loc
254+
q_seqlens = attn_metadata.q_seqlens
255+
fill_seqlens = q_seqlens
256+
kv_start_loc = attn_metadata.kv_start_loc
257+
kv_seqlens = attn_metadata.kv_seqlens
258+
kv_flatten_size = attn_metadata.kv_flatten_size
259+
quant_policy = attn_metadata.quant_policy
260+
if attn_metadata.is_decoding:
261+
max_q_seqlen = 1
262+
else:
263+
max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2))
264+
fill_max_q_seqlen = max_q_seqlen
265+
if attn_metadata.fill_seqlens is not None:
266+
fill_seqlens = attn_metadata.fill_seqlens
267+
fill_max_q_seqlen = key.numel() // (key.size(-1) * key.size(-2))
268+
fill_q_start_loc = fill_seqlens.cumsum(0) - fill_seqlens
269+
270+
# fill kv cache
271+
if key is not None and value is not None:
272+
self.fill_kv_cache(
273+
key,
274+
value,
275+
k_cache,
276+
v_cache,
277+
fill_q_start_loc,
278+
fill_seqlens,
279+
kv_seq_length=kv_seqlens,
280+
max_q_seq_length=fill_max_q_seqlen,
281+
block_offsets=block_offsets,
282+
k_scales_zeros=k_scales_zeros,
283+
v_scales_zeros=v_scales_zeros,
284+
quant_policy=quant_policy,
285+
)
286+
287+
q_shape = query.shape
288+
o_shape = q_shape[:-1] + (self.v_head_size, )
289+
attn_output = query.new_empty(o_shape)
290+
291+
is_decoding = attn_metadata.is_decoding
292+
if is_decoding:
293+
query = query.unsqueeze(1)
294+
if kv_seqlens.dtype == torch.int64:
295+
kv_seqlens = kv_seqlens.to(torch.int32)
296+
attn_output = self.flash_mla_fwd(query,
297+
k_cache=k_cache,
298+
block_table=block_offsets,
299+
cache_seqlens=kv_seqlens,
300+
head_dim_v=self.v_head_size,
301+
softmax_scale=self.scale,
302+
tile_scheduler_metadata=attn_metadata.tile_scheduler_metadata,
303+
num_splits=attn_metadata.num_splits,
304+
causal=True)
305+
306+
else:
307+
BLOCK_BS = k_cache.size(1)
308+
# pad one more block to avoid invalid kv visit
309+
out_size = (_cdiv(kv_flatten_size, BLOCK_BS) * BLOCK_BS + BLOCK_BS)
310+
flatten_k, flatten_v = self.flatten_kv_cache(
311+
k_cache,
312+
v_cache,
313+
kv_seqlens,
314+
block_offsets,
315+
start_loc=kv_start_loc,
316+
out_size=out_size,
317+
out_dtype=query.dtype,
318+
k_scales_zeros=k_scales_zeros,
319+
v_scales_zeros=v_scales_zeros,
320+
quant_policy=quant_policy,
321+
)
322+
self.flash_attention_fwd(
323+
query,
324+
flatten_k,
325+
flatten_v,
326+
attn_output,
327+
q_start_loc=q_start_loc,
328+
q_seqlens=q_seqlens,
329+
kv_start_loc=kv_start_loc,
330+
kv_seqlens=kv_seqlens,
331+
max_seqlen=max_q_seqlen,
332+
window_size=self.sliding_window,
333+
sm_scale=self.scale,
334+
logit_softcapping=self.logit_softcapping,
335+
causal=self.causal,
336+
)
337+
return attn_output
338+
339+
199340
class TritonAttentionBuilder(AttentionBuilder[TritonAttentionMetadata]):
200341
"""triton attention builder."""
201342

@@ -210,9 +351,21 @@ def build(
210351
sliding_window: int = None,
211352
logical_softcapping: float = None,
212353
causal: bool = True,
354+
use_flash_mla: bool = False,
213355
**kwargs,
214356
) -> TritonAttentionImpl:
215357
"""build."""
358+
if use_flash_mla is True:
359+
return FlashMLAImpl(num_heads,
360+
head_size,
361+
scale=scale,
362+
num_kv_heads=num_kv_heads,
363+
v_head_size=v_head_size,
364+
alibi=alibi,
365+
sliding_window=sliding_window,
366+
logical_softcapping=logical_softcapping,
367+
causal=causal,
368+
**kwargs)
216369
return TritonAttentionImpl(num_heads,
217370
head_size,
218371
scale=scale,

lmdeploy/pytorch/backends/cuda/op_backend.py

+7
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,13 @@ def update_step_context(cls, step_context):
125125
kv_flatten_size=kv_flatten_size,
126126
quant_policy=step_context.kv_quant_policy,
127127
)
128+
if getattr(step_context.model_config, 'use_flash_mla', False) is True:
129+
if step_context.is_decoding is True:
130+
import flash_mla_cuda
131+
tile_scheduler_metadata, num_splits = flash_mla_cuda.get_mla_metadata(
132+
attn_metadata.kv_seqlens.to(torch.int32), step_context.model_config.num_attention_heads, 1)
133+
attn_metadata.tile_scheduler_metadata = tile_scheduler_metadata
134+
attn_metadata.num_splits = num_splits
128135

129136
cross_seqlens = step_context.cross_seqlens
130137
cross_kv_seqlens = step_context.cross_kv_seqlens

lmdeploy/pytorch/config.py

+1
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ class ModelConfig:
108108
hf_config: Any = None
109109
cogvlm_style: bool = False
110110
custom_module_map: Dict[str, setattr] = None
111+
use_flash_mla: bool = False
111112

112113
def get_head_size(self):
113114
"""get head size."""

lmdeploy/pytorch/configurations/deepseek_v2.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from lmdeploy.pytorch.config import ModelConfig
33

44
from .builder import AutoModelConfigBuilder
5+
from .utils import flash_mla_available
56

67

78
class DeepseekV2ModelConfigBuilder(AutoModelConfigBuilder):
@@ -23,6 +24,7 @@ def build(cls, hf_config, model_path: str = None, **kwargs):
2324
tp = kwargs.get('tp', 1)
2425
# update num_kv_heads for tp mode
2526
num_key_value_heads = cls.update_num_kv_heads(hf_config, tp, num_key_value_heads)
27+
hf_config.use_flash_mla = flash_mla_available()
2628

2729
return ModelConfig(hidden_size=hf_config.hidden_size,
2830
num_layers=hf_config.num_hidden_layers,
@@ -33,4 +35,5 @@ def build(cls, hf_config, model_path: str = None, **kwargs):
3335
head_dim=head_dim,
3436
k_head_dim=k_head_dim,
3537
v_head_dim=v_head_dim,
36-
vocab_size=hf_config.vocab_size)
38+
vocab_size=hf_config.vocab_size,
39+
use_flash_mla=hf_config.use_flash_mla)
+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import torch
3+
4+
from lmdeploy.utils import get_logger
5+
6+
logger = get_logger('lmdeploy')
7+
8+
9+
def flash_mla_available():
10+
"""Check if flash mla is available."""
11+
# use flash_mla by default if it is installed
12+
use_flash_mla = False
13+
try:
14+
import flash_mla_cuda # noqa
15+
if torch.cuda.get_device_properties(0).major >= 9:
16+
use_flash_mla = True
17+
except ImportError:
18+
logger.warning('For higher performance, please install flash_mla https://github.com/deepseek-ai/FlashMLA')
19+
return use_flash_mla

lmdeploy/pytorch/engine/engine.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,10 @@ class InferOutput:
4040
logits: torch.Tensor = None
4141

4242

43-
def _tensorlize_block_offsets(block_offsets):
43+
def _tensorlize_block_offsets(block_offsets, dtype=torch.int32):
4444
"""tensorlize block_offsets."""
4545
from torch.nn.utils.rnn import pad_sequence
46-
block_offsets = [torch.from_numpy(off) for off in block_offsets]
46+
block_offsets = [torch.from_numpy(off).to(dtype) for off in block_offsets]
4747
block_offsets = pad_sequence(block_offsets, batch_first=True)
4848
return block_offsets
4949

@@ -371,6 +371,13 @@ def model_config(self) -> ModelConfig:
371371
def gpu_count(self):
372372
return self.tp * self.dp
373373

374+
@property
375+
def torch_int_dtype(self):
376+
"""return int32 for cuda, int64 for others."""
377+
if self.executor.device_type == 'cuda':
378+
return torch.int32
379+
return torch.int64
380+
374381
@logging_timer('CreateModelInputs', logger)
375382
def create_model_inputs(self, messages: SeqList, is_prefill: bool):
376383
"""create model inputs from messages.
@@ -398,7 +405,7 @@ def create_model_inputs(self, messages: SeqList, is_prefill: bool):
398405
max_q_seq_length = seq_length.max().item()
399406

400407
block_offsets = self.scheduler.get_block_tables(messages)
401-
block_offsets = _tensorlize_block_offsets(block_offsets)
408+
block_offsets = _tensorlize_block_offsets(block_offsets, dtype=self.torch_int_dtype)
402409

403410
local_adapter_ids = None
404411
if self.adapter_manager.num_adapters() > 1:

lmdeploy/pytorch/engine/executor/base.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,16 @@ def _get_runtime_size(self, num_free_gpu_mem: int, cache_block_size: int, vocal_
101101

102102
def _adjust_block_size(self):
103103
"""adjust block_size."""
104+
if self.model_config.use_flash_mla is True:
105+
if self.cache_config.block_size != 64:
106+
raise ValueError('Please set block_size to 64 for flash_mla.')
107+
return
104108
# TODO: support kernel with both large head dim and large block size.
105109
if self.model_config.k_head_dim >= 512 and self.cache_config.block_size > 32:
106110
self.cache_config.block_size = 32
107-
logger.warning(f'Update `block_size={self.cache_config.block_size}`'
108-
f' for large `head_dim={self.model_config.k_head_dim}`.')
111+
logger.warning(
112+
f'Update `block_size={self.cache_config.block_size}` for large `head_dim={self.model_config.k_head_dim}`.' # noqa
113+
)
109114

110115
def update_configs(self):
111116
"""update cache config."""
@@ -114,7 +119,7 @@ def update_configs(self):
114119
model_config = self.model_config
115120
free_mems = self.gather_free_mem()
116121
free_mem = min(free_mems)
117-
logger.debug(f'minimal free gpu memory: {free_mem>>20} mb')
122+
logger.debug(f'minimal free gpu memory: {free_mem >> 20} mb')
118123
vocal_size = self.model_config.vocab_size
119124

120125
cache_block_size = CacheEngine.get_cache_block_size(cache_config.block_size, model_config, self.tp,
@@ -126,7 +131,7 @@ def update_configs(self):
126131
cache_config.max_prefill_token_num = max_prefill_token_num
127132
logger.warning(f'No enough memory. Update max_prefill_token_num={max_prefill_token_num}')
128133
free_mem -= runtime_mem
129-
logger.debug(f'estimated max runtime memory: {runtime_mem>>20} mb')
134+
logger.debug(f'estimated max runtime memory: {runtime_mem >> 20} mb')
130135
available_mem = free_mem * cache_config.cache_max_entry_count
131136

132137
if cache_config.num_gpu_blocks == 0:
@@ -144,5 +149,5 @@ def init(self):
144149
self.update_configs()
145150
logger.info('Building GraphRunner.')
146151
self.build_graph_runner()
147-
logger.info(f'Building CacheEngine with config:\n{self.cache_config}.')
152+
logger.info(f'Building CacheEngine with config: \n{self.cache_config}.')
148153
self.build_cache_engine()

lmdeploy/pytorch/kernels/cuda/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .alibi_pagedattention import alibi_paged_attention_fwd
44
from .apply_rotary_pos_emb import apply_rotary_pos_emb
55
from .fill_kv_cache import fill_kv_cache
6+
from .flash_mla import flash_mla_fwd
67
from .flashattention import flash_attention_fwd
78
from .flatten_kv_cache import flatten_kv_cache
89
from .fused_moe import fused_moe
@@ -30,4 +31,5 @@
3031
'flash_attention_fwd',
3132
'flatten_kv_cache',
3233
'fused_moe_w8a8',
34+
'flash_mla_fwd',
3335
]

0 commit comments

Comments
 (0)