-
Notifications
You must be signed in to change notification settings - Fork 96
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
241 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
*.bin | ||
*.gguf | ||
*.safetensors | ||
|
||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# only import if running as a custom node | ||
try: | ||
import comfy.utils | ||
except ImportError: | ||
pass | ||
else: | ||
from .nodes import NODE_CLASS_MAPPINGS | ||
NODE_DISPLAY_NAME_MAPPINGS = {k:v.TITLE for k,v in NODE_CLASS_MAPPINGS.items()} | ||
__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) | ||
import gguf | ||
import torch | ||
import numpy as np | ||
|
||
def dequantize_tensor(tensor, dtype=torch.float16): | ||
if tensor is None: | ||
return None | ||
|
||
data = torch.tensor(tensor.data) | ||
qtype = tensor.tensor_type | ||
oshape = tensor.tensor_shape | ||
|
||
if qtype == gguf.GGMLQuantizationType.F32: | ||
return data.to(dtype) | ||
elif qtype == gguf.GGMLQuantizationType.F16: | ||
return data.to(dtype) | ||
elif qtype in dequantize_functions: | ||
# this is the main pytorch op | ||
return dequantize(data, qtype, oshape).to(dtype) | ||
else: | ||
# this is incredibly slow | ||
new = gguf.quants.dequantize(data.cpu().numpy(), qtype) | ||
return torch.from_numpy(new).to(data.device, dtype=dtype) | ||
|
||
def dequantize(data, qtype, oshape): | ||
""" | ||
Dequantize tensor back to usable shape/dtype | ||
""" | ||
block_size, type_size = gguf.GGML_QUANT_SIZES[qtype] | ||
dequantize_blocks = dequantize_functions[qtype] | ||
|
||
rows = data.reshape( | ||
(-1, data.shape[-1]) | ||
).view(torch.uint8) | ||
|
||
n_blocks = rows.numel() // type_size | ||
blocks = rows.reshape((n_blocks, type_size)) | ||
blocks = dequantize_blocks(blocks, block_size, type_size) | ||
return blocks.reshape(oshape) | ||
|
||
def to_uint32(x): | ||
# no uint32 :( | ||
x = x.view(torch.uint8).to(torch.int32) | ||
return (x[:, 0] | x[:, 1] << 8 | x[:, 2] << 16 | x[:, 3] << 24).unsqueeze(1) | ||
|
||
def dequantize_blocks_Q8_0(blocks, block_size, type_size): | ||
d = blocks[:, :2].view(torch.float16) | ||
x = blocks[:, 2:].view(torch.int8).to(torch.float16) | ||
return (x * d) | ||
|
||
def dequantize_blocks_Q5_0(blocks, block_size, type_size): | ||
n_blocks = blocks.shape[0] | ||
|
||
d = blocks[:, :2] | ||
qh = blocks[:, 2:6] | ||
qs = blocks[:, 6: ] | ||
|
||
d = d.view(torch.float16).to(torch.float32) | ||
qh = to_uint32(qh) | ||
|
||
qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32) | ||
ql = qs.reshape(n_blocks, -1, 1, block_size // 2) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1) | ||
|
||
qh = (qh & 1).to(torch.uint8) | ||
ql = (ql & 0x0F).reshape(n_blocks, -1) | ||
|
||
qs = (ql | (qh << 4)).to(torch.int8) - 16 | ||
return (d * qs) | ||
|
||
def dequantize_blocks_Q4_0(blocks, block_size, type_size): | ||
n_blocks = blocks.shape[0] | ||
|
||
d = blocks[:, :2].view(torch.float16) | ||
qs = blocks[:, 2:] | ||
|
||
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1)) | ||
qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8 | ||
return (d * qs) | ||
|
||
dequantize_functions = { | ||
gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0, | ||
gguf.GGMLQuantizationType.Q5_0: dequantize_blocks_Q5_0, | ||
gguf.GGMLQuantizationType.Q4_0: dequantize_blocks_Q4_0, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) | ||
import gguf | ||
import logging | ||
|
||
import comfy.sd | ||
import comfy.utils | ||
import comfy.model_management | ||
import folder_paths | ||
|
||
from .ops import GGMLTensor, GGMLOps | ||
|
||
# TODO: This causes gguf files to show up in the main unet loader | ||
folder_paths.folder_names_and_paths["unet"][1].add(".gguf") | ||
|
||
def gguf_sd_loader(path): | ||
""" | ||
Read state dict as fake tensors | ||
""" | ||
reader = gguf.GGUFReader(path) | ||
sd = {} | ||
dt = {} | ||
for tensor in reader.tensors: | ||
sd[str(tensor.name)] = GGMLTensor(tensor) | ||
dt[str(tensor.tensor_type)] = dt.get(str(tensor.tensor_type), 0) + 1 | ||
|
||
# TODO: .to doesn't get passed later on for some reason, this is a hotfix | ||
#sd = {k:v.to(comfy.model_management.get_torch_device()) for k,v in sd.items()} | ||
|
||
# sanity check debug print | ||
print("\nggml_sd_loader:") | ||
for k,v in dt.items(): | ||
print(f" {k:30}{v:3}") | ||
print("\n") | ||
return sd | ||
|
||
class UnetLoaderGGUF: | ||
@classmethod | ||
def INPUT_TYPES(s): | ||
unet_names = [x for x in folder_paths.get_filename_list("unet") if x.endswith(".gguf")] | ||
return { | ||
"required": { | ||
"unet_name": (unet_names,), | ||
} | ||
} | ||
|
||
RETURN_TYPES = ("MODEL",) | ||
FUNCTION = "load_unet" | ||
CATEGORY = "bootleg" | ||
TITLE = "Unet Loader (GGUF)" | ||
|
||
def load_unet(self, unet_name): | ||
unet_path = folder_paths.get_full_path("unet", unet_name) | ||
sd = gguf_sd_loader(unet_path) | ||
model = comfy.sd.load_diffusion_model_state_dict( | ||
sd, model_options={"custom_operations": GGMLOps} | ||
) | ||
if model is None: | ||
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path)) | ||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) | ||
return (model,) | ||
|
||
NODE_CLASS_MAPPINGS = { | ||
"UnetLoaderGGUF": UnetLoaderGGUF, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) | ||
import gguf | ||
import torch | ||
import numpy as np | ||
|
||
import comfy.ops | ||
from .dequant import dequantize_tensor | ||
|
||
class GGMLTensor(torch.Tensor): | ||
""" | ||
Main tensor-like class for storing quantized weights | ||
""" | ||
def __init__(self, tensor, *args, **kwargs): | ||
super().__init__() | ||
self.tensor_type = tensor.tensor_type | ||
self.tensor_shape = torch.Size( | ||
np.flip(list(tensor.shape)) | ||
) | ||
|
||
def __new__(cls, tensor, *args, **kwargs): | ||
data = torch.tensor(tensor.data) | ||
return super().__new__(cls, data, *args, **kwargs) | ||
|
||
def to(self, *args, **kwargs): | ||
new = super().to(*args, **kwargs) | ||
new.tensor_type = self.tensor_type | ||
new.tensor_shape = self.tensor_shape | ||
return new | ||
|
||
@property | ||
def shape(self): | ||
return self.tensor_shape | ||
|
||
class GGMLLayer(torch.nn.Module): | ||
""" | ||
This (should) be responsible for de-quantizing on the fly | ||
""" | ||
def __init__(self, *args, **kwargs): | ||
super().__init__() | ||
self.weight = None | ||
self.bias = None | ||
|
||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): | ||
for k,v in state_dict.items(): | ||
if k[len(prefix):] == "weight": | ||
self.weight = v | ||
elif k[len(prefix):] == "bias": | ||
self.bias = v | ||
else: | ||
missing_keys.append(k) | ||
|
||
def _apply(self, fn): | ||
if self.weight is not None: | ||
self.weight = fn(self.weight) | ||
if self.bias is not None: | ||
self.bias = fn(self.bias) | ||
super()._apply(fn) | ||
return self | ||
|
||
def get_weights(self, dtype=torch.float16): | ||
weight = dequantize_tensor(self.weight, dtype) | ||
bias = dequantize_tensor(self.bias, dtype) | ||
return (weight, bias) | ||
|
||
class GGMLOps(comfy.ops.disable_weight_init): | ||
""" | ||
Dequantize weights on the fly before doing the compute | ||
""" | ||
class Linear(GGMLLayer): | ||
def __init__(self, *args, device=None, dtype=None, **kwargs): | ||
super().__init__(device=device, dtype=dtype) | ||
self.parameters_manual_cast = torch.float32 | ||
|
||
def forward(self, x): | ||
weight, bias = self.get_weights(x.dtype) | ||
x = torch.nn.functional.linear(x, weight, bias) | ||
del weight, bias | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
gguf>=0.9.1 |