31
31
SPARSITY_CONFIG_NAME ,
32
32
)
33
33
from compressed_tensors .compressors .base import BaseCompressor
34
+ from compressed_tensors .compressors .sparse_compressors import DenseCompressor
34
35
from compressed_tensors .config import CompressionFormat , SparsityCompressionConfig
35
36
from compressed_tensors .quantization import (
36
37
DEFAULT_QUANTIZATION_METHOD ,
37
38
QuantizationConfig ,
38
39
QuantizationStatus ,
39
40
apply_quantization_config ,
40
- load_pretrained_quantization ,
41
+ load_pretrained_quantization_parameters ,
41
42
)
42
43
from compressed_tensors .quantization .lifecycle import expand_target_names
43
44
from compressed_tensors .quantization .quant_args import QuantizationArgs
47
48
)
48
49
from compressed_tensors .utils import (
49
50
get_safetensors_folder ,
51
+ has_offloaded_params ,
50
52
merge_names ,
53
+ register_offload_parameter ,
51
54
update_parameter_data ,
52
55
)
53
56
from compressed_tensors .utils .helpers import (
@@ -412,6 +415,13 @@ def decompress(self, model_path: str, model: Module):
412
415
413
416
:param model_path: path to compressed weights
414
417
:param model: pytorch model to load decompressed weights into
418
+
419
+ Note: decompress makes use of both _replace_sparsity_weights and _replace_weights
420
+ The variations in these methods are a result of the subtle variations between the sparsity
421
+ and quantization compressors. Specifically, quantization compressors return not just the
422
+ decompressed weight, but the quantization parameters (e.g scales, zero_point) whereas sparsity
423
+ compressors only return the decompressed weight.
424
+
415
425
"""
416
426
model_path = get_safetensors_folder (model_path )
417
427
sparse_decompressed = False
@@ -420,9 +430,16 @@ def decompress(self, model_path: str, model: Module):
420
430
self .sparsity_compressor is not None
421
431
and self .sparsity_config .format != CompressionFormat .dense .value
422
432
):
433
+ params_to_ignore = None
434
+ if self .quantization_compressor is not None :
435
+ params_to_ignore = self .quantization_compressor .compression_param_names
423
436
# Sparse decompression is applied on the model_path
424
- dense_gen = self .sparsity_compressor .decompress (model_path )
425
- self ._replace_weights (dense_gen , model )
437
+ # The compressor will try and load any quantization parameters as well
438
+ # params_to_skip_load will skip over quantization params from being loaded
439
+ dense_gen = self .sparsity_compressor .decompress (
440
+ model_path , params_to_skip_load = params_to_ignore
441
+ )
442
+ self ._replace_sparsity_weights (dense_gen , model )
426
443
setattr (model , SPARSITY_CONFIG_NAME , self .sparsity_compressor .config )
427
444
sparse_decompressed = True
428
445
@@ -431,13 +448,27 @@ def decompress(self, model_path: str, model: Module):
431
448
# quantization during apply_quantization_config. This ensures
432
449
# that the dtypes of the weights are not unintentionally updated.
433
450
# The status is restored after quantization params are loaded.
451
+
434
452
with override_quantization_status (
435
453
self .quantization_config , QuantizationStatus .FROZEN
436
454
):
455
+
437
456
names_to_scheme = apply_quantization_config (
438
457
model , self .quantization_config
439
458
)
440
- load_pretrained_quantization (model , model_path )
459
+ # Load activation scales/zp or any other quantization parameters
460
+ # Conditionally load the weight quantization parameters if we have a dense compressor
461
+ # Or if a sparsity compressor has already been applied
462
+ load_pretrained_quantization_parameters (
463
+ model ,
464
+ model_path ,
465
+ # TODO: all weight quantization params will be moved to the compressor in a follow-up
466
+ # including initialization
467
+ load_weight_quantization = (
468
+ sparse_decompressed
469
+ or isinstance (self .quantization_compressor , DenseCompressor )
470
+ ),
471
+ )
441
472
442
473
model_path_or_state_dict = (
443
474
model .state_dict () if sparse_decompressed else model_path
@@ -446,6 +477,8 @@ def decompress(self, model_path: str, model: Module):
446
477
dense_gen = self .quantization_compressor .decompress (
447
478
model_path_or_state_dict , names_to_scheme = names_to_scheme
448
479
)
480
+ # TODO: all weight quantization params will be moved to the compressor
481
+ # to prevent duplicate parameter updates in update_parameter_data
449
482
self ._replace_weights (dense_gen , model )
450
483
451
484
def freeze_quantization_status (module ):
@@ -501,7 +534,7 @@ def update_config(self, save_directory: str):
501
534
with open (config_file_path , "w" ) as config_file :
502
535
json .dump (config_data , config_file , indent = 2 , sort_keys = True )
503
536
504
- def _replace_weights (self , dense_weight_generator , model : Module ):
537
+ def _replace_sparsity_weights (self , dense_weight_generator , model : Module ):
505
538
"""
506
539
Replace the weights of the model with the
507
540
provided dense weights.
@@ -516,11 +549,60 @@ def _replace_weights(self, dense_weight_generator, model: Module):
516
549
:param model: The model whose weights are to be updated.
517
550
"""
518
551
for name , data in tqdm (dense_weight_generator , desc = "Decompressing model" ):
552
+
519
553
split_name = name .split ("." )
520
554
prefix , param_name = "." .join (split_name [:- 1 ]), split_name [- 1 ]
521
555
module = operator .attrgetter (prefix )(model )
522
- if hasattr (module , param_name ):
523
- update_parameter_data (module , data , param_name )
556
+
557
+ params_device = next (module .parameters ()).device
558
+ device = "cpu" if has_offloaded_params (module ) else params_device
559
+ delattr (module , param_name )
560
+ requires_grad = data .dtype in (torch .float16 , torch .float32 , torch .bfloat16 )
561
+ param = torch .nn .Parameter (data .to (device ), requires_grad = requires_grad )
562
+ register_offload_parameter (module , param_name , param )
563
+
564
+ def _replace_weights (self , dense_weight_generator , model : Module ):
565
+ """
566
+ Replace the weights of the model with the
567
+ provided dense weights.
568
+
569
+ This method iterates over the dense_weight_generator and
570
+ updates the corresponding weights in the model. If a parameter
571
+ name does not exist in the model, it will be skipped.
572
+
573
+ :param dense_weight_generator (generator): A generator that yields
574
+ tuples of (name, data), where 'name' is the parameter name and
575
+ 'data' is the updated param data
576
+ :param model: The model whose weights are to be updated.
577
+ """
578
+
579
+ for name , data in tqdm (dense_weight_generator , desc = "Decompressing model" ):
580
+ module = operator .attrgetter (name )(model )
581
+
582
+ params_device = next (module .parameters ()).device
583
+ device = "cpu" if has_offloaded_params (module ) else params_device
584
+
585
+ for param_name , param_data in data .items ():
586
+ if hasattr (module , param_name ):
587
+ # If compressed, will have an incorrect dtype for transformers >4.49
588
+ # TODO: we can also just skip initialization of scales/zp if in decompression in init
589
+ # to be consistent with loading which happens later as well
590
+ # however, update_data does a good shape check - should be moved to the compressor
591
+ if param_name == "weight" :
592
+ delattr (module , param_name )
593
+ requires_grad = param_data .dtype in (
594
+ torch .float16 ,
595
+ torch .float32 ,
596
+ torch .bfloat16 ,
597
+ )
598
+ param = torch .nn .Parameter (
599
+ param_data .to (device ), requires_grad = requires_grad
600
+ )
601
+ register_offload_parameter (module , param_name , param )
602
+ else :
603
+ # Should already be registered to the correct device for
604
+ # for scales/zero-points
605
+ update_parameter_data (module , param_data , param_name )
524
606
525
607
526
608
def map_modules_to_quant_args (
0 commit comments