Skip to content

Commit e823c48

Browse files
committed
support dlinfer smooth_quant
1 parent 9bacc1c commit e823c48

File tree

6 files changed

+46
-2
lines changed

6 files changed

+46
-2
lines changed

lmdeploy/cli/lite.py

+1
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def add_parser_smooth_quant():
9292
type=str,
9393
default='./work_dir',
9494
help='The working directory for outputs. defaults to "./work_dir"')
95+
parser.add_argument('--device', type=str, default='cuda', help='Device for weight quantization (cuda or npu)')
9596
ArgumentHelper.calib_dataset(parser)
9697
ArgumentHelper.calib_samples(parser)
9798
ArgumentHelper.calib_seqlen(parser)

lmdeploy/lite/apis/smooth_quant.py

+2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from lmdeploy.lite.apis.calibrate import LAYER_TYPE_MAP, NORM_TYPE_MAP, calibrate
1111
from lmdeploy.lite.quantization.awq import FC_FCS_MAP, NORM_FCS_MAP, awq_layers, skipped_module, smooth_layers
1212
from lmdeploy.lite.utils import collect_target_modules
13+
from lmdeploy.pytorch.check_env import try_import_deeplink
1314
from lmdeploy.pytorch.models import QLinear, QRMSNorm
1415

1516

@@ -26,6 +27,7 @@ def smooth_quant(model: str,
2627
quant_dtype: Literal['int8', 'fp8', 'float8_e4m3fn', 'float8_e5m2'] = 'int8',
2728
revision: str = None,
2829
download_dir: str = None):
30+
try_import_deeplink(device)
2931
if quant_dtype == 'fp8':
3032
quant_dtype = 'float8_e4m3fn'
3133

Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from .multinomial_sampling import multinomial_sampling
3+
from .w8a8_kernels import per_channel_quant
34

45
__all__ = [
56
'multinomial_sampling',
7+
'per_channel_quant',
68
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import torch
3+
4+
5+
def per_channel_quant(x: torch.Tensor, dtype: torch.dtype):
6+
"""Quantize the input tensor 'x' channel-wise using the given number of
7+
bits.
8+
9+
Args:
10+
x (torch.Tensor): The input tensor to be quantized. Must be a
11+
2-dimensional tensor.
12+
dtype (torch.dtype): The data type to which the quantized tensor should
13+
be converted.
14+
15+
Returns:
16+
tuple: A tuple containing two items -- the quantized tensor and
17+
the scale used for quantization.
18+
"""
19+
assert x.ndim == 2
20+
x = x.to(torch.float32)
21+
x_absmax = x.view(x.shape[0], -1).abs().max(dim=1, keepdim=True)[0]
22+
qtype_info = torch.finfo(dtype) if dtype.is_floating_point else torch.iinfo(dtype)
23+
q_max = qtype_info.max
24+
q_min = qtype_info.min
25+
scale = x_absmax / q_max
26+
x_q = x / scale
27+
if not dtype.is_floating_point:
28+
x_q = torch.round(x_q)
29+
x_q = x_q.clamp(q_min, q_max).to(dtype)
30+
return x_q, scale

lmdeploy/pytorch/kernels/dispatcher.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import inspect
44
from typing import Callable
55

6+
from torch import Tensor
7+
68
from lmdeploy.utils import get_logger
79

810
from ..devices import DeviceContext, get_device_manager
@@ -64,6 +66,8 @@ def __init__(self, func_name: str):
6466
self.func_name = func_name
6567
self.dispatched_func = self.load_and_call
6668
self.device_manager.register_context_callback(self.device_callback)
69+
self.device_type = None
70+
self.device_map = {'cuda': 'cuda', 'npu': 'dlinfer', 'maca': 'dlinfer', 'camb': 'dlinfer'}
6771

6872
def device_callback(self, context: DeviceContext):
6973
"""device context callback."""
@@ -88,7 +92,11 @@ def load_func(self, device: str):
8892

8993
def load_and_call(self, *args, **kwargs):
9094
"""load and call."""
91-
device = self.device_manager.current_context().device_type
95+
if self.device_type is None:
96+
device_type = self.device_manager.current_context().device_type
97+
self.device_type = next(
98+
(arg.device.type for arg in args if isinstance(arg, Tensor) and arg.device.type != 'cpu'), device_type)
99+
device = self.device_map[self.device_type]
92100
if device not in self.impl_map:
93101
self.load_func(device)
94102
self.dispatched_func = self.impl_map[device]

lmdeploy/pytorch/kernels/dlinfer/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
from ..default import multinomial_sampling
2+
from ..default import multinomial_sampling, per_channel_quant
33
from .apply_rotary_pos_emb import apply_rotary_pos_emb
44
from .awq_kernels import awq_linear
55
from .fill_kv_cache import fill_kv_cache
@@ -21,4 +21,5 @@
2121
'linear',
2222
'moe_gating_topk_softmax',
2323
'multinomial_sampling',
24+
'per_channel_quant',
2425
]

0 commit comments

Comments
 (0)