Skip to content

Commit 16e6435

Browse files
authored
[Decompression] Update Decompression Lifecycle (#285)
* update decompression * update * update * update * update * remove extra ToDos * remove breakpoints * clean-up; PR comments * PR comments
1 parent ed3ac7c commit 16e6435

File tree

12 files changed

+217
-62
lines changed

12 files changed

+217
-62
lines changed

src/compressed_tensors/compressors/base.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from compressed_tensors.config import SparsityCompressionConfig
2020
from compressed_tensors.quantization import QuantizationArgs, QuantizationConfig
2121
from compressed_tensors.registry import RegistryMixin
22+
from compressed_tensors.utils import has_offloaded_params
2223
from torch import Tensor
2324
from torch.nn import Module
2425

@@ -169,6 +170,10 @@ def decompress_module(self, module: Module):
169170
:param module: PyTorch module to decompress
170171
:return: tensor of the decompressed weight, or None if module is not quantized
171172
"""
173+
174+
params_device = next(module.parameters()).device
175+
device = "cpu" if has_offloaded_params(module) else params_device
176+
172177
if not hasattr(module, "quantization_scheme"):
173178
return None # module is not quantized
174179
quantization_scheme = module.quantization_scheme
@@ -182,7 +187,7 @@ def decompress_module(self, module: Module):
182187

183188
return self.decompress_weight(
184189
compressed_data=compressed_data, quantization_args=quantization_args
185-
)
190+
).to(device)
186191

187192
def decompress_weight(
188193
self, compressed_data: Dict[str, Tensor], **kwargs

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 89 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,14 @@
3131
SPARSITY_CONFIG_NAME,
3232
)
3333
from compressed_tensors.compressors.base import BaseCompressor
34+
from compressed_tensors.compressors.sparse_compressors import DenseCompressor
3435
from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
3536
from compressed_tensors.quantization import (
3637
DEFAULT_QUANTIZATION_METHOD,
3738
QuantizationConfig,
3839
QuantizationStatus,
3940
apply_quantization_config,
40-
load_pretrained_quantization,
41+
load_pretrained_quantization_parameters,
4142
)
4243
from compressed_tensors.quantization.lifecycle import expand_target_names
4344
from compressed_tensors.quantization.quant_args import QuantizationArgs
@@ -47,7 +48,9 @@
4748
)
4849
from compressed_tensors.utils import (
4950
get_safetensors_folder,
51+
has_offloaded_params,
5052
merge_names,
53+
register_offload_parameter,
5154
update_parameter_data,
5255
)
5356
from compressed_tensors.utils.helpers import (
@@ -412,6 +415,13 @@ def decompress(self, model_path: str, model: Module):
412415
413416
:param model_path: path to compressed weights
414417
:param model: pytorch model to load decompressed weights into
418+
419+
Note: decompress makes use of both _replace_sparsity_weights and _replace_weights
420+
The variations in these methods are a result of the subtle variations between the sparsity
421+
and quantization compressors. Specifically, quantization compressors return not just the
422+
decompressed weight, but the quantization parameters (e.g scales, zero_point) whereas sparsity
423+
compressors only return the decompressed weight.
424+
415425
"""
416426
model_path = get_safetensors_folder(model_path)
417427
sparse_decompressed = False
@@ -420,9 +430,16 @@ def decompress(self, model_path: str, model: Module):
420430
self.sparsity_compressor is not None
421431
and self.sparsity_config.format != CompressionFormat.dense.value
422432
):
433+
params_to_ignore = None
434+
if self.quantization_compressor is not None:
435+
params_to_ignore = self.quantization_compressor.compression_param_names
423436
# Sparse decompression is applied on the model_path
424-
dense_gen = self.sparsity_compressor.decompress(model_path)
425-
self._replace_weights(dense_gen, model)
437+
# The compressor will try and load any quantization parameters as well
438+
# params_to_skip_load will skip over quantization params from being loaded
439+
dense_gen = self.sparsity_compressor.decompress(
440+
model_path, params_to_skip_load=params_to_ignore
441+
)
442+
self._replace_sparsity_weights(dense_gen, model)
426443
setattr(model, SPARSITY_CONFIG_NAME, self.sparsity_compressor.config)
427444
sparse_decompressed = True
428445

@@ -431,13 +448,27 @@ def decompress(self, model_path: str, model: Module):
431448
# quantization during apply_quantization_config. This ensures
432449
# that the dtypes of the weights are not unintentionally updated.
433450
# The status is restored after quantization params are loaded.
451+
434452
with override_quantization_status(
435453
self.quantization_config, QuantizationStatus.FROZEN
436454
):
455+
437456
names_to_scheme = apply_quantization_config(
438457
model, self.quantization_config
439458
)
440-
load_pretrained_quantization(model, model_path)
459+
# Load activation scales/zp or any other quantization parameters
460+
# Conditionally load the weight quantization parameters if we have a dense compressor
461+
# Or if a sparsity compressor has already been applied
462+
load_pretrained_quantization_parameters(
463+
model,
464+
model_path,
465+
# TODO: all weight quantization params will be moved to the compressor in a follow-up
466+
# including initialization
467+
load_weight_quantization=(
468+
sparse_decompressed
469+
or isinstance(self.quantization_compressor, DenseCompressor)
470+
),
471+
)
441472

442473
model_path_or_state_dict = (
443474
model.state_dict() if sparse_decompressed else model_path
@@ -446,6 +477,8 @@ def decompress(self, model_path: str, model: Module):
446477
dense_gen = self.quantization_compressor.decompress(
447478
model_path_or_state_dict, names_to_scheme=names_to_scheme
448479
)
480+
# TODO: all weight quantization params will be moved to the compressor
481+
# to prevent duplicate parameter updates in update_parameter_data
449482
self._replace_weights(dense_gen, model)
450483

451484
def freeze_quantization_status(module):
@@ -501,7 +534,7 @@ def update_config(self, save_directory: str):
501534
with open(config_file_path, "w") as config_file:
502535
json.dump(config_data, config_file, indent=2, sort_keys=True)
503536

504-
def _replace_weights(self, dense_weight_generator, model: Module):
537+
def _replace_sparsity_weights(self, dense_weight_generator, model: Module):
505538
"""
506539
Replace the weights of the model with the
507540
provided dense weights.
@@ -516,11 +549,60 @@ def _replace_weights(self, dense_weight_generator, model: Module):
516549
:param model: The model whose weights are to be updated.
517550
"""
518551
for name, data in tqdm(dense_weight_generator, desc="Decompressing model"):
552+
519553
split_name = name.split(".")
520554
prefix, param_name = ".".join(split_name[:-1]), split_name[-1]
521555
module = operator.attrgetter(prefix)(model)
522-
if hasattr(module, param_name):
523-
update_parameter_data(module, data, param_name)
556+
557+
params_device = next(module.parameters()).device
558+
device = "cpu" if has_offloaded_params(module) else params_device
559+
delattr(module, param_name)
560+
requires_grad = data.dtype in (torch.float16, torch.float32, torch.bfloat16)
561+
param = torch.nn.Parameter(data.to(device), requires_grad=requires_grad)
562+
register_offload_parameter(module, param_name, param)
563+
564+
def _replace_weights(self, dense_weight_generator, model: Module):
565+
"""
566+
Replace the weights of the model with the
567+
provided dense weights.
568+
569+
This method iterates over the dense_weight_generator and
570+
updates the corresponding weights in the model. If a parameter
571+
name does not exist in the model, it will be skipped.
572+
573+
:param dense_weight_generator (generator): A generator that yields
574+
tuples of (name, data), where 'name' is the parameter name and
575+
'data' is the updated param data
576+
:param model: The model whose weights are to be updated.
577+
"""
578+
579+
for name, data in tqdm(dense_weight_generator, desc="Decompressing model"):
580+
module = operator.attrgetter(name)(model)
581+
582+
params_device = next(module.parameters()).device
583+
device = "cpu" if has_offloaded_params(module) else params_device
584+
585+
for param_name, param_data in data.items():
586+
if hasattr(module, param_name):
587+
# If compressed, will have an incorrect dtype for transformers >4.49
588+
# TODO: we can also just skip initialization of scales/zp if in decompression in init
589+
# to be consistent with loading which happens later as well
590+
# however, update_data does a good shape check - should be moved to the compressor
591+
if param_name == "weight":
592+
delattr(module, param_name)
593+
requires_grad = param_data.dtype in (
594+
torch.float16,
595+
torch.float32,
596+
torch.bfloat16,
597+
)
598+
param = torch.nn.Parameter(
599+
param_data.to(device), requires_grad=requires_grad
600+
)
601+
register_offload_parameter(module, param_name, param)
602+
else:
603+
# Should already be registered to the correct device for
604+
# for scales/zero-points
605+
update_parameter_data(module, param_data, param_name)
524606

525607

526608
def map_modules_to_quant_args(

src/compressed_tensors/compressors/quantized_compressors/base.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import logging
1616
from pathlib import Path
17-
from typing import Any, Dict, Generator, Tuple, Union
17+
from typing import Any, Dict, Generator, Optional, Tuple, Union
1818

1919
import torch
2020
from compressed_tensors.compressors.base import BaseCompressor
@@ -199,7 +199,8 @@ def _decompress_from_path(self, path_to_model, names_to_scheme, device):
199199
decompressed = self.decompress_weight(
200200
compressed_data=weight_data, quantization_args=quant_args
201201
)
202-
yield merge_names(weight_name, "weight"), decompressed
202+
weight_data["weight"] = decompressed
203+
yield weight_name, weight_data
203204

204205
def _decompress_from_state_dict(self, state_dict, names_to_scheme):
205206
weight_mappings = get_nested_mappings_from_state_dict(
@@ -215,4 +216,5 @@ def _decompress_from_state_dict(self, state_dict, names_to_scheme):
215216
decompressed = self.decompress_weight(
216217
compressed_data=weight_data, quantization_args=quant_args
217218
)
218-
yield merge_names(weight_name, "weight"), decompressed
219+
weight_data["weight"] = decompressed
220+
yield weight_name, weight_data

src/compressed_tensors/compressors/sparse_compressors/base.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,11 @@ def compress(
9898
return compressed_dict
9999

100100
def decompress(
101-
self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs
101+
self,
102+
path_to_model_or_tensors: str,
103+
device: str = "cpu",
104+
params_to_skip_load: Optional[Tuple] = None,
105+
**kwargs,
102106
) -> Generator[Tuple[str, Tensor], None, None]:
103107
"""
104108
Reads a bitmask compressed state dict located
@@ -108,6 +112,11 @@ def decompress(
108112
:param model_path: path to compressed safetensors model (directory with
109113
one or more safetensors files) or compressed tensors file
110114
:param device: device to load decompressed weights onto
115+
:param params_to_skip_load: a list of non-sparsity parameters (e.g quantization
116+
parameters) that we want to skip loading. As the sparsity compresssor does
117+
not handle quantized decompression, this should contain any quantization
118+
parameters when decompressing stacked compressors. We want these parameters
119+
to be handled by the quantization decompressor
111120
:return: iterator for generating decompressed weights
112121
"""
113122
weight_mappings, ignored_params = get_nested_weight_mappings(
@@ -121,13 +130,21 @@ def decompress(
121130
full_name = merge_names(weight_name, param_name)
122131
with safe_open(safe_path, framework="pt", device=device) as f:
123132
weight_data[param_name] = f.get_tensor(full_name)
133+
124134
decompressed = self.decompress_weight(weight_data)
125135
yield merge_names(weight_name, "weight"), decompressed
126136

127137
for ignored_param_name, safe_path in ignored_params.items():
128-
with safe_open(safe_path, framework="pt", device=device) as f:
129-
value = f.get_tensor(ignored_param_name)
130-
yield ignored_param_name, value
138+
should_skip = False
139+
if params_to_skip_load is not None:
140+
for param_to_skip in params_to_skip_load:
141+
if param_to_skip in ignored_param_name:
142+
should_skip = True
143+
144+
if not should_skip:
145+
with safe_open(safe_path, framework="pt", device=device) as f:
146+
value = f.get_tensor(ignored_param_name)
147+
yield ignored_param_name, value
131148

132149
@staticmethod
133150
def should_compress(name: str, expanded_targets: Optional[Set[str]] = None) -> bool:

0 commit comments

Comments
 (0)