From a8ce4e8d7c3b8d99c12fbc2892fde1b42696238d Mon Sep 17 00:00:00 2001 From: City <125218114+city96@users.noreply.github.com> Date: Thu, 15 Aug 2024 05:50:54 +0200 Subject: [PATCH] Initial commit --- .gitignore | 4 +++ __init__.py | 9 +++++ dequant.py | 85 ++++++++++++++++++++++++++++++++++++++++++++++++ nodes.py | 64 ++++++++++++++++++++++++++++++++++++ ops.py | 78 ++++++++++++++++++++++++++++++++++++++++++++ requirements.txt | 1 + 6 files changed, 241 insertions(+) create mode 100644 __init__.py create mode 100644 dequant.py create mode 100644 nodes.py create mode 100644 ops.py create mode 100644 requirements.txt diff --git a/.gitignore b/.gitignore index 82f9275..a485d15 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,7 @@ +*.bin +*.gguf +*.safetensors + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..a03726e --- /dev/null +++ b/__init__.py @@ -0,0 +1,9 @@ +# only import if running as a custom node +try: + import comfy.utils +except ImportError: + pass +else: + from .nodes import NODE_CLASS_MAPPINGS + NODE_DISPLAY_NAME_MAPPINGS = {k:v.TITLE for k,v in NODE_CLASS_MAPPINGS.items()} + __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] diff --git a/dequant.py b/dequant.py new file mode 100644 index 0000000..2be51b8 --- /dev/null +++ b/dequant.py @@ -0,0 +1,85 @@ +# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) +import gguf +import torch +import numpy as np + +def dequantize_tensor(tensor, dtype=torch.float16): + if tensor is None: + return None + + data = torch.tensor(tensor.data) + qtype = tensor.tensor_type + oshape = tensor.tensor_shape + + if qtype == gguf.GGMLQuantizationType.F32: + return data.to(dtype) + elif qtype == gguf.GGMLQuantizationType.F16: + return data.to(dtype) + elif qtype in dequantize_functions: + # this is the main pytorch op + return dequantize(data, qtype, oshape).to(dtype) + else: + # this is incredibly slow + new = gguf.quants.dequantize(data.cpu().numpy(), qtype) + return torch.from_numpy(new).to(data.device, dtype=dtype) + +def dequantize(data, qtype, oshape): + """ + Dequantize tensor back to usable shape/dtype + """ + block_size, type_size = gguf.GGML_QUANT_SIZES[qtype] + dequantize_blocks = dequantize_functions[qtype] + + rows = data.reshape( + (-1, data.shape[-1]) + ).view(torch.uint8) + + n_blocks = rows.numel() // type_size + blocks = rows.reshape((n_blocks, type_size)) + blocks = dequantize_blocks(blocks, block_size, type_size) + return blocks.reshape(oshape) + +def to_uint32(x): + # no uint32 :( + x = x.view(torch.uint8).to(torch.int32) + return (x[:, 0] | x[:, 1] << 8 | x[:, 2] << 16 | x[:, 3] << 24).unsqueeze(1) + +def dequantize_blocks_Q8_0(blocks, block_size, type_size): + d = blocks[:, :2].view(torch.float16) + x = blocks[:, 2:].view(torch.int8).to(torch.float16) + return (x * d) + +def dequantize_blocks_Q5_0(blocks, block_size, type_size): + n_blocks = blocks.shape[0] + + d = blocks[:, :2] + qh = blocks[:, 2:6] + qs = blocks[:, 6: ] + + d = d.view(torch.float16).to(torch.float32) + qh = to_uint32(qh) + + qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32) + ql = qs.reshape(n_blocks, -1, 1, block_size // 2) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1) + + qh = (qh & 1).to(torch.uint8) + ql = (ql & 0x0F).reshape(n_blocks, -1) + + qs = (ql | (qh << 4)).to(torch.int8) - 16 + return (d * qs) + +def dequantize_blocks_Q4_0(blocks, block_size, type_size): + n_blocks = blocks.shape[0] + + d = blocks[:, :2].view(torch.float16) + qs = blocks[:, 2:] + + qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1)) + qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8 + return (d * qs) + +dequantize_functions = { + gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0, + gguf.GGMLQuantizationType.Q5_0: dequantize_blocks_Q5_0, + gguf.GGMLQuantizationType.Q4_0: dequantize_blocks_Q4_0, +} diff --git a/nodes.py b/nodes.py new file mode 100644 index 0000000..88b3a8d --- /dev/null +++ b/nodes.py @@ -0,0 +1,64 @@ +# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) +import gguf +import logging + +import comfy.sd +import comfy.utils +import comfy.model_management +import folder_paths + +from .ops import GGMLTensor, GGMLOps + +# TODO: This causes gguf files to show up in the main unet loader +folder_paths.folder_names_and_paths["unet"][1].add(".gguf") + +def gguf_sd_loader(path): + """ + Read state dict as fake tensors + """ + reader = gguf.GGUFReader(path) + sd = {} + dt = {} + for tensor in reader.tensors: + sd[str(tensor.name)] = GGMLTensor(tensor) + dt[str(tensor.tensor_type)] = dt.get(str(tensor.tensor_type), 0) + 1 + + # TODO: .to doesn't get passed later on for some reason, this is a hotfix + #sd = {k:v.to(comfy.model_management.get_torch_device()) for k,v in sd.items()} + + # sanity check debug print + print("\nggml_sd_loader:") + for k,v in dt.items(): + print(f" {k:30}{v:3}") + print("\n") + return sd + +class UnetLoaderGGUF: + @classmethod + def INPUT_TYPES(s): + unet_names = [x for x in folder_paths.get_filename_list("unet") if x.endswith(".gguf")] + return { + "required": { + "unet_name": (unet_names,), + } + } + + RETURN_TYPES = ("MODEL",) + FUNCTION = "load_unet" + CATEGORY = "bootleg" + TITLE = "Unet Loader (GGUF)" + + def load_unet(self, unet_name): + unet_path = folder_paths.get_full_path("unet", unet_name) + sd = gguf_sd_loader(unet_path) + model = comfy.sd.load_diffusion_model_state_dict( + sd, model_options={"custom_operations": GGMLOps} + ) + if model is None: + logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path)) + raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) + return (model,) + +NODE_CLASS_MAPPINGS = { + "UnetLoaderGGUF": UnetLoaderGGUF, +} diff --git a/ops.py b/ops.py new file mode 100644 index 0000000..7c79952 --- /dev/null +++ b/ops.py @@ -0,0 +1,78 @@ +# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) +import gguf +import torch +import numpy as np + +import comfy.ops +from .dequant import dequantize_tensor + +class GGMLTensor(torch.Tensor): + """ + Main tensor-like class for storing quantized weights + """ + def __init__(self, tensor, *args, **kwargs): + super().__init__() + self.tensor_type = tensor.tensor_type + self.tensor_shape = torch.Size( + np.flip(list(tensor.shape)) + ) + + def __new__(cls, tensor, *args, **kwargs): + data = torch.tensor(tensor.data) + return super().__new__(cls, data, *args, **kwargs) + + def to(self, *args, **kwargs): + new = super().to(*args, **kwargs) + new.tensor_type = self.tensor_type + new.tensor_shape = self.tensor_shape + return new + + @property + def shape(self): + return self.tensor_shape + +class GGMLLayer(torch.nn.Module): + """ + This (should) be responsible for de-quantizing on the fly + """ + def __init__(self, *args, **kwargs): + super().__init__() + self.weight = None + self.bias = None + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + for k,v in state_dict.items(): + if k[len(prefix):] == "weight": + self.weight = v + elif k[len(prefix):] == "bias": + self.bias = v + else: + missing_keys.append(k) + + def _apply(self, fn): + if self.weight is not None: + self.weight = fn(self.weight) + if self.bias is not None: + self.bias = fn(self.bias) + super()._apply(fn) + return self + + def get_weights(self, dtype=torch.float16): + weight = dequantize_tensor(self.weight, dtype) + bias = dequantize_tensor(self.bias, dtype) + return (weight, bias) + +class GGMLOps(comfy.ops.disable_weight_init): + """ + Dequantize weights on the fly before doing the compute + """ + class Linear(GGMLLayer): + def __init__(self, *args, device=None, dtype=None, **kwargs): + super().__init__(device=device, dtype=dtype) + self.parameters_manual_cast = torch.float32 + + def forward(self, x): + weight, bias = self.get_weights(x.dtype) + x = torch.nn.functional.linear(x, weight, bias) + del weight, bias + return x diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..fdb89ac --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +gguf>=0.9.1 \ No newline at end of file