Skip to content
2 changes: 1 addition & 1 deletion lmdeploy/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def calib_search_scale(parser):
@staticmethod
def device(parser,
default: str = 'cuda',
choices: List[str] = ['cuda', 'ascend', 'maca']):
choices: List[str] = ['cuda', 'ascend', 'maca', 'camb']):
"""Add argument device to parser."""

return parser.add_argument('--device',
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def __post_init__(self):
assert self.num_gpu_blocks >= 0, 'invalid num_gpu_blocks'
assert self.quant_policy in (0, 4, 8), 'invalid quant_policy'
assert self.device_type in [
'cuda', 'ascend', 'maca'
'cuda', 'ascend', 'maca', 'camb'
], (f'invalid device_type: {self.device_type}')
if self.quant_policy > 0 and self.device_type != 'cuda':
assert False, 'kv cache quantization only works for CUDA.'
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/pytorch/backends/dlinfer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .ascend import AscendOpsBackend # noqa: F401
from .maca import MacaOpsBackend # noqa: F401
from .camb import CambOpsBackend # noqa: F401
1 change: 0 additions & 1 deletion lmdeploy/pytorch/backends/dlinfer/apply_rotary_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from ..apply_rotary_emb import ApplyRotaryEmbBuilder, ApplyRotaryEmbImpl


class DlinferApplyRotaryEmbImpl(ApplyRotaryEmbImpl):
"""Apply rotary embedding implementation."""

Expand Down
13 changes: 12 additions & 1 deletion lmdeploy/pytorch/backends/dlinfer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ class DlinferAttentionMetadata(AttentionMetadata):
is_unpaged_prefill: Optional[bool] = None
max_q_seq_len: int = 1
max_kv_seq_len: int = 1

cu_seqlens: Optional[Tensor] = None
is_flash_attn_support_inplace: bool = True
is_mock_q_start_loc: bool = False

class DlinferAttentionImpl(AttentionImpl[DlinferAttentionMetadata]):
"""dlinfer attention implementation."""
Expand Down Expand Up @@ -74,11 +76,20 @@ def forward(
is_unpaged_prefill = attn_metadata.is_unpaged_prefill
max_q_seq_len = attn_metadata.max_q_seq_len
max_kv_seq_len = attn_metadata.max_kv_seq_len
cu_seqlens = attn_metadata.cu_seqlens
is_mock_q_start_loc = attn_metadata.is_mock_q_start_loc

# fill kv cache
k_cache, v_cache = self.fill_kv_cache(key, value, k_cache, v_cache,
kv_start_indices)

if is_unpaged_prefill:
inplace = inplace if attn_metadata.is_flash_attn_support_inplace \
else False

if is_mock_q_start_loc:
q_start_loc = cu_seqlens

if inplace:
attn_output = query[..., :self.v_head_size]
else:
Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/backends/dlinfer/camb/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .op_backend import CambOpsBackend # noqa: F401
313 changes: 313 additions & 0 deletions lmdeploy/pytorch/backends/dlinfer/camb/graph_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,313 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Dict, List, Tuple

import torch
import torch_mlu
from torch_mlu.utils.model_transfer import transfer
from torch import Tensor

from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig
from lmdeploy.pytorch.model_inputs import StepContext
from lmdeploy.pytorch.models.utils.cudagraph import CudaGraphMeta
from lmdeploy.utils import get_logger

from ...graph_runner import GraphRunner

logger = get_logger('lmdeploy')

BuffType = Dict[str, Tensor]

def round_up_to_multiple_of_8(n: int):
return (n + 7) // 8 * 8


def _false(*args, **kwargs):
"""default value of not support cuda graph."""
return False


class CAMBSingleGraphRunner:
"""camb single graph runner."""

def __init__(
self,
model: torch.nn.Module,
max_batches: int,
max_tokens: int,
num_blocks: int,
is_decoding: bool,
pool: Tuple[int, int],
device: torch.device,
):
self.model = model
self.ctx_mgr = model.ctx_mgr
self.meta = CudaGraphMeta(
max_batchs=max_batches,
max_tokens=max_tokens,
num_blocks=num_blocks,
is_decoding=is_decoding,
device=device,
input_buffers=dict(),
output_buffers=dict(),
)
self.device = device
self.max_batches = max_batches
self.max_tokens = max_tokens
self.num_blocks = num_blocks
self.is_decoding = is_decoding
self.pool = pool
self._graph: torch.mlu.CUDAGraph = None

def capture(self, **kwargs):
"""capture graph."""
self.meta.input_buffers = self.make_camb_buffers(
self.meta, **kwargs)
padded_kwargs = self.update_camb_buffer(self.meta, **kwargs)

context = self.ctx_mgr.current_context()
self.update_camb_context(self.meta, context)
current_stream = torch.mlu.current_stream()

# warmup
output = self.model(**padded_kwargs)

self._graph = torch.mlu.CUDAGraph()
# unsafe kernel call in other thread might invalid the capture
# so we set thread_safe capture mode here.
with torch.mlu.graph(self._graph,
pool=self.pool,
stream=current_stream,
capture_error_mode='thread_local'):
output = self.model(**padded_kwargs)

output_buffers = dict(logits=output)
self.meta.output_buffers = output_buffers
return output

def make_camb_buffers(self, graph_meta: CudaGraphMeta, *args,
**kwargs) -> BuffType:
"""make cudagraph buffers from forward inputs."""
max_batches = graph_meta.max_batchs
max_tokens = graph_meta.max_tokens
num_blocks = graph_meta.num_blocks
device = graph_meta.device

input_buffers: BuffType = dict()
input_buffers['input_ids'] = torch.zeros(1,
max_tokens,
dtype=torch.int32,
device=device)
input_buffers['position_ids'] = torch.ones((1, max_tokens),
dtype=torch.int32,
device=device)

input_buffers['block_offsets'] = torch.zeros((max_batches, num_blocks),
dtype=torch.int32,
device=device)

input_buffers['q_start_loc'] = torch.arange(max_batches,
dtype=torch.int32,
device=device)

input_buffers['q_seqlens'] = torch.ones(max_batches,
dtype=torch.int32,
device=device)

input_buffers['kv_seqlens'] = torch.ones(max_batches,
dtype=torch.int32,
device=device)

input_buffers['kv_start_indices'] = -torch.ones((max_batches*max_tokens),
dtype=torch.int32,
device=device)

input_buffers['local_adapter_ids'] = torch.zeros(max_batches,
dtype=torch.int32,
device=device)
return input_buffers

def update_camb_buffer(self, graph_meta: CudaGraphMeta,
input_ids: Tensor, position_ids: Tensor,
past_key_values: List, attn_metadata: Any,
inputs_embeds: Tensor,
**kwargs) -> Dict[str, Tensor]:
"""fill cudagraph buffers from forward inputs."""
is_decoding = graph_meta.is_decoding
block_offsets: Tensor = attn_metadata.block_offsets
q_start_loc: Tensor = attn_metadata.q_start_loc
q_seqlens: Tensor = attn_metadata.q_seqlens
kv_seqlens: Tensor = attn_metadata.kv_seqlens
kv_start_indices: Tensor = attn_metadata.kv_start_indices

input_buffers: BuffType = graph_meta.input_buffers

batch_size, num_blocks = block_offsets.size()
num_tokens = input_ids.size(-1)
# fill buffer
input_buffers['input_ids'][:, :num_tokens] = input_ids
input_buffers['position_ids'][:, :num_tokens] = position_ids
input_buffers[
'block_offsets'][:batch_size, :num_blocks] = block_offsets
# if q_seqlens.data_ptr() != input_buffers['q_seqlens'].data_ptr():
# input_buffers['q_seqlens'].zero_()
input_buffers['q_seqlens'][:batch_size] = q_seqlens
# if kv_seqlens.data_ptr() != input_buffers['kv_seqlens'].data_ptr():
# input_buffers['kv_seqlens'].zero_()
input_buffers['kv_seqlens'][:batch_size] = kv_seqlens
input_buffers['q_start_loc'][:batch_size] = q_start_loc


input_buffers['kv_start_indices'][:num_tokens] = kv_start_indices[:num_tokens]

if inputs_embeds is not None:
emb_size = inputs_embeds.size(-1)
if 'inputs_embeds' not in input_buffers:
max_num_tokens = input_buffers['input_ids'].size(-1)
input_buffers['inputs_embeds'] = inputs_embeds.new_zeros(
1, max_num_tokens, emb_size)
input_buffers['inputs_embeds'][:, :num_tokens] = inputs_embeds

# create inputs
new_batch_size = round_up_to_multiple_of_8(batch_size)
new_num_tokens = round_up_to_multiple_of_8(num_tokens)

attn_metadata.block_offsets = input_buffers[
'block_offsets'][:new_batch_size]
attn_metadata.q_start_loc = input_buffers[
'q_start_loc'][:new_batch_size]
attn_metadata.q_seqlens = input_buffers['q_seqlens'][:new_batch_size]
attn_metadata.kv_seqlens = input_buffers['kv_seqlens'][:new_batch_size]

attn_metadata.kv_start_indices = input_buffers['kv_start_indices'][:new_num_tokens]
new_inputs = dict(
past_key_values=past_key_values,
attn_metadata=attn_metadata,
)

if is_decoding:
new_inputs['input_ids'] = input_buffers[
'input_ids'][:, :new_batch_size]
new_inputs['position_ids'] = input_buffers[
'position_ids'][:, :new_batch_size]
else:
new_inputs['input_ids'] = input_buffers['input_ids']
new_inputs['position_ids'] = input_buffers['position_ids']

if inputs_embeds is not None:
if is_decoding:
new_inputs['inputs_embeds'] = input_buffers[
'inputs_embeds'][:, :new_batch_size]
else:
new_inputs['inputs_embeds'] = input_buffers['inputs_embeds']

new_inputs.update(kwargs)
return new_inputs

def update_camb_context(self, graph_meta, context):
"""update step context with input buffers."""
input_buffers = graph_meta.input_buffers
local_adapter_ids = context.local_adapter_ids
if local_adapter_ids is not None:
if input_buffers['local_adapter_ids'].data_ptr(
) != local_adapter_ids.data_ptr():
input_buffers['local_adapter_ids'].fill_(0)
batch_size = local_adapter_ids.size(0)
input_buffers['local_adapter_ids'][:batch_size] = local_adapter_ids
context.local_adapter_ids = input_buffers['local_adapter_ids']
context.q_seqlens = input_buffers['q_seqlens']
context.kv_seqlens = input_buffers['kv_seqlens']
context.q_start_loc = input_buffers['q_start_loc']
context.kv_start_indices = input_buffers['kv_start_indices']

def forward(self, **kwargs):
"""forward."""
num_tokens = kwargs['input_ids'].size(-1)
assert self._graph is not None
self.update_camb_buffer(self.meta, **kwargs)
context = self.ctx_mgr.current_context()
self.update_camb_context(self.meta,context)

self._graph.replay()

output = self.meta.output_buffers['logits'][:, :num_tokens]
return output

def __del__(self):
"""del."""
del self._graph


class CAMBGraphRunner(GraphRunner):
"""CAMB graph runner."""

def __init__(self, model: torch.nn.Module, model_config: ModelConfig,
cache_config: CacheConfig, backend_config: BackendConfig,
device: torch.device):
super().__init__(model, model_config, cache_config, backend_config,
device)
self.max_batches = cache_config.max_batches
self.max_tokens = cache_config.max_prefill_token_num
self.num_blocks = cache_config.num_gpu_blocks

self.enable_graph = self.check_enable_graph()

self.graph_pool_handle = torch.mlu.graph_pool_handle()
self._runner_map: Dict[Any, CAMBSingleGraphRunner] = dict()

def check_enable_graph(self):
"""check enable graph."""
if self.backend_config.eager_mode:
return _false

return getattr(self.model, 'support_cuda_graph', _false)

def get_graph_key(self, input_ids: torch.Tensor,
position_ids: torch.Tensor, past_key_values: List,
attn_metadata: Any, inputs_embeds: torch.Tensor,
**kwargs):
"""get graph key."""
context = self.ctx_mgr.current_context()
is_decoding = context.is_decoding
num_tokens = input_ids.numel()
new_num_tokens = round_up_to_multiple_of_8(num_tokens)
return (new_num_tokens, is_decoding)

def __call__(self, **kwargs):
"""call."""
enable_graph = self.enable_graph(**kwargs)
graph_key = self.get_graph_key(**kwargs)
max_tokens = graph_key[0]
is_decoding = graph_key[1]

if (not enable_graph) or (not is_decoding):
return self.model(**kwargs)

if graph_key not in self._runner_map:
max_batches = max_tokens if is_decoding else self.max_batches
runner = CAMBSingleGraphRunner(self.model,
max_batches=max_batches,
max_tokens=max_tokens,
num_blocks=self.num_blocks,
is_decoding=is_decoding,
pool=self.graph_pool_handle,
device=self.device)
runner.capture(**kwargs)
self._runner_map[graph_key] = runner
else:
runner = self._runner_map[graph_key]

output = runner.forward(**kwargs)
return output

def prepare_inputs_for_generation(
self,
past_key_values: List[List[torch.Tensor]],
inputs_embeds: torch.Tensor = None,
context: StepContext = None,
):
"""prepare inputs."""
return self.model.prepare_inputs_for_generation(
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
context=context,
)
Loading