Skip to content

Commit

Permalink
Account for dequantization VRAM to avoid OOM
Browse files Browse the repository at this point in the history
  • Loading branch information
city96 committed Dec 22, 2024
1 parent eaffa0a commit 51fa2cb
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
6 changes: 6 additions & 0 deletions loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import gguf

from .ops import GGMLTensor
from .dequant import is_quantized

IMG_ARCH_LIST = {"flux", "sd1", "sdxl", "sd3", "aura", "ltxv", "hyvid"}
TXT_ARCH_LIST = {"t5", "t5encoder", "llama"}
Expand Down Expand Up @@ -78,6 +79,11 @@ def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", return_arch=Fal
state_dict[sd_key] = GGMLTensor(torch_tensor, tensor_type=tensor.tensor_type, tensor_shape=shape)
qtype_dict[tensor_type_str] = qtype_dict.get(tensor_type_str, 0) + 1

# mark largest tensor for vram estimation
qsd = {k:v for k,v in state_dict.items() if is_quantized(v)}
max_key = max(qsd.keys(), key=lambda k: qsd[k].numel())
state_dict[max_key].is_largest_weight = True

# sanity check debug print
print("\nggml_sd_loader:")
for k,v in qtype_dict.items():
Expand Down
17 changes: 15 additions & 2 deletions ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class GGMLLayer(torch.nn.Module):
comfy_cast_weights = True
dequant_dtype = None
patch_dtype = None
largest_layer = False
torch_compatible_tensor_types = {None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}

def is_ggml_quantized(self, *, weight=None, bias=None):
Expand All @@ -87,12 +88,17 @@ def ggml_load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
self.bias = torch.nn.Parameter(v, requires_grad=False)
else:
unexpected_keys.append(k)

# For Linear layer with missing weight
if self.weight is None and isinstance(self, torch.nn.Linear):
v = torch.zeros(self.in_features, self.out_features)
self.weight = torch.nn.Parameter(v, requires_grad=False)
missing_keys.append(prefix+"weight")

# for vram estimation (TODO: less fragile logic?)
if getattr(self.weight, "is_largest_weight", False):
self.largest_layer = True

def _save_to_state_dict(self, *args, **kwargs):
if self.is_ggml_quantized():
return self.ggml_save_to_state_dict(*args, **kwargs)
Expand All @@ -105,9 +111,16 @@ def ggml_save_to_state_dict(self, destination, prefix, keep_vars):
if self.bias is not None:
bias = torch.zeros_like(self.bias, device=torch.device("meta"))
destination[prefix + "bias"] = bias
return

# This would return the actual state dict
# Take into account space required for dequantizing the largest tensor
if self.largest_layer:
shape = getattr(self.weight, "tensor_shape", self.weight.shape)
dtype = self.dequant_dtype or torch.float16
temp = torch.empty(*shape, device=torch.device("meta"), dtype=dtype)
destination[prefix + "temp.weight"] = temp

return
# This would return the dequantized state dict
destination[prefix + "weight"] = self.get_weight(self.weight)
if bias is not None:
destination[prefix + "bias"] = self.get_weight(self.bias)
Expand Down

0 comments on commit 51fa2cb

Please sign in to comment.