Skip to content

Commit 617fd2a

Browse files
committed
Fix potential Safetensors memory leak and limit system RAM usage while loading model
1 parent d769533 commit 617fd2a

File tree

1 file changed

+63
-25
lines changed

1 file changed

+63
-25
lines changed

model.py

Lines changed: 63 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import cuda_ext
1313
import json
1414
import math
15+
import gc
1516
from enum import Enum
1617

1718
class ParsedEnum(Enum):
@@ -50,7 +51,6 @@ def __init__(self, model_config_path):
5051
self.intermediate_size = read_config["intermediate_size"]
5152
self.num_attention_heads = read_config["num_attention_heads"]
5253
self.num_hidden_layers = read_config["num_hidden_layers"]
53-
self.num_attention_heads = read_config["num_attention_heads"]
5454
self.rms_norm_eps = read_config["rms_norm_eps"]
5555
self.vocab_size = read_config["vocab_size"]
5656

@@ -75,6 +75,7 @@ def __init__(self, model_config_path):
7575
self.alpha_value = 1.0 # Alpha value for NTK RoPE scaling. Similar to compress_pos_emb, higher values increaste ctx but add Perplexity.
7676
self.gpu_peer_fix = False # Apparently Torch can have problems transferring tensors directly one GPU to another sometimes. Enable this to expliticly move tensors via system RAM instead, where needed
7777
self.auto_map = None # List of floats with memory allocation in GB, per CUDA device, overrides device_map
78+
7879
# Tuning
7980

8081
self.matmul_recons_thd = 8
@@ -409,7 +410,7 @@ def forward(self, hidden_states, cache, buffer, lora):
409410
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
410411
attn_weights /= math.sqrt(self.config.head_dim)
411412
if buffer.attn_mask is not None: attn_weights = attn_weights + buffer.attn_mask
412-
attn_weights = nn.functional.softmax(attn_weights, dim = -1, dtype = torch.float16).to(query_states.dtype)
413+
attn_weights = nn.functional.softmax(attn_weights, dim = -1, dtype = torch.float16)
413414
attn_output = torch.matmul(attn_weights, value_states)
414415
attn_output = attn_output.transpose(1, 2)
415416

@@ -577,7 +578,12 @@ def get_layers_devs(self):
577578
return sorted(list(set(self.layers)))
578579

579580

580-
def map(self, key, loading = False):
581+
def get_all_devs(self):
582+
583+
return sorted(list(set(self.layers + [self.lm_head, self.norm, self.embed_tokens])))
584+
585+
586+
def map(self, key):
581587

582588
if key.startswith("lm_head."): return self.lm_head
583589
if key.startswith("model.embed_tokens."): return self.embed_tokens
@@ -629,6 +635,14 @@ def _move_tensor(tensor, new_device, name, config):
629635
tensor = tensor.to("cpu")
630636
return tensor.to(new_device)
631637

638+
def _layer_dtype_size(key):
639+
if key.endswith(".weight"): return 2
640+
if key.endswith(".qweight"): return 4
641+
if key.endswith(".qzeros"): return 4
642+
if key.endswith(".scales"): return 2
643+
if key.endswith(".g_idx"): return 0
644+
raise ValueError("Unrecognized layer: " + key)
645+
632646

633647
class ExLlama:
634648

@@ -643,7 +657,7 @@ def __init__(self, config):
643657
# Load model weights
644658

645659
tensors = {}
646-
with safe_open(self.config.model_path, framework="pt", device="cpu") as f:
660+
with safe_open(self.config.model_path, framework = "pt", device = "cpu") as f:
647661

648662
# Begin auto mapping if enabled
649663

@@ -662,16 +676,22 @@ def __init__(self, config):
662676
if _skip_key(key): continue
663677

664678
if key.startswith("model.layers.0."):
665-
tensor = f.get_tensor(key)
666-
decoder_size += tensor.numel() * tensor.element_size()
679+
tensor_slice = f.get_slice(key)
680+
shape = tensor_slice.get_shape()
681+
decoder_size += math.prod(shape) * _layer_dtype_size(key)
682+
del tensor_slice
667683

668684
if key.startswith("model.norm."):
669-
tensor = f.get_tensor(key)
670-
norm_size += tensor.numel() * tensor.element_size()
685+
tensor_slice = f.get_slice(key)
686+
shape = tensor_slice.get_shape()
687+
norm_size += math.prod(shape) * _layer_dtype_size(key)
688+
del tensor_slice
671689

672690
if key.startswith("lm_head."):
673-
tensor = f.get_tensor(key)
674-
head_size += tensor.numel() * tensor.element_size()
691+
tensor_slice = f.get_slice(key)
692+
shape = tensor_slice.get_shape()
693+
head_size += math.prod(shape) * _layer_dtype_size(key)
694+
del tensor_slice
675695

676696
# Assign layers automatically
677697

@@ -701,29 +721,47 @@ def __init__(self, config):
701721
device_usage += this_layer_size
702722
layer_index_device += 1
703723

704-
# Load tensors, move to device(s)
705-
706-
max_dq_buffer_size = 0
724+
# Read tensor list from file
707725

726+
load_keys = []
727+
with safe_open(self.config.model_path, framework = "pt", device = "cpu") as f:
708728
for key in f.keys():
729+
load_keys.append(key)
730+
731+
# Load up to 1 GB of tensors at a time, closing and reopening the file in between each chunk
732+
733+
max_dq_buffer_size = 0
734+
f = None
735+
st_mem = 0
736+
MAX_ST_MEM = 1024**3
737+
738+
for key in load_keys:
739+
740+
if _skip_key(key): continue
741+
device = self.config.device_map.map(key)
709742

710-
if _skip_key(key): continue
743+
if f is None or st_mem > MAX_ST_MEM:
744+
if f is not None: del f
745+
f = safe_open(self.config.model_path, framework = "pt", device = "cpu")
746+
st_mem = 0
711747

712-
device = self.config.device_map.map(key, loading = True)
713-
tensor = f.get_tensor(key)
748+
tensor = f.get_tensor(key)
749+
size = tensor.numel() * tensor.element_size()
750+
st_mem += size
714751

715-
if key.endswith(".scales"): tensor = tensor.half()
716-
if key == "lm_head.weight": tensor = tensor.float() if device == "cpu" else tensor.half()
717-
if key == "model.norm.weight": tensor = tensor.half()
718-
if key.endswith(".embed_tokens.weight"): tensor = tensor.half()
719-
if key.endswith(".input_layernorm.weight"): tensor = tensor.half()
720-
if key.endswith(".post_attention_layernorm.weight"): tensor = tensor.half()
752+
if key.endswith(".scales"): tensor = tensor.half()
753+
if key == "lm_head.weight": tensor = tensor.float() if device == "cpu" else tensor.half()
754+
if key == "model.norm.weight": tensor = tensor.half()
755+
if key.endswith(".embed_tokens.weight"): tensor = tensor.half()
756+
if key.endswith(".input_layernorm.weight"): tensor = tensor.half()
757+
if key.endswith(".post_attention_layernorm.weight"): tensor = tensor.half()
721758

722-
tensor = tensor.to(device, non_blocking = True)
759+
tensor = tensor.to(device, non_blocking = True)
760+
if key.endswith(".qweight"): max_dq_buffer_size = max(max_dq_buffer_size, tensor.numel() * 8)
723761

724-
if key.endswith(".qweight"): max_dq_buffer_size = max(max_dq_buffer_size, tensor.numel() * 8)
762+
tensors[key] = tensor
725763

726-
tensors[key] = tensor
764+
del f
727765

728766
# Head
729767

0 commit comments

Comments
 (0)