Skip to content

Commit c7cb5c3

Browse files
authored
[Misc] GPTQ Activation Ordering (vllm-project#8135)
1 parent f9b4a2d commit c7cb5c3

File tree

4 files changed

+64
-15
lines changed

4 files changed

+64
-15
lines changed

tests/weight_loading/models.txt

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
2121
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
2222
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
2323
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
24+
compressed-tensors, nm-testing/TinyLlama-1.1B-Chat-v1.0-actorder-group, main
2425
awq, casperhansen/mixtral-instruct-awq, main
2526
awq_marlin, casperhansen/mixtral-instruct-awq, main
2627
fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,8 @@ def _get_scheme_from_parts(
232232
return CompressedTensorsWNA16(
233233
num_bits=weight_quant.num_bits,
234234
strategy=weight_quant.strategy,
235-
group_size=weight_quant.group_size)
235+
group_size=weight_quant.group_size,
236+
actorder=weight_quant.actorder)
236237

237238
# Detect If Activation Quantization.
238239
# TODO @dsikka: clean-up conditions

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py

+33-12
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,18 @@
55
from vllm import _custom_ops as ops
66
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
77
CompressedTensorsScheme)
8+
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
9+
ActivationOrdering)
810
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
911
apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace,
10-
marlin_permute_scales, replace_tensor, verify_marlin_supported,
12+
marlin_permute_scales, marlin_repeat_scales_on_all_ranks,
13+
marlin_sort_g_idx, replace_tensor, verify_marlin_supported,
1114
verify_marlin_supports_shape)
1215
from vllm.model_executor.parameter import (BasevLLMParameter,
1316
ChannelQuantScaleParameter,
1417
GroupQuantScaleParameter,
15-
PackedvLLMParameter)
18+
PackedvLLMParameter,
19+
RowvLLMParameter)
1620
from vllm.scalar_type import scalar_types
1721

1822
__all__ = ["CompressedTensorsWNA16"]
@@ -28,11 +32,13 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
2832
def __init__(self,
2933
strategy: str,
3034
num_bits: int,
31-
group_size: Optional[int] = None):
35+
group_size: Optional[int] = None,
36+
actorder: Optional[ActivationOrdering] = None):
3237

3338
self.pack_factor = 32 // num_bits
3439
self.strategy = strategy
3540
self.group_size = -1 if group_size is None else group_size
41+
self.has_g_idx = actorder == ActivationOrdering.GROUP
3642

3743
if self.group_size == -1 and self.strategy != "channel":
3844
raise ValueError("Marlin kernels require group quantization or "
@@ -64,12 +70,10 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
6470
output_size_per_partition = sum(output_partition_sizes)
6571

6672
# If group_size is -1, we are in channelwise case.
67-
channelwise = (self.group_size == -1)
6873
group_size = self.group_size if self.group_size != -1 else input_size
6974
row_parallel = (input_size != input_size_per_partition)
70-
# In the case of channelwise quantization, we need to replicate the
71-
# scales across all gpus.
72-
partition_scales = (row_parallel and not channelwise)
75+
partition_scales = not marlin_repeat_scales_on_all_ranks(
76+
self.has_g_idx, self.group_size, row_parallel)
7377

7478
verify_marlin_supports_shape(
7579
output_size_per_partition=output_size_per_partition,
@@ -123,6 +127,16 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
123127
layer.register_parameter("weight_scale", weight_scale)
124128
layer.register_parameter("weight_shape", weight_shape)
125129

130+
# group index (for activation reordering)
131+
if self.has_g_idx:
132+
weight_g_idx = RowvLLMParameter(data=torch.empty(
133+
input_size_per_partition,
134+
dtype=torch.int32,
135+
),
136+
input_dim=0,
137+
weight_loader=weight_loader)
138+
layer.register_parameter("weight_g_idx", weight_g_idx)
139+
126140
layer.input_size_per_partition = input_size_per_partition
127141
layer.output_size_per_partition = output_size_per_partition
128142
layer.input_size = input_size
@@ -137,9 +151,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
137151
layer.workspace = marlin_make_workspace(
138152
layer.output_size_per_partition, device)
139153

140-
# Act-order not supported in compressed-tensors yet, so set to empty.
141-
layer.g_idx = marlin_make_empty_g_idx(device)
142-
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
154+
# Handle sorting for activation reordering if needed.
155+
if self.has_g_idx:
156+
g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.weight_g_idx)
157+
layer.g_idx_sort_indices = g_idx_sort_indices
158+
replace_tensor(layer, "weight_g_idx", g_idx)
159+
else:
160+
layer.weight_g_idx = marlin_make_empty_g_idx(device)
161+
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
143162

144163
# No zero-point
145164
layer.weight_zp = marlin_make_empty_g_idx(device)
@@ -159,9 +178,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
159178
replace_tensor(layer, "weight_packed", marlin_qweight)
160179

161180
# Permute scales from compressed-tensors format to marlin format.
181+
# scale is required on all partitions if activation reordering
162182
marlin_scales = marlin_permute_scales(
163183
layer.weight_scale,
164-
size_k=layer.input_size_per_partition,
184+
size_k=(layer.input_size
185+
if self.has_g_idx else layer.input_size_per_partition),
165186
size_n=layer.output_size_per_partition,
166187
group_size=layer.group_size)
167188
replace_tensor(layer, "weight_scale", marlin_scales)
@@ -174,7 +195,7 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
174195
weight=layer.weight_packed,
175196
weight_scale=layer.weight_scale,
176197
weight_zp=layer.weight_zp,
177-
g_idx=layer.g_idx,
198+
g_idx=layer.weight_g_idx,
178199
g_idx_sort_indices=layer.g_idx_sort_indices,
179200
workspace=layer.workspace,
180201
wtype=self.quant_type,

vllm/model_executor/layers/quantization/compressed_tensors/utils.py

+28-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import re
22
from enum import Enum
3-
from typing import Any, Dict, Iterable, Optional
3+
from typing import Any, Dict, Iterable, Optional, Union
44

5-
from pydantic import BaseModel, Field
5+
from pydantic import BaseModel, Field, field_validator
66
from torch.nn import Module
77

88
from vllm.model_executor.layers.quantization.utils.quant_utils import (
@@ -40,6 +40,19 @@ class QuantizationStrategy(str, Enum):
4040
TOKEN = "token"
4141

4242

43+
class ActivationOrdering(str, Enum):
44+
"""
45+
Enum storing strategies for activation ordering
46+
47+
Group: reorder groups and weight\n
48+
Weight: only reorder weight, not groups. Slightly lower latency and
49+
accuracy compared to group actorder\n
50+
"""
51+
52+
GROUP = "group"
53+
WEIGHT = "weight"
54+
55+
4356
class QuantizationArgs(BaseModel):
4457
"""
4558
User facing arguments used to define a quantization config
@@ -58,6 +71,8 @@ class QuantizationArgs(BaseModel):
5871
observed with every sample. Defaults to False for static
5972
quantization. Note that enabling dynamic quantization
6073
will change the default observer to a memoryless one
74+
:param actorder: whether to apply group quantization in decreasing order of
75+
activation. Defaults to None for arbitrary ordering
6176
"""
6277

6378
num_bits: int = 8
@@ -67,6 +82,7 @@ class QuantizationArgs(BaseModel):
6782
strategy: Optional[QuantizationStrategy] = None
6883
block_structure: Optional[str] = None
6984
dynamic: bool = False
85+
actorder: Union[ActivationOrdering, bool, None] = None
7086
observer: str = Field(
7187
default="minmax",
7288
description=("The class to use to compute the quantization param - "
@@ -79,6 +95,16 @@ class QuantizationArgs(BaseModel):
7995
"Observers constructor excluding quantization range or symmetry"),
8096
)
8197

98+
@field_validator("actorder", mode="before")
99+
def validate_actorder(cls, value) -> Optional[ActivationOrdering]:
100+
if isinstance(value, bool):
101+
return ActivationOrdering.GROUP if value else None
102+
103+
if isinstance(value, str):
104+
return ActivationOrdering(value.lower())
105+
106+
return value
107+
82108

83109
def is_activation_quantization_format(format: str) -> bool:
84110
_ACTIVATION_QUANTIZATION_FORMATS = [

0 commit comments

Comments
 (0)