5
5
from vllm import _custom_ops as ops
6
6
from vllm .model_executor .layers .quantization .compressed_tensors .schemes import (
7
7
CompressedTensorsScheme )
8
+ from vllm .model_executor .layers .quantization .compressed_tensors .utils import (
9
+ ActivationOrdering )
8
10
from vllm .model_executor .layers .quantization .utils .marlin_utils import (
9
11
apply_gptq_marlin_linear , marlin_make_empty_g_idx , marlin_make_workspace ,
10
- marlin_permute_scales , replace_tensor , verify_marlin_supported ,
12
+ marlin_permute_scales , marlin_repeat_scales_on_all_ranks ,
13
+ marlin_sort_g_idx , replace_tensor , verify_marlin_supported ,
11
14
verify_marlin_supports_shape )
12
15
from vllm .model_executor .parameter import (BasevLLMParameter ,
13
16
ChannelQuantScaleParameter ,
14
17
GroupQuantScaleParameter ,
15
- PackedvLLMParameter )
18
+ PackedvLLMParameter ,
19
+ RowvLLMParameter )
16
20
from vllm .scalar_type import scalar_types
17
21
18
22
__all__ = ["CompressedTensorsWNA16" ]
@@ -28,11 +32,13 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
28
32
def __init__ (self ,
29
33
strategy : str ,
30
34
num_bits : int ,
31
- group_size : Optional [int ] = None ):
35
+ group_size : Optional [int ] = None ,
36
+ actorder : Optional [ActivationOrdering ] = None ):
32
37
33
38
self .pack_factor = 32 // num_bits
34
39
self .strategy = strategy
35
40
self .group_size = - 1 if group_size is None else group_size
41
+ self .has_g_idx = actorder == ActivationOrdering .GROUP
36
42
37
43
if self .group_size == - 1 and self .strategy != "channel" :
38
44
raise ValueError ("Marlin kernels require group quantization or "
@@ -64,12 +70,10 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
64
70
output_size_per_partition = sum (output_partition_sizes )
65
71
66
72
# If group_size is -1, we are in channelwise case.
67
- channelwise = (self .group_size == - 1 )
68
73
group_size = self .group_size if self .group_size != - 1 else input_size
69
74
row_parallel = (input_size != input_size_per_partition )
70
- # In the case of channelwise quantization, we need to replicate the
71
- # scales across all gpus.
72
- partition_scales = (row_parallel and not channelwise )
75
+ partition_scales = not marlin_repeat_scales_on_all_ranks (
76
+ self .has_g_idx , self .group_size , row_parallel )
73
77
74
78
verify_marlin_supports_shape (
75
79
output_size_per_partition = output_size_per_partition ,
@@ -123,6 +127,16 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
123
127
layer .register_parameter ("weight_scale" , weight_scale )
124
128
layer .register_parameter ("weight_shape" , weight_shape )
125
129
130
+ # group index (for activation reordering)
131
+ if self .has_g_idx :
132
+ weight_g_idx = RowvLLMParameter (data = torch .empty (
133
+ input_size_per_partition ,
134
+ dtype = torch .int32 ,
135
+ ),
136
+ input_dim = 0 ,
137
+ weight_loader = weight_loader )
138
+ layer .register_parameter ("weight_g_idx" , weight_g_idx )
139
+
126
140
layer .input_size_per_partition = input_size_per_partition
127
141
layer .output_size_per_partition = output_size_per_partition
128
142
layer .input_size = input_size
@@ -137,9 +151,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
137
151
layer .workspace = marlin_make_workspace (
138
152
layer .output_size_per_partition , device )
139
153
140
- # Act-order not supported in compressed-tensors yet, so set to empty.
141
- layer .g_idx = marlin_make_empty_g_idx (device )
142
- layer .g_idx_sort_indices = marlin_make_empty_g_idx (device )
154
+ # Handle sorting for activation reordering if needed.
155
+ if self .has_g_idx :
156
+ g_idx , g_idx_sort_indices = marlin_sort_g_idx (layer .weight_g_idx )
157
+ layer .g_idx_sort_indices = g_idx_sort_indices
158
+ replace_tensor (layer , "weight_g_idx" , g_idx )
159
+ else :
160
+ layer .weight_g_idx = marlin_make_empty_g_idx (device )
161
+ layer .g_idx_sort_indices = marlin_make_empty_g_idx (device )
143
162
144
163
# No zero-point
145
164
layer .weight_zp = marlin_make_empty_g_idx (device )
@@ -159,9 +178,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
159
178
replace_tensor (layer , "weight_packed" , marlin_qweight )
160
179
161
180
# Permute scales from compressed-tensors format to marlin format.
181
+ # scale is required on all partitions if activation reordering
162
182
marlin_scales = marlin_permute_scales (
163
183
layer .weight_scale ,
164
- size_k = layer .input_size_per_partition ,
184
+ size_k = (layer .input_size
185
+ if self .has_g_idx else layer .input_size_per_partition ),
165
186
size_n = layer .output_size_per_partition ,
166
187
group_size = layer .group_size )
167
188
replace_tensor (layer , "weight_scale" , marlin_scales )
@@ -174,7 +195,7 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
174
195
weight = layer .weight_packed ,
175
196
weight_scale = layer .weight_scale ,
176
197
weight_zp = layer .weight_zp ,
177
- g_idx = layer .g_idx ,
198
+ g_idx = layer .weight_g_idx ,
178
199
g_idx_sort_indices = layer .g_idx_sort_indices ,
179
200
workspace = layer .workspace ,
180
201
wtype = self .quant_type ,
0 commit comments