Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
city96 committed Aug 15, 2024
1 parent 7368f2d commit a8ce4e8
Show file tree
Hide file tree
Showing 6 changed files with 241 additions and 0 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
*.bin
*.gguf
*.safetensors

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
9 changes: 9 additions & 0 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# only import if running as a custom node
try:
import comfy.utils
except ImportError:
pass
else:
from .nodes import NODE_CLASS_MAPPINGS
NODE_DISPLAY_NAME_MAPPINGS = {k:v.TITLE for k,v in NODE_CLASS_MAPPINGS.items()}
__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']
85 changes: 85 additions & 0 deletions dequant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
import gguf
import torch
import numpy as np

def dequantize_tensor(tensor, dtype=torch.float16):
if tensor is None:
return None

data = torch.tensor(tensor.data)
qtype = tensor.tensor_type
oshape = tensor.tensor_shape

if qtype == gguf.GGMLQuantizationType.F32:
return data.to(dtype)
elif qtype == gguf.GGMLQuantizationType.F16:
return data.to(dtype)
elif qtype in dequantize_functions:
# this is the main pytorch op
return dequantize(data, qtype, oshape).to(dtype)
else:
# this is incredibly slow
new = gguf.quants.dequantize(data.cpu().numpy(), qtype)
return torch.from_numpy(new).to(data.device, dtype=dtype)

def dequantize(data, qtype, oshape):
"""
Dequantize tensor back to usable shape/dtype
"""
block_size, type_size = gguf.GGML_QUANT_SIZES[qtype]
dequantize_blocks = dequantize_functions[qtype]

rows = data.reshape(
(-1, data.shape[-1])
).view(torch.uint8)

n_blocks = rows.numel() // type_size
blocks = rows.reshape((n_blocks, type_size))
blocks = dequantize_blocks(blocks, block_size, type_size)
return blocks.reshape(oshape)

def to_uint32(x):
# no uint32 :(
x = x.view(torch.uint8).to(torch.int32)
return (x[:, 0] | x[:, 1] << 8 | x[:, 2] << 16 | x[:, 3] << 24).unsqueeze(1)

def dequantize_blocks_Q8_0(blocks, block_size, type_size):
d = blocks[:, :2].view(torch.float16)
x = blocks[:, 2:].view(torch.int8).to(torch.float16)
return (x * d)

def dequantize_blocks_Q5_0(blocks, block_size, type_size):
n_blocks = blocks.shape[0]

d = blocks[:, :2]
qh = blocks[:, 2:6]
qs = blocks[:, 6: ]

d = d.view(torch.float16).to(torch.float32)
qh = to_uint32(qh)

qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
ql = qs.reshape(n_blocks, -1, 1, block_size // 2) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1)

qh = (qh & 1).to(torch.uint8)
ql = (ql & 0x0F).reshape(n_blocks, -1)

qs = (ql | (qh << 4)).to(torch.int8) - 16
return (d * qs)

def dequantize_blocks_Q4_0(blocks, block_size, type_size):
n_blocks = blocks.shape[0]

d = blocks[:, :2].view(torch.float16)
qs = blocks[:, 2:]

qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1))
qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8
return (d * qs)

dequantize_functions = {
gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0,
gguf.GGMLQuantizationType.Q5_0: dequantize_blocks_Q5_0,
gguf.GGMLQuantizationType.Q4_0: dequantize_blocks_Q4_0,
}
64 changes: 64 additions & 0 deletions nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
import gguf
import logging

import comfy.sd
import comfy.utils
import comfy.model_management
import folder_paths

from .ops import GGMLTensor, GGMLOps

# TODO: This causes gguf files to show up in the main unet loader
folder_paths.folder_names_and_paths["unet"][1].add(".gguf")

def gguf_sd_loader(path):
"""
Read state dict as fake tensors
"""
reader = gguf.GGUFReader(path)
sd = {}
dt = {}
for tensor in reader.tensors:
sd[str(tensor.name)] = GGMLTensor(tensor)
dt[str(tensor.tensor_type)] = dt.get(str(tensor.tensor_type), 0) + 1

# TODO: .to doesn't get passed later on for some reason, this is a hotfix
#sd = {k:v.to(comfy.model_management.get_torch_device()) for k,v in sd.items()}

# sanity check debug print
print("\nggml_sd_loader:")
for k,v in dt.items():
print(f" {k:30}{v:3}")
print("\n")
return sd

class UnetLoaderGGUF:
@classmethod
def INPUT_TYPES(s):
unet_names = [x for x in folder_paths.get_filename_list("unet") if x.endswith(".gguf")]
return {
"required": {
"unet_name": (unet_names,),
}
}

RETURN_TYPES = ("MODEL",)
FUNCTION = "load_unet"
CATEGORY = "bootleg"
TITLE = "Unet Loader (GGUF)"

def load_unet(self, unet_name):
unet_path = folder_paths.get_full_path("unet", unet_name)
sd = gguf_sd_loader(unet_path)
model = comfy.sd.load_diffusion_model_state_dict(
sd, model_options={"custom_operations": GGMLOps}
)
if model is None:
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
return (model,)

NODE_CLASS_MAPPINGS = {
"UnetLoaderGGUF": UnetLoaderGGUF,
}
78 changes: 78 additions & 0 deletions ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
import gguf
import torch
import numpy as np

import comfy.ops
from .dequant import dequantize_tensor

class GGMLTensor(torch.Tensor):
"""
Main tensor-like class for storing quantized weights
"""
def __init__(self, tensor, *args, **kwargs):
super().__init__()
self.tensor_type = tensor.tensor_type
self.tensor_shape = torch.Size(
np.flip(list(tensor.shape))
)

def __new__(cls, tensor, *args, **kwargs):
data = torch.tensor(tensor.data)
return super().__new__(cls, data, *args, **kwargs)

def to(self, *args, **kwargs):
new = super().to(*args, **kwargs)
new.tensor_type = self.tensor_type
new.tensor_shape = self.tensor_shape
return new

@property
def shape(self):
return self.tensor_shape

class GGMLLayer(torch.nn.Module):
"""
This (should) be responsible for de-quantizing on the fly
"""
def __init__(self, *args, **kwargs):
super().__init__()
self.weight = None
self.bias = None

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
for k,v in state_dict.items():
if k[len(prefix):] == "weight":
self.weight = v
elif k[len(prefix):] == "bias":
self.bias = v
else:
missing_keys.append(k)

def _apply(self, fn):
if self.weight is not None:
self.weight = fn(self.weight)
if self.bias is not None:
self.bias = fn(self.bias)
super()._apply(fn)
return self

def get_weights(self, dtype=torch.float16):
weight = dequantize_tensor(self.weight, dtype)
bias = dequantize_tensor(self.bias, dtype)
return (weight, bias)

class GGMLOps(comfy.ops.disable_weight_init):
"""
Dequantize weights on the fly before doing the compute
"""
class Linear(GGMLLayer):
def __init__(self, *args, device=None, dtype=None, **kwargs):
super().__init__(device=device, dtype=dtype)
self.parameters_manual_cast = torch.float32

def forward(self, x):
weight, bias = self.get_weights(x.dtype)
x = torch.nn.functional.linear(x, weight, bias)
del weight, bias
return x
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
gguf>=0.9.1

0 comments on commit a8ce4e8

Please sign in to comment.