|
| 1 | +# Copyright (c) OpenMMLab. All rights reserved. |
| 2 | +from typing import Any, Dict, List, Tuple |
| 3 | + |
| 4 | +import torch |
| 5 | +import torch_mlu |
| 6 | +from torch_mlu.utils.model_transfer import transfer |
| 7 | +from torch import Tensor |
| 8 | + |
| 9 | +from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig |
| 10 | +from lmdeploy.pytorch.model_inputs import StepContext |
| 11 | +from lmdeploy.pytorch.models.utils.cudagraph import CudaGraphMeta |
| 12 | +from lmdeploy.utils import get_logger |
| 13 | + |
| 14 | +from ...graph_runner import GraphRunner |
| 15 | + |
| 16 | +logger = get_logger('lmdeploy') |
| 17 | + |
| 18 | +BuffType = Dict[str, Tensor] |
| 19 | + |
| 20 | +def next_power_of_2(n: int): |
| 21 | + """Return the smallest power of 2 greater than or equal to n.""" |
| 22 | + n -= 1 |
| 23 | + n |= n >> 1 |
| 24 | + n |= n >> 2 |
| 25 | + n |= n >> 4 |
| 26 | + n |= n >> 8 |
| 27 | + n |= n >> 16 |
| 28 | + n |= n >> 32 |
| 29 | + n += 1 |
| 30 | + return n |
| 31 | + |
| 32 | + |
| 33 | +def _false(*args, **kwargs): |
| 34 | + """default value of not support cuda graph.""" |
| 35 | + return False |
| 36 | + |
| 37 | + |
| 38 | +class CAMBSingleGraphRunner: |
| 39 | + """camb single graph runner.""" |
| 40 | + |
| 41 | + def __init__( |
| 42 | + self, |
| 43 | + model: torch.nn.Module, |
| 44 | + max_batches: int, |
| 45 | + max_tokens: int, |
| 46 | + num_blocks: int, |
| 47 | + is_decoding: bool, |
| 48 | + pool: Tuple[int, int], |
| 49 | + device: torch.device, |
| 50 | + ): |
| 51 | + self.model = model |
| 52 | + self.ctx_mgr = model.ctx_mgr |
| 53 | + self.meta = CudaGraphMeta( |
| 54 | + max_batchs=max_batches, |
| 55 | + max_tokens=max_tokens, |
| 56 | + num_blocks=num_blocks, |
| 57 | + is_decoding=is_decoding, |
| 58 | + device=device, |
| 59 | + input_buffers=dict(), |
| 60 | + output_buffers=dict(), |
| 61 | + ) |
| 62 | + self.device = device |
| 63 | + self.max_batches = max_batches |
| 64 | + self.max_tokens = max_tokens |
| 65 | + self.num_blocks = num_blocks |
| 66 | + self.is_decoding = is_decoding |
| 67 | + self.pool = pool |
| 68 | + self._graph: torch.cuda.CUDAGraph = None |
| 69 | + |
| 70 | + def capture(self, **kwargs): |
| 71 | + """capture graph.""" |
| 72 | + self.meta.input_buffers = self.make_Camb_buffers( |
| 73 | + self.meta, **kwargs) |
| 74 | + # padded_kwargs = self.model.fill_buffers_cudagraph(self.meta, **kwargs) |
| 75 | + padded_kwargs = self.update_Camb_buffer(self.meta, **kwargs) |
| 76 | + |
| 77 | + context = self.ctx_mgr.current_context() |
| 78 | + # self.model.update_context_cudagraph(self.meta, context) |
| 79 | + self.update_Camb_context(self.meta, context) |
| 80 | + current_stream = torch.cuda.current_stream() |
| 81 | + # warmup |
| 82 | + self.model(**padded_kwargs) |
| 83 | + |
| 84 | + self._graph = torch.cuda.CUDAGraph() |
| 85 | + # unsafe kernel call in other thread might invalid the capture |
| 86 | + # so we set thread_safe capture mode here. |
| 87 | + with torch.cuda.graph(self._graph, |
| 88 | + pool=self.pool, |
| 89 | + stream=current_stream, |
| 90 | + capture_error_mode='thread_local'): |
| 91 | + output = self.model(**padded_kwargs) |
| 92 | + |
| 93 | + output_buffers = dict(logits=output) |
| 94 | + self.meta.output_buffers = output_buffers |
| 95 | + return output |
| 96 | + |
| 97 | + def make_Camb_buffers(self, graph_meta: CudaGraphMeta, *args, |
| 98 | + **kwargs) -> BuffType: |
| 99 | + """make cudagraph buffers from forward inputs.""" |
| 100 | + max_batches = graph_meta.max_batchs |
| 101 | + max_tokens = graph_meta.max_tokens |
| 102 | + num_blocks = graph_meta.num_blocks |
| 103 | + device = graph_meta.device |
| 104 | + |
| 105 | + input_buffers: BuffType = dict() |
| 106 | + input_buffers['input_ids'] = torch.zeros(1, |
| 107 | + max_tokens, |
| 108 | + dtype=torch.int32, |
| 109 | + device=device) |
| 110 | + input_buffers['position_ids'] = torch.zeros((1, max_tokens), |
| 111 | + dtype=torch.int32, |
| 112 | + device=device) |
| 113 | + |
| 114 | + input_buffers['block_offsets'] = torch.zeros((max_batches, num_blocks), |
| 115 | + dtype=torch.int32, |
| 116 | + device=device) |
| 117 | + input_buffers['q_start_loc'] = torch.zeros(max_batches, |
| 118 | + dtype=torch.int32, |
| 119 | + device=device) |
| 120 | + input_buffers['q_seqlens'] = torch.zeros(max_batches, |
| 121 | + dtype=torch.int32, |
| 122 | + device=device) |
| 123 | + input_buffers['kv_seqlens'] = torch.zeros(max_batches, |
| 124 | + dtype=torch.int32, |
| 125 | + device=device) |
| 126 | + |
| 127 | + input_buffers['cu_seqlens'] = torch.zeros(max_batches+1, |
| 128 | + dtype=torch.int32, |
| 129 | + device=device) |
| 130 | + input_buffers['kv_start_indices'] = torch.ones(max_batches*max_tokens, |
| 131 | + dtype=torch.int32, |
| 132 | + device=device) * 512 |
| 133 | + |
| 134 | + input_buffers['local_adapter_ids'] = torch.zeros(max_batches, |
| 135 | + dtype=torch.int32, |
| 136 | + device=device) |
| 137 | + return input_buffers |
| 138 | + |
| 139 | + def update_Camb_buffer(self, graph_meta: CudaGraphMeta, |
| 140 | + input_ids: Tensor, position_ids: Tensor, |
| 141 | + past_key_values: List, attn_metadata: Any, |
| 142 | + inputs_embeds: Tensor, |
| 143 | + **kwargs) -> Dict[str, Tensor]: |
| 144 | + """fill cudagraph buffers from forward inputs.""" |
| 145 | + is_decoding = graph_meta.is_decoding |
| 146 | + block_offsets: Tensor = attn_metadata.block_offsets |
| 147 | + q_start_loc: Tensor = attn_metadata.q_start_loc |
| 148 | + q_seqlens: Tensor = attn_metadata.q_seqlens |
| 149 | + kv_seqlens: Tensor = attn_metadata.kv_seqlens |
| 150 | + |
| 151 | + cu_seqlens: Tensor = attn_metadata.cu_seqlens |
| 152 | + kv_start_indices: Tensor = attn_metadata.kv_start_indices |
| 153 | + |
| 154 | + input_buffers: BuffType = graph_meta.input_buffers |
| 155 | + |
| 156 | + batch_size, num_blocks = block_offsets.size() |
| 157 | + num_tokens = input_ids.size(-1) |
| 158 | + # fill buffer |
| 159 | + input_buffers['input_ids'][:, :num_tokens] = input_ids |
| 160 | + input_buffers['position_ids'][:, :num_tokens] = position_ids |
| 161 | + input_buffers[ |
| 162 | + 'block_offsets'][:batch_size, :num_blocks] = block_offsets |
| 163 | + if q_seqlens.data_ptr() != input_buffers['q_seqlens'].data_ptr(): |
| 164 | + input_buffers['q_seqlens'].zero_() |
| 165 | + input_buffers['q_seqlens'][:batch_size] = q_seqlens |
| 166 | + if kv_seqlens.data_ptr() != input_buffers['kv_seqlens'].data_ptr(): |
| 167 | + input_buffers['kv_seqlens'].zero_() |
| 168 | + input_buffers['kv_seqlens'][:batch_size] = kv_seqlens |
| 169 | + input_buffers['q_start_loc'][:batch_size] = q_start_loc |
| 170 | + |
| 171 | + input_buffers['cu_seqlens'][:batch_size+1] = cu_seqlens |
| 172 | + input_buffers['kv_start_indices'][:num_tokens] = kv_start_indices[:num_tokens] |
| 173 | + |
| 174 | + if inputs_embeds is not None: |
| 175 | + emb_size = inputs_embeds.size(-1) |
| 176 | + if 'inputs_embeds' not in input_buffers: |
| 177 | + max_num_tokens = input_buffers['input_ids'].size(-1) |
| 178 | + input_buffers['inputs_embeds'] = inputs_embeds.new_zeros( |
| 179 | + 1, max_num_tokens, emb_size) |
| 180 | + input_buffers['inputs_embeds'][:, :num_tokens] = inputs_embeds |
| 181 | + |
| 182 | + # create inputs |
| 183 | + new_batch_size = next_power_of_2(batch_size) |
| 184 | + new_num_tokens = next_power_of_2(num_tokens) |
| 185 | + |
| 186 | + attn_metadata.block_offsets = input_buffers[ |
| 187 | + 'block_offsets'][:new_batch_size] |
| 188 | + attn_metadata.q_start_loc = input_buffers[ |
| 189 | + 'q_start_loc'][:new_batch_size] |
| 190 | + attn_metadata.q_seqlens = input_buffers['q_seqlens'][:new_batch_size] |
| 191 | + attn_metadata.kv_seqlens = input_buffers['kv_seqlens'][:new_batch_size] |
| 192 | + |
| 193 | + attn_metadata.cu_seqlens = input_buffers['cu_seqlens'][:batch_size+1] |
| 194 | + attn_metadata.kv_start_indices = input_buffers['kv_start_indices'][:new_num_tokens] |
| 195 | + new_inputs = dict( |
| 196 | + past_key_values=past_key_values, |
| 197 | + attn_metadata=attn_metadata, |
| 198 | + ) |
| 199 | + |
| 200 | + if is_decoding: |
| 201 | + new_inputs['input_ids'] = input_buffers[ |
| 202 | + 'input_ids'][:, :new_batch_size] |
| 203 | + new_inputs['position_ids'] = input_buffers[ |
| 204 | + 'position_ids'][:, :new_batch_size] |
| 205 | + else: |
| 206 | + new_inputs['input_ids'] = input_buffers['input_ids'] |
| 207 | + new_inputs['position_ids'] = input_buffers['position_ids'] |
| 208 | + |
| 209 | + if inputs_embeds is not None: |
| 210 | + if is_decoding: |
| 211 | + new_inputs['inputs_embeds'] = input_buffers[ |
| 212 | + 'inputs_embeds'][:, :new_batch_size] |
| 213 | + else: |
| 214 | + new_inputs['inputs_embeds'] = input_buffers['inputs_embeds'] |
| 215 | + |
| 216 | + new_inputs.update(kwargs) |
| 217 | + return new_inputs |
| 218 | + |
| 219 | + def update_Camb_context(self, graph_meta, context): |
| 220 | + """update step context with input buffers.""" |
| 221 | + input_buffers = graph_meta.input_buffers |
| 222 | + local_adapter_ids = context.local_adapter_ids |
| 223 | + if local_adapter_ids is not None: |
| 224 | + if input_buffers['local_adapter_ids'].data_ptr( |
| 225 | + ) != local_adapter_ids.data_ptr(): |
| 226 | + input_buffers['local_adapter_ids'].fill_(0) |
| 227 | + batch_size = local_adapter_ids.size(0) |
| 228 | + input_buffers['local_adapter_ids'][:batch_size] = local_adapter_ids |
| 229 | + context.local_adapter_ids = input_buffers['local_adapter_ids'] |
| 230 | + context.q_seqlens = input_buffers['q_seqlens'] |
| 231 | + context.kv_seqlens = input_buffers['kv_seqlens'] |
| 232 | + context.q_start_loc = input_buffers['q_start_loc'] |
| 233 | + context.cu_seqlens = input_buffers['cu_seqlens'] |
| 234 | + context.kv_start_indices = input_buffers['kv_start_indices'] |
| 235 | + |
| 236 | + def forward(self, **kwargs): |
| 237 | + """forward.""" |
| 238 | + num_tokens = kwargs['input_ids'].size(-1) |
| 239 | + assert self._graph is not None |
| 240 | + self.update_Camb_buffer(self.meta, **kwargs) |
| 241 | + context = self.ctx_mgr.current_context() |
| 242 | + self.update_Camb_context(self.meta,context) |
| 243 | + |
| 244 | + self._graph.replay() |
| 245 | + |
| 246 | + output = self.meta.output_buffers['logits'][:, :num_tokens] |
| 247 | + return output |
| 248 | + |
| 249 | + def __del__(self): |
| 250 | + """del.""" |
| 251 | + del self._graph |
| 252 | + |
| 253 | + |
| 254 | +class CAMBGraphRunner(GraphRunner): |
| 255 | + """CAMB graph runner.""" |
| 256 | + |
| 257 | + def __init__(self, model: torch.nn.Module, model_config: ModelConfig, |
| 258 | + cache_config: CacheConfig, backend_config: BackendConfig, |
| 259 | + device: torch.device): |
| 260 | + super().__init__(model, model_config, cache_config, backend_config, |
| 261 | + device) |
| 262 | + self.max_batches = cache_config.max_batches |
| 263 | + self.max_tokens = cache_config.max_prefill_token_num |
| 264 | + self.num_blocks = cache_config.num_gpu_blocks |
| 265 | + |
| 266 | + self.enable_graph = self.check_enable_graph() |
| 267 | + |
| 268 | + self.graph_pool_handle = torch.cuda.graph_pool_handle() |
| 269 | + self._runner_map: Dict[Any, CAMBSingleGraphRunner] = dict() |
| 270 | + |
| 271 | + def check_enable_graph(self): |
| 272 | + """check enable graph.""" |
| 273 | + if self.backend_config.eager_mode: |
| 274 | + return _false |
| 275 | + |
| 276 | + return getattr(self.model, 'support_cuda_graph', _false) |
| 277 | + |
| 278 | + def get_graph_key(self, input_ids: torch.Tensor, |
| 279 | + position_ids: torch.Tensor, past_key_values: List, |
| 280 | + attn_metadata: Any, inputs_embeds: torch.Tensor, |
| 281 | + **kwargs): |
| 282 | + """get graph key.""" |
| 283 | + context = self.ctx_mgr.current_context() |
| 284 | + is_decoding = context.is_decoding |
| 285 | + num_tokens = input_ids.numel() |
| 286 | + new_num_tokens = next_power_of_2(num_tokens) |
| 287 | + return (new_num_tokens, is_decoding) |
| 288 | + |
| 289 | + def __call__(self, **kwargs): |
| 290 | + """call.""" |
| 291 | + enable_graph = self.enable_graph(**kwargs) |
| 292 | + |
| 293 | + if not enable_graph: |
| 294 | + return self.model(**kwargs) |
| 295 | + |
| 296 | + graph_key = self.get_graph_key(**kwargs) |
| 297 | + max_tokens = graph_key[0] |
| 298 | + is_decoding = graph_key[1] |
| 299 | + if graph_key not in self._runner_map: |
| 300 | + max_batches = max_tokens if is_decoding else self.max_batches |
| 301 | + runner = CAMBSingleGraphRunner(self.model, |
| 302 | + max_batches=max_batches, |
| 303 | + max_tokens=max_tokens, |
| 304 | + num_blocks=self.num_blocks, |
| 305 | + is_decoding=is_decoding, |
| 306 | + pool=self.graph_pool_handle, |
| 307 | + device=self.device) |
| 308 | + runner.capture(**kwargs) |
| 309 | + self._runner_map[graph_key] = runner |
| 310 | + else: |
| 311 | + runner = self._runner_map[graph_key] |
| 312 | + output = runner.forward(**kwargs) |
| 313 | + return output |
| 314 | + |
| 315 | + def prepare_inputs_for_generation( |
| 316 | + self, |
| 317 | + past_key_values: List[List[torch.Tensor]], |
| 318 | + inputs_embeds: torch.Tensor = None, |
| 319 | + context: StepContext = None, |
| 320 | + ): |
| 321 | + """prepare inputs.""" |
| 322 | + return self.model.prepare_inputs_for_generation( |
| 323 | + past_key_values=past_key_values, |
| 324 | + inputs_embeds=inputs_embeds, |
| 325 | + context=context, |
| 326 | + ) |
0 commit comments