Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support ascend w8a8 graph_mode #3267

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lmdeploy/cli/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def add_parser_smooth_quant():
type=str,
default='./work_dir',
help='The working directory for outputs. defaults to "./work_dir"')
parser.add_argument('--device', type=str, default='cuda', help='Device for weight quantization (cuda or npu)')
ArgumentHelper.calib_dataset(parser)
ArgumentHelper.calib_samples(parser)
ArgumentHelper.calib_seqlen(parser)
Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/lite/apis/smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from lmdeploy.lite.apis.calibrate import LAYER_TYPE_MAP, NORM_TYPE_MAP, calibrate
from lmdeploy.lite.quantization.awq import FC_FCS_MAP, NORM_FCS_MAP, awq_layers, skipped_module, smooth_layers
from lmdeploy.lite.utils import collect_target_modules
from lmdeploy.pytorch.check_env import try_import_deeplink
from lmdeploy.pytorch.models import QLinear, QRMSNorm


Expand All @@ -26,6 +27,7 @@ def smooth_quant(model: str,
quant_dtype: Literal['int8', 'fp8', 'float8_e4m3fn', 'float8_e5m2'] = 'int8',
revision: str = None,
download_dir: str = None):
try_import_deeplink(device)
if quant_dtype == 'fp8':
quant_dtype = 'float8_e4m3fn'

Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/kernels/default/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .multinomial_sampling import multinomial_sampling
from .w8a8_kernels import per_channel_quant

__all__ = [
'multinomial_sampling',
'per_channel_quant',
]
30 changes: 30 additions & 0 deletions lmdeploy/pytorch/kernels/default/w8a8_kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch


def per_channel_quant(x: torch.Tensor, dtype: torch.dtype):
"""Quantize the input tensor 'x' channel-wise using the given number of
bits.

Args:
x (torch.Tensor): The input tensor to be quantized. Must be a
2-dimensional tensor.
dtype (torch.dtype): The data type to which the quantized tensor should
be converted.

Returns:
tuple: A tuple containing two items -- the quantized tensor and
the scale used for quantization.
"""
assert x.ndim == 2
x = x.to(torch.float32)
x_absmax = x.view(x.shape[0], -1).abs().max(dim=1, keepdim=True)[0]
qtype_info = torch.finfo(dtype) if dtype.is_floating_point else torch.iinfo(dtype)
q_max = qtype_info.max
q_min = qtype_info.min
scale = x_absmax / q_max
x_q = x / scale
if not dtype.is_floating_point:
x_q = torch.round(x_q)
x_q = x_q.clamp(q_min, q_max).to(dtype)
return x_q, scale
10 changes: 9 additions & 1 deletion lmdeploy/pytorch/kernels/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import inspect
from typing import Callable

from torch import Tensor

from lmdeploy.utils import get_logger

from ..devices import DeviceContext, get_device_manager
Expand Down Expand Up @@ -64,6 +66,8 @@ def __init__(self, func_name: str):
self.func_name = func_name
self.dispatched_func = self.load_and_call
self.device_manager.register_context_callback(self.device_callback)
self.device_type = None
self.device_map = {'cuda': 'cuda', 'npu': 'dlinfer', 'maca': 'dlinfer', 'camb': 'dlinfer'}

def device_callback(self, context: DeviceContext):
"""device context callback."""
Expand All @@ -88,7 +92,11 @@ def load_func(self, device: str):

def load_and_call(self, *args, **kwargs):
"""load and call."""
device = self.device_manager.current_context().device_type
if self.device_type is None:
device_type = self.device_manager.current_context().device_type
self.device_type = next(
(arg.device.type for arg in args if isinstance(arg, Tensor) and arg.device.type != 'cpu'), device_type)
device = self.device_map[self.device_type]
if device not in self.impl_map:
self.load_func(device)
self.dispatched_func = self.impl_map[device]
Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/pytorch/kernels/dlinfer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from ..default import multinomial_sampling
from ..default import multinomial_sampling, per_channel_quant
from .apply_rotary_pos_emb import apply_rotary_pos_emb
from .awq_kernels import awq_linear
from .fill_kv_cache import fill_kv_cache
Expand All @@ -21,4 +21,5 @@
'linear',
'moe_gating_topk_softmax',
'multinomial_sampling',
'per_channel_quant',
]
10 changes: 6 additions & 4 deletions lmdeploy/pytorch/models/q_modules.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.

from dataclasses import dataclass
from dataclasses import dataclass, fields

import torch
import torch.nn as nn
Expand All @@ -19,13 +19,15 @@ class QTensor:
scale: torch.Tensor
zero_point: torch.Tensor = None

def __post_init__(self):
self.fields = [field.name for field in fields(self)]

def __getattr__(self, name: str):
"""Allows attribute access to be forwarded to the wrapped tensor when
the attribute doesn't exist in QTensor."""
try:
if name in self.fields:
return super().__getattr__(name)
except AttributeError:
return getattr(self.tensor, name)
return getattr(self.tensor, name)


class QRMSNorm(nn.Module):
Expand Down