12
12
import cuda_ext
13
13
import json
14
14
import math
15
+ import gc
15
16
from enum import Enum
16
17
17
18
class ParsedEnum (Enum ):
@@ -50,7 +51,6 @@ def __init__(self, model_config_path):
50
51
self .intermediate_size = read_config ["intermediate_size" ]
51
52
self .num_attention_heads = read_config ["num_attention_heads" ]
52
53
self .num_hidden_layers = read_config ["num_hidden_layers" ]
53
- self .num_attention_heads = read_config ["num_attention_heads" ]
54
54
self .rms_norm_eps = read_config ["rms_norm_eps" ]
55
55
self .vocab_size = read_config ["vocab_size" ]
56
56
@@ -75,6 +75,7 @@ def __init__(self, model_config_path):
75
75
self .alpha_value = 1.0 # Alpha value for NTK RoPE scaling. Similar to compress_pos_emb, higher values increaste ctx but add Perplexity.
76
76
self .gpu_peer_fix = False # Apparently Torch can have problems transferring tensors directly one GPU to another sometimes. Enable this to expliticly move tensors via system RAM instead, where needed
77
77
self .auto_map = None # List of floats with memory allocation in GB, per CUDA device, overrides device_map
78
+
78
79
# Tuning
79
80
80
81
self .matmul_recons_thd = 8
@@ -409,7 +410,7 @@ def forward(self, hidden_states, cache, buffer, lora):
409
410
attn_weights = torch .matmul (query_states , key_states .transpose (2 , 3 ))
410
411
attn_weights /= math .sqrt (self .config .head_dim )
411
412
if buffer .attn_mask is not None : attn_weights = attn_weights + buffer .attn_mask
412
- attn_weights = nn .functional .softmax (attn_weights , dim = - 1 , dtype = torch .float16 ). to ( query_states . dtype )
413
+ attn_weights = nn .functional .softmax (attn_weights , dim = - 1 , dtype = torch .float16 )
413
414
attn_output = torch .matmul (attn_weights , value_states )
414
415
attn_output = attn_output .transpose (1 , 2 )
415
416
@@ -577,7 +578,12 @@ def get_layers_devs(self):
577
578
return sorted (list (set (self .layers )))
578
579
579
580
580
- def map (self , key , loading = False ):
581
+ def get_all_devs (self ):
582
+
583
+ return sorted (list (set (self .layers + [self .lm_head , self .norm , self .embed_tokens ])))
584
+
585
+
586
+ def map (self , key ):
581
587
582
588
if key .startswith ("lm_head." ): return self .lm_head
583
589
if key .startswith ("model.embed_tokens." ): return self .embed_tokens
@@ -629,6 +635,14 @@ def _move_tensor(tensor, new_device, name, config):
629
635
tensor = tensor .to ("cpu" )
630
636
return tensor .to (new_device )
631
637
638
+ def _layer_dtype_size (key ):
639
+ if key .endswith (".weight" ): return 2
640
+ if key .endswith (".qweight" ): return 4
641
+ if key .endswith (".qzeros" ): return 4
642
+ if key .endswith (".scales" ): return 2
643
+ if key .endswith (".g_idx" ): return 0
644
+ raise ValueError ("Unrecognized layer: " + key )
645
+
632
646
633
647
class ExLlama :
634
648
@@ -643,7 +657,7 @@ def __init__(self, config):
643
657
# Load model weights
644
658
645
659
tensors = {}
646
- with safe_open (self .config .model_path , framework = "pt" , device = "cpu" ) as f :
660
+ with safe_open (self .config .model_path , framework = "pt" , device = "cpu" ) as f :
647
661
648
662
# Begin auto mapping if enabled
649
663
@@ -662,16 +676,22 @@ def __init__(self, config):
662
676
if _skip_key (key ): continue
663
677
664
678
if key .startswith ("model.layers.0." ):
665
- tensor = f .get_tensor (key )
666
- decoder_size += tensor .numel () * tensor .element_size ()
679
+ tensor_slice = f .get_slice (key )
680
+ shape = tensor_slice .get_shape ()
681
+ decoder_size += math .prod (shape ) * _layer_dtype_size (key )
682
+ del tensor_slice
667
683
668
684
if key .startswith ("model.norm." ):
669
- tensor = f .get_tensor (key )
670
- norm_size += tensor .numel () * tensor .element_size ()
685
+ tensor_slice = f .get_slice (key )
686
+ shape = tensor_slice .get_shape ()
687
+ norm_size += math .prod (shape ) * _layer_dtype_size (key )
688
+ del tensor_slice
671
689
672
690
if key .startswith ("lm_head." ):
673
- tensor = f .get_tensor (key )
674
- head_size += tensor .numel () * tensor .element_size ()
691
+ tensor_slice = f .get_slice (key )
692
+ shape = tensor_slice .get_shape ()
693
+ head_size += math .prod (shape ) * _layer_dtype_size (key )
694
+ del tensor_slice
675
695
676
696
# Assign layers automatically
677
697
@@ -701,29 +721,47 @@ def __init__(self, config):
701
721
device_usage += this_layer_size
702
722
layer_index_device += 1
703
723
704
- # Load tensors, move to device(s)
705
-
706
- max_dq_buffer_size = 0
724
+ # Read tensor list from file
707
725
726
+ load_keys = []
727
+ with safe_open (self .config .model_path , framework = "pt" , device = "cpu" ) as f :
708
728
for key in f .keys ():
729
+ load_keys .append (key )
730
+
731
+ # Load up to 1 GB of tensors at a time, closing and reopening the file in between each chunk
732
+
733
+ max_dq_buffer_size = 0
734
+ f = None
735
+ st_mem = 0
736
+ MAX_ST_MEM = 1024 ** 3
737
+
738
+ for key in load_keys :
739
+
740
+ if _skip_key (key ): continue
741
+ device = self .config .device_map .map (key )
709
742
710
- if _skip_key (key ): continue
743
+ if f is None or st_mem > MAX_ST_MEM :
744
+ if f is not None : del f
745
+ f = safe_open (self .config .model_path , framework = "pt" , device = "cpu" )
746
+ st_mem = 0
711
747
712
- device = self .config .device_map .map (key , loading = True )
713
- tensor = f .get_tensor (key )
748
+ tensor = f .get_tensor (key )
749
+ size = tensor .numel () * tensor .element_size ()
750
+ st_mem += size
714
751
715
- if key .endswith (".scales" ): tensor = tensor .half ()
716
- if key == "lm_head.weight" : tensor = tensor .float () if device == "cpu" else tensor .half ()
717
- if key == "model.norm.weight" : tensor = tensor .half ()
718
- if key .endswith (".embed_tokens.weight" ): tensor = tensor .half ()
719
- if key .endswith (".input_layernorm.weight" ): tensor = tensor .half ()
720
- if key .endswith (".post_attention_layernorm.weight" ): tensor = tensor .half ()
752
+ if key .endswith (".scales" ): tensor = tensor .half ()
753
+ if key == "lm_head.weight" : tensor = tensor .float () if device == "cpu" else tensor .half ()
754
+ if key == "model.norm.weight" : tensor = tensor .half ()
755
+ if key .endswith (".embed_tokens.weight" ): tensor = tensor .half ()
756
+ if key .endswith (".input_layernorm.weight" ): tensor = tensor .half ()
757
+ if key .endswith (".post_attention_layernorm.weight" ): tensor = tensor .half ()
721
758
722
- tensor = tensor .to (device , non_blocking = True )
759
+ tensor = tensor .to (device , non_blocking = True )
760
+ if key .endswith (".qweight" ): max_dq_buffer_size = max (max_dq_buffer_size , tensor .numel () * 8 )
723
761
724
- if key . endswith ( ".qweight" ): max_dq_buffer_size = max ( max_dq_buffer_size , tensor . numel () * 8 )
762
+ tensors [ key ] = tensor
725
763
726
- tensors [ key ] = tensor
764
+ del f
727
765
728
766
# Head
729
767
0 commit comments