-
Notifications
You must be signed in to change notification settings - Fork 111
/
Copy pathbase.py
353 lines (295 loc) · 13.7 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
from compressed_tensors.utils.offload import is_module_offloaded
from loguru import logger
from torch.nn import Module
from llmcompressor.core import State
from llmcompressor.modifiers import Modifier
from llmcompressor.modifiers.smoothquant.utils import (
get_layer_mappings_from_architecture,
handle_mapping_resolution_errors,
)
from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward
from llmcompressor.utils.fsdp.helpers import get_fsdp_parent
from llmcompressor.utils.pytorch.module import (
get_layers,
get_matching_layer,
match_targets,
)
MINIMUM_SMOOTHING_SCALE = 1e-5
__all__ = ["SmoothQuantScale", "SmoothQuantMapping", "SmoothQuantModifier"]
@dataclass
class SmoothQuantScale:
"""
Dataclass for storing the channel-wise minimum and maximum values for a layer. This
is updated each forward pass during calibration
:param min_channel_vals: minimum output value seen so far, per channel
:param max_channel_vals: maximum output value seen so far, per channel
"""
min_channel_vals: torch.Tensor
max_channel_vals: torch.Tensor
@dataclass
class SmoothQuantMapping:
"""
Dataclass for storing the mapping between an activation layer and the following
weights that must be balanced during smoothing
:param smooth_name: name of the activation layer
:param smooth_layer: PyTorch module storing the activation layer
:param balance_layers: list of PyTorch modules that smooth_layer feeds into, must be
balanced to offset the smoothing of smooth_layer
"""
smooth_name: str
smooth_layer: Module
balance_layers: List[Module]
class SmoothQuantModifier(Modifier):
"""
Implements the SmoothQuant algorithm from https://arxiv.org/abs/2211.10438. This
modifier performs a channel-wise smoothing of outliers in activations, making them
easier to quantize by reducing the dynamic range. The smoothing is offset by
applying the inverse operation to the next layer of weights, making the weights
slightly more difficult to quantize.
Because this modifier manipulates the weights of the model, it can only be used in
in one-shot and not during training. Activation ranges are determined by running a
small set of calibration data through the model.
example recipe:
```yaml
SmoothQuantModifier:
smoothing_strength: 0.5
mappings: [
[["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], "re:.*self_attn_layer_norm"],
[["re:.*fc1"], "re:.*final_layer_norm"]
]
ignore: ["model.decoder.final_layer_norm"]
```
:param smoothing_strength: alpha, intensity of smoothing to perform (0-1 range)
:param mappings: list activation layers to smooth, and which layers to
scale the output such that activations are smoothed.
Each entry of the mapping list should be a list itself, in which the first
entry is a list of layers who share the same input activation (the one to be
to smoothed) and the second entry is the layer whose output is scaled to
achieve the smoothing. If regex is used, it matches layers with the largest
overlap in module name. If not supplied the argument will be inferred from the
model architecture.
:param ignore: list of layers to ignore, even if they match a regex in mappings.
It should match the name of layers whose outputs are scaled to achieve
smoothing (the second entry of the mappings list).
:param num_calibration_steps: number of samples to use for calibration, or None to
use the whole dataset
:param calibration_function: optional function to use for the forward pass, or None
to use the default tensor_module_forward
"""
smoothing_strength: float = 0.5
mappings: Optional[List[Union[Tuple, List]]] = None
ignore: Optional[List[str]] = None
num_calibration_steps: Optional[int] = None
calibration_function: Optional[Callable] = None
resolved_mappings_: Optional[List] = None
scales_: Optional[Dict] = None
def on_initialize(self, state: State, **kwargs) -> bool:
"""
Initialize and run SmoothQuant on the given state
:param state: state to run SmoothQuant on
:return: True on a successful run, False otherwise
"""
if self.end and self.end != -1:
raise ValueError(
f"{self.__class__.__name__} can only be applied during one-shot. "
f" Expected end to be None or -1, got {self.end}"
)
if self.start and self.start != -1:
raise ValueError(
f"{self.__class__.__name__} can only be applied during one-shot. "
f"Expected start to be None or -1, got {self.end}"
)
self.ignore = [] if not self.ignore else self.ignore
self.mappings = self._infer_mappings_from_model(state.model)
self.resolved_mappings_ = self._resolve_mappings(state.model)
self.scales_ = {}
calibration_dataloader = state.data.calib
self._setup_scale_hooks()
self._calibrate(state.model, calibration_dataloader)
self._apply_smoothing(state.model)
return True
def on_finalize(self, state: State, **kwargs) -> bool:
"""
Clean up by clearing the scale and mapping data
:param state: unused
:return: True
"""
if self.scales_ is not None:
self.scales_.clear()
if self.resolved_mappings_ is not None:
self.resolved_mappings_.clear()
return True
def _infer_mappings_from_model(
self,
model: Module,
) -> List[Tuple]:
if self.mappings is not None:
return self.mappings
logger.info("No SmoothQuantModifier.mappings provided, inferring from model...")
return get_layer_mappings_from_architecture(
architecture=model.__class__.__name__
)
@handle_mapping_resolution_errors
def _resolve_mappings(self, model: Module) -> List:
"""
Transforms the list of activations to smooth and their corresponding weights
into SmoothQuantMapping objects, resolving regular expressions.
For each activation in the mapping list, we find the corresponding weight to
balance by searching for the longest substring. For instance, if our balance
weight is ".*re:.*q_proj" and the activation is "re:.*self_attn_layer_norm" we
would match model.layer.0.p_proj to model.layer.0.self_attn_layer_norm and
repeat for model.layer.1 and so on
"""
resolved_mappings = []
for to_balance, to_smooth in self.mappings:
to_smooth_layers = get_layers(to_smooth, model)
for layer_name, smooth_layer in to_smooth_layers.items():
if not match_targets(layer_name, self.ignore)[0]:
balance_layers = []
for balance_suffix in to_balance:
# find the submodule that matches the activation layer
_, balance_layer = get_matching_layer(
balance_suffix, layer_name, model
)
if balance_layer:
balance_layers.append(balance_layer)
# each mapping can contain multiple layers to balance, but only
# one layer to smooth
mapping = SmoothQuantMapping(
layer_name, smooth_layer, balance_layers
)
resolved_mappings.append(mapping)
return resolved_mappings
def _setup_scale_hooks(self):
"""
Attach a forward hook to each activation we want to smooth. This allows us to
calculate the dynamic range during calibration
"""
def create_hook_fn(layer_name):
def hook_fn(module, inp, out):
# update the per-channel min/max output values seen during calibration
if isinstance(out, tuple):
out = out[0]
hidden_dim = out.shape[-1]
out = out.view(-1, hidden_dim)
latest_mins = torch.min(out, dim=0)[0]
latest_maxes = torch.max(out, dim=0)[0]
if layer_name in self.scales_:
self.scales_[layer_name].min_channel_vals = torch.minimum(
self.scales_[layer_name].min_channel_vals, latest_mins
)
self.scales_[layer_name].max_channel_vals = torch.maximum(
self.scales_[layer_name].max_channel_vals, latest_maxes
)
else:
self.scales_[layer_name] = SmoothQuantScale(
min_channel_vals=latest_mins, max_channel_vals=latest_maxes
)
return hook_fn
for mapping in self.resolved_mappings_:
name = mapping.smooth_name
layer = mapping.smooth_layer
self.register_hook(layer, create_hook_fn(name), "forward")
@torch.no_grad()
def _calibrate(self, model: Module, calibration_dataloader: List):
"""
Catch the output dynamic ranges of each layer that will be smoothed by running
forward passes with calibration_dataloader
"""
class_name = self.__class__.__name__.replace("PyTorch", "")
logger.info(
f"Running {class_name} calibration with "
f"{len(calibration_dataloader)} samples..."
)
if not calibration_dataloader:
raise ValueError(
"Calibration data loader not set, must populate the calib_data field of"
" CompressionSession to run the SmoothQuant modifier"
)
run_calibration_forward(
model,
calibration_dataloader,
self.num_calibration_steps,
self.calibration_function,
)
# remove the hooks now that we are done calibrating
self.remove_hooks()
@torch.no_grad()
def _apply_smoothing(self, model: Module):
"""
After calibration, apply smoothing to the activations and push the transform
into the following weights by applying the inverse to each balance weight.
Y = (Xdiag(scales)^(-1) * diag(scales)W) where W is the to_balance weights and
X is the to_smooth weights
This modifies the weights of the model in-place.
"""
logger.info("Smoothing activation scales...")
for mapping in self.resolved_mappings_:
activation_scales = ( # get dynamic range for each activation channel
self.scales_[mapping.smooth_name].max_channel_vals
- self.scales_[mapping.smooth_name].min_channel_vals
)
smooth_layer = mapping.smooth_layer
balance_layers = mapping.balance_layers
scales = self._calculate_smoothing_scales(balance_layers, activation_scales)
scales = torch.maximum(
scales, torch.Tensor([MINIMUM_SMOOTHING_SCALE]).to(scales.device)
)
@torch.no_grad()
def smooth(module):
offloaded = is_module_offloaded(module)
if offloaded:
module._hf_hook.pre_forward(module)
if module in balance_layers:
module.weight.mul_(scales.view(1, -1))
elif module == smooth_layer:
if module.weight.ndim == 1:
module.weight.div_(scales)
else:
module.weight.div_(scales.view(-1, 1))
if hasattr(module, "bias") and module.bias is not None:
module.bias.div_(scales)
if offloaded:
module._hf_hook.post_forward(module, None)
parent = get_fsdp_parent(mapping.smooth_name, model)
if parent is not None:
parent.apply(smooth)
else:
# if we're not running with FSDP we can apply smoothing directly
for layer in balance_layers:
smooth(layer)
smooth(smooth_layer)
# clear out allocated smoothing scales
torch.cuda.empty_cache()
def _calculate_smoothing_scales(
self, balance_layers: List[Module], activation_scales: torch.Tensor
) -> List[float]:
"""
Calculate how much smoothing to apply to each channel based on the dynamic
range of the activation and the following weights
:param balance_layers: layers to offset activation smoothing to
:param activation_scales: channel-wise dynamic range of activations to smooth
:return: channel-wise scales to use for smoothing activations
"""
# get the channel-wise dynamic range for each layer to be balanced
weight_scales = []
for layer in balance_layers:
offloaded = is_module_offloaded(layer)
if offloaded:
layer._hf_hook.pre_forward(layer)
scale = layer.weight.abs().max(dim=0, keepdim=True)[0]
weight_scales.append(scale)
if offloaded:
layer._hf_hook.post_forward(layer, None)
weight_scales = 2.0 * torch.cat(weight_scales, dim=0).max(dim=0)[0]
# calculate the amount of smoothing to apply
# s_j = max(|X_j|)^alpha / max(|W_j|)^(1-alpha)
# where j is the input channel, alpha is smoothing strength
scales = activation_scales.pow(self.smoothing_strength) / weight_scales.pow(
1 - self.smoothing_strength
)
scales = torch.where(weight_scales > 0.0, scales, activation_scales)
return scales