|
| 1 | +import warnings |
| 2 | +from collections import defaultdict |
| 3 | +from functools import partial |
| 4 | +from typing import Any, Dict, List, Optional, Tuple, Union |
| 5 | + |
| 6 | +import numpy |
| 7 | +import torch |
| 8 | +from loguru import logger |
| 9 | +from pydantic import Field, field_validator, model_validator |
| 10 | + |
| 11 | +from llmcompressor.core import State |
| 12 | +from llmcompressor.modifiers import Modifier |
| 13 | +from llmcompressor.modifiers.utils.hooks import HooksMixin |
| 14 | +from llmcompressor.pipelines.basic import run_pipeline as run_basic |
| 15 | +from llmcompressor.pipelines.layer_sequential import ( |
| 16 | + run_pipeline as run_layer_sequential, |
| 17 | +) |
| 18 | +from llmcompressor.pipelines.sequential import run_pipeline as run_sequential |
| 19 | +from llmcompressor.utils.pytorch.module import ( |
| 20 | + get_layers, |
| 21 | + get_no_split_params, |
| 22 | + get_prunable_layers, |
| 23 | +) |
| 24 | + |
| 25 | + |
| 26 | +class SparsityModifierMixin(HooksMixin): |
| 27 | + # modifier arguments |
| 28 | + sparsity: Optional[Union[float, List[float]]] = None |
| 29 | + sparsity_profile: Optional[str] = None |
| 30 | + mask_structure: str = "0:0" |
| 31 | + owl_m: Optional[int] = None |
| 32 | + owl_lmbda: Optional[float] = None |
| 33 | + |
| 34 | + # data pipeline arguments |
| 35 | + sequential_update: Optional[bool] = False # deprecated |
| 36 | + sequential_targets: Union[str, List[str], None] = None |
| 37 | + targets: Union[str, List[str], None] = None # alias sequential_targets |
| 38 | + ignore: List[str] = Field(default_factory=list) |
| 39 | + |
| 40 | + @field_validator("sequential_update", mode="before") |
| 41 | + def validate_sequential_update(cls, value: bool) -> bool: |
| 42 | + if not value: |
| 43 | + warnings.warn( |
| 44 | + "`sequential_update=False` is no longer supported, setting " |
| 45 | + "sequential_update=True", |
| 46 | + DeprecationWarning, |
| 47 | + ) |
| 48 | + |
| 49 | + return True |
| 50 | + |
| 51 | + @field_validator("sparsity_profile", mode="before") |
| 52 | + def validate_sparsity_profile(cls, value: Optional[str]) -> bool: |
| 53 | + if value is None: |
| 54 | + return value |
| 55 | + |
| 56 | + value = value.lower() |
| 57 | + |
| 58 | + profile_options = ["owl"] |
| 59 | + if value not in profile_options: |
| 60 | + raise ValueError(f"Please choose profile from {profile_options}") |
| 61 | + |
| 62 | + return value |
| 63 | + |
| 64 | + @model_validator(mode="after") |
| 65 | + def validate_model_after(model: "Modifier") -> "Modifier": |
| 66 | + sparsity = model.sparsity |
| 67 | + profile = model.sparsity_profile |
| 68 | + owl_m = model.owl_m |
| 69 | + owl_lmbda = model.owl_lmbda |
| 70 | + mask_structure = model.mask_structure |
| 71 | + targets = model.targets |
| 72 | + sequential_targets = model.sequential_targets |
| 73 | + |
| 74 | + if profile == "owl" and ((owl_m is not None) ^ (owl_lmbda is not None)): |
| 75 | + raise ValueError("Must provide both `owl_m` and `owl_lmbda` or neither") |
| 76 | + |
| 77 | + if profile != "owl" and (owl_m is not None or owl_lmbda is not None): |
| 78 | + raise ValueError("Must provide both `owl_m` and `owl_lmbda`") |
| 79 | + |
| 80 | + if owl_m is not None and sparsity is not None: |
| 81 | + raise ValueError("Cannot provide both sparsity and owl parameters") |
| 82 | + |
| 83 | + if targets is not None: |
| 84 | + if sequential_targets is not None: |
| 85 | + raise ValueError("Cannot use both `targets` and `sequential_targets`") |
| 86 | + model.sequential_targets = targets |
| 87 | + model.targets = None |
| 88 | + |
| 89 | + model._prune_n, model._prune_m = model._split_mask_structure(mask_structure) |
| 90 | + |
| 91 | + return model |
| 92 | + |
| 93 | + def on_initialize(self, state: "State", **kwargs) -> bool: |
| 94 | + """ |
| 95 | + Initialize and run the OBCQ algorithm on the current state |
| 96 | +
|
| 97 | + :param state: session state storing input model and calibration data |
| 98 | + """ |
| 99 | + model = state.model |
| 100 | + dataloader = state.data.calib |
| 101 | + |
| 102 | + # infer module and sequential targets |
| 103 | + self.sequential_targets = self._infer_sequential_targets(model) |
| 104 | + |
| 105 | + # infer layer sparsities |
| 106 | + if self.sparsity_profile == "owl": |
| 107 | + logger.info( |
| 108 | + "Using OWL to infer target layer-wise sparsities from " |
| 109 | + f"{len(dataloader) if dataloader else 0} calibration samples..." |
| 110 | + ) |
| 111 | + self.sparsity = self._infer_owl_layer_sparsity() |
| 112 | + |
| 113 | + # get layers and validate sparsity |
| 114 | + layers = get_layers(self.sequential_targets, model) |
| 115 | + if isinstance(self.sparsity, (list, dict)) and len(layers) != len( |
| 116 | + self.sparsity |
| 117 | + ): |
| 118 | + raise ValueError( |
| 119 | + f"{self.__repr_name__} was initialized with {len(self.sparsity)} " |
| 120 | + f"sparsities values, but model only has {len(layers)} layers" |
| 121 | + ) |
| 122 | + |
| 123 | + # register hooks |
| 124 | + for index, (name, layer) in enumerate(layers.items()): |
| 125 | + if isinstance(self.sparsity, dict): |
| 126 | + layer_sparsity = self.sparsity[name] |
| 127 | + elif isinstance(self.sparsity, list): |
| 128 | + layer_sparsity = self.sparsity[index] |
| 129 | + else: |
| 130 | + layer_sparsity = self.sparsity |
| 131 | + |
| 132 | + for name, module in get_prunable_layers(layer).items(): |
| 133 | + self._module_names[module] = name |
| 134 | + self._module_sparsities[module] = layer_sparsity |
| 135 | + self.register_hook(module, self.calibrate_module, "forward") |
| 136 | + |
| 137 | + # infer and run pipeline |
| 138 | + model_name = state.model.__class__.__name__ |
| 139 | + input_names = dataloader.dataset.column_names |
| 140 | + unfixable_errors = (torch.OutOfMemoryError, torch._C._LinAlgError) |
| 141 | + try: |
| 142 | + run_sequential( |
| 143 | + state.model, |
| 144 | + state.data.calib, |
| 145 | + self.sequential_targets, |
| 146 | + self.ignore, |
| 147 | + self, |
| 148 | + ) |
| 149 | + return True |
| 150 | + |
| 151 | + except Exception as exception: |
| 152 | + if isinstance(exception, torch.fx.proxy.TraceError): |
| 153 | + warnings.warn(f"Failed to trace {model_name} with inputs {input_names}") |
| 154 | + if isinstance(exception, unfixable_errors): |
| 155 | + raise exception |
| 156 | + |
| 157 | + warnings.warn("Falling back to layer_sequential pipeline") |
| 158 | + try: |
| 159 | + run_layer_sequential( |
| 160 | + state.model, |
| 161 | + state.data.calib, |
| 162 | + self.sequential_targets, |
| 163 | + self, |
| 164 | + ) |
| 165 | + return True |
| 166 | + |
| 167 | + except Exception as exception: |
| 168 | + if isinstance(exception, TypeError): |
| 169 | + warnings.warn(f"{model_name} fails layer-wise assumptions") |
| 170 | + if isinstance(exception, unfixable_errors): |
| 171 | + raise exception |
| 172 | + |
| 173 | + warnings.warn( |
| 174 | + "Falling back to basic pipeline, which requires extra memory and " |
| 175 | + "may result in decreased accuracy" |
| 176 | + ) |
| 177 | + run_basic(state.model, state.data.calib, self) |
| 178 | + return True |
| 179 | + |
| 180 | + return True |
| 181 | + |
| 182 | + def _infer_sequential_targets( |
| 183 | + self, model: torch.nn.Module |
| 184 | + ) -> Union[str, List[str]]: |
| 185 | + if self.sequential_targets is None: |
| 186 | + return get_no_split_params(model) |
| 187 | + if isinstance(self.sequential_targets, str): |
| 188 | + return [self.sequential_targets] |
| 189 | + return self.sequential_targets |
| 190 | + |
| 191 | + def _infer_owl_layer_sparsity(self, activations): |
| 192 | + groups = {} |
| 193 | + for name, layer in self.compressible_layers_.items(): |
| 194 | + prunable_layers = get_prunable_layers(layer) |
| 195 | + z = [ |
| 196 | + m.weight.abs() * activations[f"{name}.{n}"].unsqueeze(0) |
| 197 | + for n, m in prunable_layers.items() |
| 198 | + ] |
| 199 | + groups[name] = torch.cat([item.flatten().cpu() for item in z]) |
| 200 | + |
| 201 | + del activations |
| 202 | + |
| 203 | + outlier_ratios = {} |
| 204 | + for group in groups: |
| 205 | + threshold = torch.mean(groups[group]) * self.owl_m |
| 206 | + outlier_ratios[group] = ( |
| 207 | + 100 * (groups[group] > threshold).sum().item() / groups[group].numel() |
| 208 | + ) |
| 209 | + outlier_ratios_arr = numpy.array([outlier_ratios[k] for k in outlier_ratios]) |
| 210 | + for k in outlier_ratios: |
| 211 | + outlier_ratios[k] = (outlier_ratios[k] - outlier_ratios_arr.min()) * ( |
| 212 | + 1 |
| 213 | + / (outlier_ratios_arr.max() - outlier_ratios_arr.min()) |
| 214 | + * self.owl_lmbda |
| 215 | + * 2 |
| 216 | + ) |
| 217 | + outlier_ratios_arr = numpy.array([outlier_ratios[k] for k in outlier_ratios]) |
| 218 | + sparsities = { |
| 219 | + k: 1 |
| 220 | + - ( |
| 221 | + outlier_ratios[k] |
| 222 | + - numpy.mean(outlier_ratios_arr) |
| 223 | + + (1 - float(self.sparsity)) |
| 224 | + ) |
| 225 | + for k in outlier_ratios |
| 226 | + } |
| 227 | + logger.info(f"OWL sparsities for sp={self.sparsity} are:") |
| 228 | + for k in sparsities: |
| 229 | + logger.info(f"Sparsity for {k}: {sparsities[k]}") |
| 230 | + return sparsities |
| 231 | + |
| 232 | + def _get_activations(self, model, dataloader, nsamples=128) -> Dict[str, int]: |
| 233 | + acts = defaultdict(int) |
| 234 | + |
| 235 | + def save_acts(_module, input: Union[Tuple[Any, ...], torch.Tensor], name: str): |
| 236 | + nonlocal acts |
| 237 | + if isinstance(input, tuple): |
| 238 | + input = input[0] |
| 239 | + acts[name] += 1.0 / nsamples * input.pow(2).sum(dim=(0, 1)).sqrt() |
| 240 | + |
| 241 | + hooks = set( |
| 242 | + self.register_hook(mod, partial(save_acts, name=name), "forward_pre") |
| 243 | + for name, mod in model.named_modules() |
| 244 | + if isinstance(mod, torch.nn.Linear) and "lm_head" not in name |
| 245 | + ) |
| 246 | + with HooksMixin.disable_hooks(keep=hooks): |
| 247 | + run_basic(model, dataloader) |
| 248 | + self.remove_hooks(hooks) |
| 249 | + |
| 250 | + return acts |
| 251 | + |
| 252 | + def _split_mask_structure(self, mask_structure: str) -> Tuple[int, int]: |
| 253 | + n, m = mask_structure.split(":") |
| 254 | + return int(n), int(m) |
0 commit comments