Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

convert.py/patch refactor #216

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions tools/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,10 @@ python convert.py --src E:\models\unet\flux1-dev.safetensors
To quantize the model, first apply the provided patch to the llama.cpp repo you've just cloned. If you get a "corrupt patch" error, you may have to [change the line endings in the patch file](https://github.com/city96/ComfyUI-GGUF/issues/90#issuecomment-2323011648).
```
cd llama.cpp
git checkout tags/b3600
git checkout tags/b3962
git apply ..\lcpp.patch
```

If you wish to quantize **SD3** or **AuraFlow** models, you should use the patch named `lcpp_sd3.patch`, which has slightly modified logic for quantizing. For this you'll want to target `tags/b3962` instead.


Then, compile the llama-quantize binary. This example uses cmake, on linux you can just use make.
```
Expand All @@ -41,6 +39,9 @@ llama.cpp\build\bin\Debug\llama-quantize.exe E:\models\unet\flux1-dev-BF16.gguf

You can extract the patch again with `git diff src\llama.cpp > lcpp.patch` if you wish to change something and contribute back.

> [!WARNING]
>For hunyuan video, you will have to uncomment the block in convert.py that deals with 5D tensors. This will save a **non functional** model to disk first, that you can quantize. After quantization, run `fix_5d_tensor.py` to add back the missing key that was saved by the conversion code. You will have to edit this file to set the correct paths/architecture. This may change in the future.


> [!WARNING]
> Do not use the diffusers UNET for flux, it won't work, use the default/reference checkpoint format. This is due to q/k/v being merged into one qkv key. You can convert it by loading it in ComfyUI and saving it using the built-in "ModelSave" node.
Expand Down
52 changes: 33 additions & 19 deletions tools/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,17 @@
QUANTIZATION_THRESHOLD = 1024
REARRANGE_THRESHOLD = 512
MAX_TENSOR_NAME_LENGTH = 127
MAX_TENSOR_DIMS = 4

class ModelTemplate:
arch = "invalid" # string describing architecture
shape_fix = False # whether to reshape tensors
keys_detect = [] # list of lists to match in state dict
keys_banned = [] # list of keys that should mark model as invalid for conversion

def handle_nd_tensor(self, key, data):
raise NotImplementedError(f"Tensor detected that exceeds dims supported by C++ code! ({key} @ {data.shape})")

class ModelFlux(ModelTemplate):
arch = "flux"
keys_detect = [
Expand All @@ -41,6 +45,24 @@ class ModelAura(ModelTemplate):
]
keys_banned = ["joint_transformer_blocks.3.ff_context.out_projection.weight",]

class ModelHyVid(ModelTemplate):
arch = "hyvid"
keys_detect = [
(
"double_blocks.0.img_attn_proj.weight",
"txt_in.individual_token_refiner.blocks.1.self_attn_qkv.weight",
)
]

# def handle_nd_tensor(self, key, data):
# # hacky but this is the only arch that uses it
# path = f"./fix_5d_tensors_{self.arch}.pt"
# if os.path.isfile(path):
# raise RuntimeError(f"5D tensor fix file already exists! {path}")
# fsd = {key: data}
# tqdm.write(f"5D key found in state dict! Manual fix required! - {key} {data.shape}")
# torch.save(fsd, path)

class ModelLTXV(ModelTemplate):
arch = "ltxv"
keys_detect = [
Expand Down Expand Up @@ -75,7 +97,7 @@ class ModelSD1(ModelTemplate):
]

# The architectures are checked in order and the first successful match terminates the search.
arch_list = [ModelFlux, ModelSD3, ModelAura, ModelLTXV, ModelSDXL, ModelSD1]
arch_list = [ModelFlux, ModelSD3, ModelAura, ModelLTXV, ModelHyVid, ModelSDXL, ModelSD1]

def is_model_arch(model, state_dict):
# check if model is correct
Expand All @@ -93,7 +115,7 @@ def detect_arch(state_dict):
model_arch = None
for arch in arch_list:
if is_model_arch(arch, state_dict):
model_arch = arch
model_arch = arch()
break
assert model_arch is not None, "Unknown model architecture!"
return model_arch
Expand All @@ -113,6 +135,8 @@ def load_state_dict(path):
if any(path.endswith(x) for x in [".ckpt", ".pt", ".bin", ".pth"]):
state_dict = torch.load(path, map_location="cpu", weights_only=True)
state_dict = state_dict.get("model", state_dict)
if len(state_dict) < 20:
raise RuntimeError(f"pt subkey load failed: {state_dict.keys()}")
else:
state_dict = load_file(path)

Expand Down Expand Up @@ -140,7 +164,7 @@ def load_model(path):
writer = gguf.GGUFWriter(path=None, arch=model_arch.arch)
return (writer, state_dict, model_arch)

def handle_tensors(args, writer, state_dict, model_arch):
def handle_tensors(writer, state_dict, model_arch):
name_lengths = tuple(sorted(
((key, len(key)) for key in state_dict.keys()),
key=lambda item: item[1],
Expand Down Expand Up @@ -170,23 +194,16 @@ def handle_tensors(args, writer, state_dict, model_arch):
"BF16" if old_dtype == torch.bfloat16 else "F16"
)

# The max no. of dimensions that can be handled by the quantization code is 4
if len(data.shape) > MAX_TENSOR_DIMS:
model_arch.handle_nd_tensor(key, data)
continue # needs to be added back later

# get number of parameters (AKA elements) in this tensor
n_params = 1
for dim_size in data_shape:
n_params *= dim_size

# keys to keep as max precision
blacklist = {
"time_embedding.",
"add_embedding.",
"time_in.",
"txt_in.",
"vector_in.",
"img_in.",
"guidance_in.",
"final_layer.",
}

if old_dtype in (torch.float32, torch.bfloat16):
if n_dims == 1:
# one-dimensional tensors should be kept in F32
Expand All @@ -197,9 +214,6 @@ def handle_tensors(args, writer, state_dict, model_arch):
# very small tensors
data_qtype = gguf.GGMLQuantizationType.F32

elif ".weight" in key and any(x in key for x in blacklist):
data_qtype = gguf.GGMLQuantizationType.F32

if (model_arch.shape_fix # NEVER reshape for models such as flux
and n_dims > 1 # Skip one-dimensional tensors
and n_params >= REARRANGE_THRESHOLD # Only rearrange tensors meeting the size requirement
Expand Down Expand Up @@ -241,7 +255,7 @@ def handle_tensors(args, writer, state_dict, model_arch):
if os.path.isfile(out_path):
input("Output exists enter to continue or ctrl+c to abort!")

handle_tensors(path, writer, state_dict, model_arch)
handle_tensors(writer, state_dict, model_arch)
writer.write_header_to_file(path=out_path)
writer.write_kv_data_to_file()
writer.write_tensors_to_file(progress=True)
Expand Down
46 changes: 46 additions & 0 deletions tools/fix_5d_tensors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import gguf
import torch
from tqdm import tqdm

arch = "hyvid" # TODO: Really should autodetect this
prec = "Q4_K_S" # edit manually for each step
src = fr"raw/hunyuan-video-t2v-720p-{prec}.gguf"
dst = fr"hunyuan-video-t2v-720p-{prec}.gguf"

sd5d = torch.load(f"./fix_5d_tensors_{arch}.pt")
print("5D:", sd5d.keys())

reader = gguf.GGUFReader(src)
writer = gguf.GGUFWriter(path=None, arch=arch)

writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
writer.add_file_type(getattr(gguf.LlamaFileType, f"MOSTLY_{prec}")) # TODO: also autodetect

added = []
def add_extra_key(writer, key, data):
old_dtype = data.dtype
data_qtype = gguf.GGMLQuantizationType.F32
n_dims = len(data.shape)
data_shape = data.shape
data = gguf.quants.quantize(data, data_qtype)
tqdm.write(f"Adding key {key} ({data_shape})")
writer.add_tensor(key, data, raw_dtype=data_qtype)
global added
added.append(key)

# main loop to add missing
for tensor in tqdm(reader.tensors):
writer.add_tensor(tensor.name, tensor.data, raw_dtype=tensor.tensor_type)
key5d = tensor.name.replace(".bias", ".weight")
if key5d in sd5d.keys():
add_extra_key(writer, key5d, sd5d[key5d])

# brute force for any missed
for key, data in sd5d.items():
if key not in added:
add_extra_key(writer, key, data)

writer.write_header_to_file(path=dst)
writer.write_kv_data_to_file()
writer.write_tensors_to_file(progress=True)
writer.close()
Loading