diff --git a/loader.py b/loader.py index 54c0bfc..7c8120e 100644 --- a/loader.py +++ b/loader.py @@ -3,6 +3,7 @@ import gguf from .ops import GGMLTensor +from .dequant import is_quantized IMG_ARCH_LIST = {"flux", "sd1", "sdxl", "sd3", "aura", "ltxv", "hyvid"} TXT_ARCH_LIST = {"t5", "t5encoder", "llama"} @@ -78,6 +79,11 @@ def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", return_arch=Fal state_dict[sd_key] = GGMLTensor(torch_tensor, tensor_type=tensor.tensor_type, tensor_shape=shape) qtype_dict[tensor_type_str] = qtype_dict.get(tensor_type_str, 0) + 1 + # mark largest tensor for vram estimation + qsd = {k:v for k,v in state_dict.items() if is_quantized(v)} + max_key = max(qsd.keys(), key=lambda k: qsd[k].numel()) + state_dict[max_key].is_largest_weight = True + # sanity check debug print print("\nggml_sd_loader:") for k,v in qtype_dict.items(): diff --git a/ops.py b/ops.py index 22342a0..422ef2e 100644 --- a/ops.py +++ b/ops.py @@ -62,6 +62,7 @@ class GGMLLayer(torch.nn.Module): comfy_cast_weights = True dequant_dtype = None patch_dtype = None + largest_layer = False torch_compatible_tensor_types = {None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16} def is_ggml_quantized(self, *, weight=None, bias=None): @@ -87,12 +88,17 @@ def ggml_load_from_state_dict(self, state_dict, prefix, local_metadata, strict, self.bias = torch.nn.Parameter(v, requires_grad=False) else: unexpected_keys.append(k) + # For Linear layer with missing weight if self.weight is None and isinstance(self, torch.nn.Linear): v = torch.zeros(self.in_features, self.out_features) self.weight = torch.nn.Parameter(v, requires_grad=False) missing_keys.append(prefix+"weight") + # for vram estimation (TODO: less fragile logic?) + if getattr(self.weight, "is_largest_weight", False): + self.largest_layer = True + def _save_to_state_dict(self, *args, **kwargs): if self.is_ggml_quantized(): return self.ggml_save_to_state_dict(*args, **kwargs) @@ -105,9 +111,16 @@ def ggml_save_to_state_dict(self, destination, prefix, keep_vars): if self.bias is not None: bias = torch.zeros_like(self.bias, device=torch.device("meta")) destination[prefix + "bias"] = bias - return - # This would return the actual state dict + # Take into account space required for dequantizing the largest tensor + if self.largest_layer: + shape = getattr(self.weight, "tensor_shape", self.weight.shape) + dtype = self.dequant_dtype or torch.float16 + temp = torch.empty(*shape, device=torch.device("meta"), dtype=dtype) + destination[prefix + "temp.weight"] = temp + + return + # This would return the dequantized state dict destination[prefix + "weight"] = self.get_weight(self.weight) if bias is not None: destination[prefix + "bias"] = self.get_weight(self.bias)