Skip to content

Commit f807a2a

Browse files
kylesayrsdsikka
andauthored
Replace LayerCompressor with HooksMixin (#1038)
## Purpose ## * Remove layer compressor to decouple modifiers from data pipelines * Reduce abstractions * Support VLMs with SparseGPT and Wanda ## Prerequisites ## * #1021 * #1023 * #1068 * #1030 ## Changes ## ### Interface/ Features ### * SparseGPT and Wanda now both support VLM architectures * Added `sequential_targets` to match GPTQ and made `targets` an alias * Support hessian offloading for `SparseGPT` * Add customized `_LinAlgError` for `SparseGPT` ### Implementations ### * Changed implementation styles of `SparseGPTModifier` and `WandaPruningModifier` to match `GPTQModifier` * Removed `LayerCompressor`, `ModuleCompressionWrapper`, `SparseGptWrapper`, and `WandaWrapper` * Shared implementations between SparseGPT and Wanda are implemented by the `SparsityModifierMixin` * Removed lines blocking `allow_tf32` * Maybe @rahul-tuli knows why this was originally implemented, potentially to avoid hardware issues? * This change was only present for wanda. Given that all other modifiers do not have this change, I see no reason why it should stay * Updated sparsegpt tests to reflect new implementation ### Tests ### * Updated obcq tests to reflect new implementations * Removed `test_sgpt_defaults.py` since this test doesn't test anything new or novel about this modifier ## Testing ## * `grep -r "LayerCompressor\|ModuleCompressionWrapper\|SparseGptWrapper\|WandaWrapper" src/ examples/ tests/` * Modified `test_invalid_layerwise_recipes_raise_exceptions` and `test_successful_layerwise_recipe` pass * `llama3_8b_2of4.py` passes and was evaluated with both SparseGPT and Wanda ## Potential Follow ups ## * Add module `targets` and `ignore` to SparseGPT and Wanda ## Regression Testing ## The hessian, row scalar, and compressed weight values were confirmed to be unchanged in the case that of one calibration sample. The final evaluations are different, which is likely due to numerical imprecision (dividing by int vs torch.int), different pipelines (different subgraph partitions => different imprecision from cpu offloading, potentially different module arguments). ### Evaluation Models were compressed using `examples/sparse_2of4_quantization_fp8/llama3_8b_2of4.py` <details><summary>sparsegpt</summary> Main ``` hf (pretrained=/home/ksayers/llm-compressor/old_Llama-3.2-1B-Instruct2of4-sparse,dtype=bfloat16,add_bos_token=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: 1 | Tasks |Version|Filter|n-shot|Metric| |Value | |Stderr| |----------|------:|------|-----:|------|---|-----:|---|-----:| |winogrande| 1|none | 5|acc |? |0.5391|? | 0.014| ``` Branch ``` hf (pretrained=/home/ksayers/llm-compressor/new_Llama-3.2-1B-Instruct2of4-sparse,dtype=bfloat16,add_bos_token=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: 1 | Tasks |Version|Filter|n-shot|Metric| |Value| |Stderr| |----------|------:|------|-----:|------|---|----:|---|-----:| |winogrande| 1|none | 5|acc |? |0.547|? | 0.014| ``` </details> To test wanda, the `SparseGPTModifier` was replaced with the `WandaPruningModifier` <details><summary>wanda</summary> Main ``` hf (pretrained=/home/kyle/old_llm-compressor/Llama-3.2-1B-Instruct2of4-sparse,dtype=bfloat16,add_bos_token=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: 1 | Tasks |Version|Filter|n-shot|Metric| |Value| |Stderr| |----------|------:|------|-----:|------|---|----:|---|-----:| |winogrande| 1|none | 5|acc |↑ |0.532|± | 0.014| ``` Branch ``` hf (pretrained=/home/kyle/llm-compressor/Llama-3.2-1B-Instruct2of4-sparse,dtype=bfloat16,add_bos_token=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: 1 | Tasks |Version|Filter|n-shot|Metric| |Value | |Stderr| |----------|------:|------|-----:|------|---|-----:|---|-----:| |winogrande| 1|none | 5|acc |↑ |0.5414|± | 0.014| ``` </details> --------- Signed-off-by: Kyle Sayers <[email protected]> Co-authored-by: Dipika Sikka <[email protected]>
1 parent 1094c38 commit f807a2a

File tree

22 files changed

+869
-1287
lines changed

22 files changed

+869
-1287
lines changed

src/llmcompressor/modifiers/obcq/base.py

+115-274
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
import warnings
2+
from collections import defaultdict
3+
from functools import partial
4+
from typing import Any, Dict, List, Optional, Tuple, Union
5+
6+
import numpy
7+
import torch
8+
from loguru import logger
9+
from pydantic import Field, field_validator, model_validator
10+
11+
from llmcompressor.core import State
12+
from llmcompressor.modifiers import Modifier
13+
from llmcompressor.modifiers.utils.hooks import HooksMixin
14+
from llmcompressor.pipelines.basic import run_pipeline as run_basic
15+
from llmcompressor.pipelines.layer_sequential import (
16+
run_pipeline as run_layer_sequential,
17+
)
18+
from llmcompressor.pipelines.sequential import run_pipeline as run_sequential
19+
from llmcompressor.utils.pytorch.module import (
20+
get_layers,
21+
get_no_split_params,
22+
get_prunable_layers,
23+
)
24+
25+
26+
class SparsityModifierMixin(HooksMixin):
27+
# modifier arguments
28+
sparsity: Optional[Union[float, List[float]]] = None
29+
sparsity_profile: Optional[str] = None
30+
mask_structure: str = "0:0"
31+
owl_m: Optional[int] = None
32+
owl_lmbda: Optional[float] = None
33+
34+
# data pipeline arguments
35+
sequential_update: Optional[bool] = False # deprecated
36+
sequential_targets: Union[str, List[str], None] = None
37+
targets: Union[str, List[str], None] = None # alias sequential_targets
38+
ignore: List[str] = Field(default_factory=list)
39+
40+
@field_validator("sequential_update", mode="before")
41+
def validate_sequential_update(cls, value: bool) -> bool:
42+
if not value:
43+
warnings.warn(
44+
"`sequential_update=False` is no longer supported, setting "
45+
"sequential_update=True",
46+
DeprecationWarning,
47+
)
48+
49+
return True
50+
51+
@field_validator("sparsity_profile", mode="before")
52+
def validate_sparsity_profile(cls, value: Optional[str]) -> bool:
53+
if value is None:
54+
return value
55+
56+
value = value.lower()
57+
58+
profile_options = ["owl"]
59+
if value not in profile_options:
60+
raise ValueError(f"Please choose profile from {profile_options}")
61+
62+
return value
63+
64+
@model_validator(mode="after")
65+
def validate_model_after(model: "Modifier") -> "Modifier":
66+
sparsity = model.sparsity
67+
profile = model.sparsity_profile
68+
owl_m = model.owl_m
69+
owl_lmbda = model.owl_lmbda
70+
mask_structure = model.mask_structure
71+
targets = model.targets
72+
sequential_targets = model.sequential_targets
73+
74+
if profile == "owl" and ((owl_m is not None) ^ (owl_lmbda is not None)):
75+
raise ValueError("Must provide both `owl_m` and `owl_lmbda` or neither")
76+
77+
if profile != "owl" and (owl_m is not None or owl_lmbda is not None):
78+
raise ValueError("Must provide both `owl_m` and `owl_lmbda`")
79+
80+
if owl_m is not None and sparsity is not None:
81+
raise ValueError("Cannot provide both sparsity and owl parameters")
82+
83+
if targets is not None:
84+
if sequential_targets is not None:
85+
raise ValueError("Cannot use both `targets` and `sequential_targets`")
86+
model.sequential_targets = targets
87+
model.targets = None
88+
89+
model._prune_n, model._prune_m = model._split_mask_structure(mask_structure)
90+
91+
return model
92+
93+
def on_initialize(self, state: "State", **kwargs) -> bool:
94+
"""
95+
Initialize and run the OBCQ algorithm on the current state
96+
97+
:param state: session state storing input model and calibration data
98+
"""
99+
model = state.model
100+
dataloader = state.data.calib
101+
102+
# infer module and sequential targets
103+
self.sequential_targets = self._infer_sequential_targets(model)
104+
105+
# infer layer sparsities
106+
if self.sparsity_profile == "owl":
107+
logger.info(
108+
"Using OWL to infer target layer-wise sparsities from "
109+
f"{len(dataloader) if dataloader else 0} calibration samples..."
110+
)
111+
self.sparsity = self._infer_owl_layer_sparsity()
112+
113+
# get layers and validate sparsity
114+
layers = get_layers(self.sequential_targets, model)
115+
if isinstance(self.sparsity, (list, dict)) and len(layers) != len(
116+
self.sparsity
117+
):
118+
raise ValueError(
119+
f"{self.__repr_name__} was initialized with {len(self.sparsity)} "
120+
f"sparsities values, but model only has {len(layers)} layers"
121+
)
122+
123+
# register hooks
124+
for index, (name, layer) in enumerate(layers.items()):
125+
if isinstance(self.sparsity, dict):
126+
layer_sparsity = self.sparsity[name]
127+
elif isinstance(self.sparsity, list):
128+
layer_sparsity = self.sparsity[index]
129+
else:
130+
layer_sparsity = self.sparsity
131+
132+
for name, module in get_prunable_layers(layer).items():
133+
self._module_names[module] = name
134+
self._module_sparsities[module] = layer_sparsity
135+
self.register_hook(module, self.calibrate_module, "forward")
136+
137+
# infer and run pipeline
138+
model_name = state.model.__class__.__name__
139+
input_names = dataloader.dataset.column_names
140+
unfixable_errors = (torch.OutOfMemoryError, torch._C._LinAlgError)
141+
try:
142+
run_sequential(
143+
state.model,
144+
state.data.calib,
145+
self.sequential_targets,
146+
self.ignore,
147+
self,
148+
)
149+
return True
150+
151+
except Exception as exception:
152+
if isinstance(exception, torch.fx.proxy.TraceError):
153+
warnings.warn(f"Failed to trace {model_name} with inputs {input_names}")
154+
if isinstance(exception, unfixable_errors):
155+
raise exception
156+
157+
warnings.warn("Falling back to layer_sequential pipeline")
158+
try:
159+
run_layer_sequential(
160+
state.model,
161+
state.data.calib,
162+
self.sequential_targets,
163+
self,
164+
)
165+
return True
166+
167+
except Exception as exception:
168+
if isinstance(exception, TypeError):
169+
warnings.warn(f"{model_name} fails layer-wise assumptions")
170+
if isinstance(exception, unfixable_errors):
171+
raise exception
172+
173+
warnings.warn(
174+
"Falling back to basic pipeline, which requires extra memory and "
175+
"may result in decreased accuracy"
176+
)
177+
run_basic(state.model, state.data.calib, self)
178+
return True
179+
180+
return True
181+
182+
def _infer_sequential_targets(
183+
self, model: torch.nn.Module
184+
) -> Union[str, List[str]]:
185+
if self.sequential_targets is None:
186+
return get_no_split_params(model)
187+
if isinstance(self.sequential_targets, str):
188+
return [self.sequential_targets]
189+
return self.sequential_targets
190+
191+
def _infer_owl_layer_sparsity(self, activations):
192+
groups = {}
193+
for name, layer in self.compressible_layers_.items():
194+
prunable_layers = get_prunable_layers(layer)
195+
z = [
196+
m.weight.abs() * activations[f"{name}.{n}"].unsqueeze(0)
197+
for n, m in prunable_layers.items()
198+
]
199+
groups[name] = torch.cat([item.flatten().cpu() for item in z])
200+
201+
del activations
202+
203+
outlier_ratios = {}
204+
for group in groups:
205+
threshold = torch.mean(groups[group]) * self.owl_m
206+
outlier_ratios[group] = (
207+
100 * (groups[group] > threshold).sum().item() / groups[group].numel()
208+
)
209+
outlier_ratios_arr = numpy.array([outlier_ratios[k] for k in outlier_ratios])
210+
for k in outlier_ratios:
211+
outlier_ratios[k] = (outlier_ratios[k] - outlier_ratios_arr.min()) * (
212+
1
213+
/ (outlier_ratios_arr.max() - outlier_ratios_arr.min())
214+
* self.owl_lmbda
215+
* 2
216+
)
217+
outlier_ratios_arr = numpy.array([outlier_ratios[k] for k in outlier_ratios])
218+
sparsities = {
219+
k: 1
220+
- (
221+
outlier_ratios[k]
222+
- numpy.mean(outlier_ratios_arr)
223+
+ (1 - float(self.sparsity))
224+
)
225+
for k in outlier_ratios
226+
}
227+
logger.info(f"OWL sparsities for sp={self.sparsity} are:")
228+
for k in sparsities:
229+
logger.info(f"Sparsity for {k}: {sparsities[k]}")
230+
return sparsities
231+
232+
def _get_activations(self, model, dataloader, nsamples=128) -> Dict[str, int]:
233+
acts = defaultdict(int)
234+
235+
def save_acts(_module, input: Union[Tuple[Any, ...], torch.Tensor], name: str):
236+
nonlocal acts
237+
if isinstance(input, tuple):
238+
input = input[0]
239+
acts[name] += 1.0 / nsamples * input.pow(2).sum(dim=(0, 1)).sqrt()
240+
241+
hooks = set(
242+
self.register_hook(mod, partial(save_acts, name=name), "forward_pre")
243+
for name, mod in model.named_modules()
244+
if isinstance(mod, torch.nn.Linear) and "lm_head" not in name
245+
)
246+
with HooksMixin.disable_hooks(keep=hooks):
247+
run_basic(model, dataloader)
248+
self.remove_hooks(hooks)
249+
250+
return acts
251+
252+
def _split_mask_structure(self, mask_structure: str) -> Tuple[int, int]:
253+
n, m = mask_structure.split(":")
254+
return int(n), int(m)

0 commit comments

Comments
 (0)