Skip to content

Commit 0a64be7

Browse files
author
root
committed
support camb graph
1 parent 0c08f88 commit 0a64be7

File tree

2 files changed

+337
-0
lines changed

2 files changed

+337
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,326 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from typing import Any, Dict, List, Tuple
3+
4+
import torch
5+
import torch_mlu
6+
from torch_mlu.utils.model_transfer import transfer
7+
from torch import Tensor
8+
9+
from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig
10+
from lmdeploy.pytorch.model_inputs import StepContext
11+
from lmdeploy.pytorch.models.utils.cudagraph import CudaGraphMeta
12+
from lmdeploy.utils import get_logger
13+
14+
from ...graph_runner import GraphRunner
15+
16+
logger = get_logger('lmdeploy')
17+
18+
BuffType = Dict[str, Tensor]
19+
20+
def next_power_of_2(n: int):
21+
"""Return the smallest power of 2 greater than or equal to n."""
22+
n -= 1
23+
n |= n >> 1
24+
n |= n >> 2
25+
n |= n >> 4
26+
n |= n >> 8
27+
n |= n >> 16
28+
n |= n >> 32
29+
n += 1
30+
return n
31+
32+
33+
def _false(*args, **kwargs):
34+
"""default value of not support cuda graph."""
35+
return False
36+
37+
38+
class CAMBSingleGraphRunner:
39+
"""camb single graph runner."""
40+
41+
def __init__(
42+
self,
43+
model: torch.nn.Module,
44+
max_batches: int,
45+
max_tokens: int,
46+
num_blocks: int,
47+
is_decoding: bool,
48+
pool: Tuple[int, int],
49+
device: torch.device,
50+
):
51+
self.model = model
52+
self.ctx_mgr = model.ctx_mgr
53+
self.meta = CudaGraphMeta(
54+
max_batchs=max_batches,
55+
max_tokens=max_tokens,
56+
num_blocks=num_blocks,
57+
is_decoding=is_decoding,
58+
device=device,
59+
input_buffers=dict(),
60+
output_buffers=dict(),
61+
)
62+
self.device = device
63+
self.max_batches = max_batches
64+
self.max_tokens = max_tokens
65+
self.num_blocks = num_blocks
66+
self.is_decoding = is_decoding
67+
self.pool = pool
68+
self._graph: torch.cuda.CUDAGraph = None
69+
70+
def capture(self, **kwargs):
71+
"""capture graph."""
72+
self.meta.input_buffers = self.make_Camb_buffers(
73+
self.meta, **kwargs)
74+
# padded_kwargs = self.model.fill_buffers_cudagraph(self.meta, **kwargs)
75+
padded_kwargs = self.update_Camb_buffer(self.meta, **kwargs)
76+
77+
context = self.ctx_mgr.current_context()
78+
# self.model.update_context_cudagraph(self.meta, context)
79+
self.update_Camb_context(self.meta, context)
80+
current_stream = torch.cuda.current_stream()
81+
# warmup
82+
self.model(**padded_kwargs)
83+
84+
self._graph = torch.cuda.CUDAGraph()
85+
# unsafe kernel call in other thread might invalid the capture
86+
# so we set thread_safe capture mode here.
87+
with torch.cuda.graph(self._graph,
88+
pool=self.pool,
89+
stream=current_stream,
90+
capture_error_mode='thread_local'):
91+
output = self.model(**padded_kwargs)
92+
93+
output_buffers = dict(logits=output)
94+
self.meta.output_buffers = output_buffers
95+
return output
96+
97+
def make_Camb_buffers(self, graph_meta: CudaGraphMeta, *args,
98+
**kwargs) -> BuffType:
99+
"""make cudagraph buffers from forward inputs."""
100+
max_batches = graph_meta.max_batchs
101+
max_tokens = graph_meta.max_tokens
102+
num_blocks = graph_meta.num_blocks
103+
device = graph_meta.device
104+
105+
input_buffers: BuffType = dict()
106+
input_buffers['input_ids'] = torch.zeros(1,
107+
max_tokens,
108+
dtype=torch.int32,
109+
device=device)
110+
input_buffers['position_ids'] = torch.zeros((1, max_tokens),
111+
dtype=torch.int32,
112+
device=device)
113+
114+
input_buffers['block_offsets'] = torch.zeros((max_batches, num_blocks),
115+
dtype=torch.int32,
116+
device=device)
117+
input_buffers['q_start_loc'] = torch.zeros(max_batches,
118+
dtype=torch.int32,
119+
device=device)
120+
input_buffers['q_seqlens'] = torch.zeros(max_batches,
121+
dtype=torch.int32,
122+
device=device)
123+
input_buffers['kv_seqlens'] = torch.zeros(max_batches,
124+
dtype=torch.int32,
125+
device=device)
126+
127+
input_buffers['cu_seqlens'] = torch.zeros(max_batches+1,
128+
dtype=torch.int32,
129+
device=device)
130+
input_buffers['kv_start_indices'] = torch.ones(max_batches*max_tokens,
131+
dtype=torch.int32,
132+
device=device) * 512
133+
134+
input_buffers['local_adapter_ids'] = torch.zeros(max_batches,
135+
dtype=torch.int32,
136+
device=device)
137+
return input_buffers
138+
139+
def update_Camb_buffer(self, graph_meta: CudaGraphMeta,
140+
input_ids: Tensor, position_ids: Tensor,
141+
past_key_values: List, attn_metadata: Any,
142+
inputs_embeds: Tensor,
143+
**kwargs) -> Dict[str, Tensor]:
144+
"""fill cudagraph buffers from forward inputs."""
145+
is_decoding = graph_meta.is_decoding
146+
block_offsets: Tensor = attn_metadata.block_offsets
147+
q_start_loc: Tensor = attn_metadata.q_start_loc
148+
q_seqlens: Tensor = attn_metadata.q_seqlens
149+
kv_seqlens: Tensor = attn_metadata.kv_seqlens
150+
151+
cu_seqlens: Tensor = attn_metadata.cu_seqlens
152+
kv_start_indices: Tensor = attn_metadata.kv_start_indices
153+
154+
input_buffers: BuffType = graph_meta.input_buffers
155+
156+
batch_size, num_blocks = block_offsets.size()
157+
num_tokens = input_ids.size(-1)
158+
# fill buffer
159+
input_buffers['input_ids'][:, :num_tokens] = input_ids
160+
input_buffers['position_ids'][:, :num_tokens] = position_ids
161+
input_buffers[
162+
'block_offsets'][:batch_size, :num_blocks] = block_offsets
163+
if q_seqlens.data_ptr() != input_buffers['q_seqlens'].data_ptr():
164+
input_buffers['q_seqlens'].zero_()
165+
input_buffers['q_seqlens'][:batch_size] = q_seqlens
166+
if kv_seqlens.data_ptr() != input_buffers['kv_seqlens'].data_ptr():
167+
input_buffers['kv_seqlens'].zero_()
168+
input_buffers['kv_seqlens'][:batch_size] = kv_seqlens
169+
input_buffers['q_start_loc'][:batch_size] = q_start_loc
170+
171+
input_buffers['cu_seqlens'][:batch_size+1] = cu_seqlens
172+
input_buffers['kv_start_indices'][:num_tokens] = kv_start_indices[:num_tokens]
173+
174+
if inputs_embeds is not None:
175+
emb_size = inputs_embeds.size(-1)
176+
if 'inputs_embeds' not in input_buffers:
177+
max_num_tokens = input_buffers['input_ids'].size(-1)
178+
input_buffers['inputs_embeds'] = inputs_embeds.new_zeros(
179+
1, max_num_tokens, emb_size)
180+
input_buffers['inputs_embeds'][:, :num_tokens] = inputs_embeds
181+
182+
# create inputs
183+
new_batch_size = next_power_of_2(batch_size)
184+
new_num_tokens = next_power_of_2(num_tokens)
185+
186+
attn_metadata.block_offsets = input_buffers[
187+
'block_offsets'][:new_batch_size]
188+
attn_metadata.q_start_loc = input_buffers[
189+
'q_start_loc'][:new_batch_size]
190+
attn_metadata.q_seqlens = input_buffers['q_seqlens'][:new_batch_size]
191+
attn_metadata.kv_seqlens = input_buffers['kv_seqlens'][:new_batch_size]
192+
193+
attn_metadata.cu_seqlens = input_buffers['cu_seqlens'][:batch_size+1]
194+
attn_metadata.kv_start_indices = input_buffers['kv_start_indices'][:new_num_tokens]
195+
new_inputs = dict(
196+
past_key_values=past_key_values,
197+
attn_metadata=attn_metadata,
198+
)
199+
200+
if is_decoding:
201+
new_inputs['input_ids'] = input_buffers[
202+
'input_ids'][:, :new_batch_size]
203+
new_inputs['position_ids'] = input_buffers[
204+
'position_ids'][:, :new_batch_size]
205+
else:
206+
new_inputs['input_ids'] = input_buffers['input_ids']
207+
new_inputs['position_ids'] = input_buffers['position_ids']
208+
209+
if inputs_embeds is not None:
210+
if is_decoding:
211+
new_inputs['inputs_embeds'] = input_buffers[
212+
'inputs_embeds'][:, :new_batch_size]
213+
else:
214+
new_inputs['inputs_embeds'] = input_buffers['inputs_embeds']
215+
216+
new_inputs.update(kwargs)
217+
return new_inputs
218+
219+
def update_Camb_context(self, graph_meta, context):
220+
"""update step context with input buffers."""
221+
input_buffers = graph_meta.input_buffers
222+
local_adapter_ids = context.local_adapter_ids
223+
if local_adapter_ids is not None:
224+
if input_buffers['local_adapter_ids'].data_ptr(
225+
) != local_adapter_ids.data_ptr():
226+
input_buffers['local_adapter_ids'].fill_(0)
227+
batch_size = local_adapter_ids.size(0)
228+
input_buffers['local_adapter_ids'][:batch_size] = local_adapter_ids
229+
context.local_adapter_ids = input_buffers['local_adapter_ids']
230+
context.q_seqlens = input_buffers['q_seqlens']
231+
context.kv_seqlens = input_buffers['kv_seqlens']
232+
context.q_start_loc = input_buffers['q_start_loc']
233+
context.cu_seqlens = input_buffers['cu_seqlens']
234+
context.kv_start_indices = input_buffers['kv_start_indices']
235+
236+
def forward(self, **kwargs):
237+
"""forward."""
238+
num_tokens = kwargs['input_ids'].size(-1)
239+
assert self._graph is not None
240+
self.update_Camb_buffer(self.meta, **kwargs)
241+
context = self.ctx_mgr.current_context()
242+
self.update_Camb_context(self.meta,context)
243+
244+
self._graph.replay()
245+
246+
output = self.meta.output_buffers['logits'][:, :num_tokens]
247+
return output
248+
249+
def __del__(self):
250+
"""del."""
251+
del self._graph
252+
253+
254+
class CAMBGraphRunner(GraphRunner):
255+
"""CAMB graph runner."""
256+
257+
def __init__(self, model: torch.nn.Module, model_config: ModelConfig,
258+
cache_config: CacheConfig, backend_config: BackendConfig,
259+
device: torch.device):
260+
super().__init__(model, model_config, cache_config, backend_config,
261+
device)
262+
self.max_batches = cache_config.max_batches
263+
self.max_tokens = cache_config.max_prefill_token_num
264+
self.num_blocks = cache_config.num_gpu_blocks
265+
266+
self.enable_graph = self.check_enable_graph()
267+
268+
self.graph_pool_handle = torch.cuda.graph_pool_handle()
269+
self._runner_map: Dict[Any, CAMBSingleGraphRunner] = dict()
270+
271+
def check_enable_graph(self):
272+
"""check enable graph."""
273+
if self.backend_config.eager_mode:
274+
return _false
275+
276+
return getattr(self.model, 'support_cuda_graph', _false)
277+
278+
def get_graph_key(self, input_ids: torch.Tensor,
279+
position_ids: torch.Tensor, past_key_values: List,
280+
attn_metadata: Any, inputs_embeds: torch.Tensor,
281+
**kwargs):
282+
"""get graph key."""
283+
context = self.ctx_mgr.current_context()
284+
is_decoding = context.is_decoding
285+
num_tokens = input_ids.numel()
286+
new_num_tokens = next_power_of_2(num_tokens)
287+
return (new_num_tokens, is_decoding)
288+
289+
def __call__(self, **kwargs):
290+
"""call."""
291+
enable_graph = self.enable_graph(**kwargs)
292+
293+
if not enable_graph:
294+
return self.model(**kwargs)
295+
296+
graph_key = self.get_graph_key(**kwargs)
297+
max_tokens = graph_key[0]
298+
is_decoding = graph_key[1]
299+
if graph_key not in self._runner_map:
300+
max_batches = max_tokens if is_decoding else self.max_batches
301+
runner = CAMBSingleGraphRunner(self.model,
302+
max_batches=max_batches,
303+
max_tokens=max_tokens,
304+
num_blocks=self.num_blocks,
305+
is_decoding=is_decoding,
306+
pool=self.graph_pool_handle,
307+
device=self.device)
308+
runner.capture(**kwargs)
309+
self._runner_map[graph_key] = runner
310+
else:
311+
runner = self._runner_map[graph_key]
312+
output = runner.forward(**kwargs)
313+
return output
314+
315+
def prepare_inputs_for_generation(
316+
self,
317+
past_key_values: List[List[torch.Tensor]],
318+
inputs_embeds: torch.Tensor = None,
319+
context: StepContext = None,
320+
):
321+
"""prepare inputs."""
322+
return self.model.prepare_inputs_for_generation(
323+
past_key_values=past_key_values,
324+
inputs_embeds=inputs_embeds,
325+
context=context,
326+
)

lmdeploy/pytorch/backends/dlinfer/camb/op_backend.py

+11
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33

44
import torch
55

6+
from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig
67
from lmdeploy.utils import get_logger
78

9+
from ...base import OpType
810
from ..op_backend import DlinferOpsBackend
911

1012
logger = get_logger('lmdeploy')
@@ -105,3 +107,12 @@ def update_step_context(cls, step_context):
105107
step_context.attn_metadata = attn_metadata
106108
return step_context
107109

110+
@staticmethod
111+
def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig,
112+
cache_config: CacheConfig,
113+
backend_config: BackendConfig,
114+
device: torch.device):
115+
"""build graph runner."""
116+
from .graph_runner import CAMBGraphRunner
117+
return CAMBGraphRunner(model, model_config, cache_config,
118+
backend_config, device)

0 commit comments

Comments
 (0)