Skip to content

Commit 00b63af

Browse files
committed
Working inference node with quantized bnb nf4 checkpoint
1 parent a1c6213 commit 00b63af

File tree

2 files changed

+65
-11
lines changed

2 files changed

+65
-11
lines changed

invokeai/app/invocations/flux_text_to_image.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def _run_diffusion(
8989
img, img_ids = self._prepare_latent_img_patches(x)
9090

9191
# HACK(ryand): Find a better way to determine if this is a schnell model or not.
92-
is_schnell = "shnell" in transformer_info.config.path if transformer_info.config else ""
92+
is_schnell = "schnell" in transformer_info.config.path if transformer_info.config else ""
9393
timesteps = get_schedule(
9494
num_steps=self.num_steps,
9595
image_seq_len=img.shape[1],
@@ -139,9 +139,9 @@ def _prepare_latent_img_patches(self, latent_img: torch.Tensor) -> tuple[torch.T
139139
img = repeat(img, "1 ... -> bs ...", bs=bs)
140140

141141
# Generate patch position ids.
142-
img_ids = torch.zeros(h // 2, w // 2, 3)
143-
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
144-
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
142+
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device)
143+
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device)[:, None]
144+
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device)[None, :]
145145
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
146146

147147
return img, img_ids
@@ -155,8 +155,10 @@ def _run_vae_decoding(
155155
with vae_info as vae:
156156
assert isinstance(vae, AutoEncoder)
157157
# TODO(ryand): Test that this works with both float16 and bfloat16.
158-
with torch.autocast(device_type=latents.device.type, dtype=TorchDevice.choose_torch_dtype()):
159-
img = vae.decode(latents)
158+
# with torch.autocast(device_type=latents.device.type, dtype=torch.float32):
159+
vae.to(torch.float32)
160+
latents.to(torch.float32)
161+
img = vae.decode(latents)
160162

161163
img.clamp(-1, 1)
162164
img = rearrange(img[0], "c h w -> h w c")

invokeai/backend/model_manager/load/model_loaders/flux.py

+57-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
22
"""Class for Flux model loading in InvokeAI."""
33

4+
import accelerate
5+
import torch
46
from dataclasses import fields
57
from pathlib import Path
68
from typing import Any, Optional
@@ -24,13 +26,15 @@
2426
CheckpointConfigBase,
2527
CLIPEmbedDiffusersConfig,
2628
MainCheckpointConfig,
29+
MainBnbQuantized4bCheckpointConfig,
2730
T5EncoderConfig,
2831
VAECheckpointConfig,
2932
)
3033
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
3134
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
3235
from invokeai.backend.util.devices import TorchDevice
3336
from invokeai.backend.util.silence_warnings import SilenceWarnings
37+
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
3438

3539
app_config = get_config()
3640

@@ -62,7 +66,7 @@ def _load_model(
6266
with SilenceWarnings():
6367
model = load_class(params).to(self._torch_dtype)
6468
# load_sft doesn't support torch.device
65-
sd = load_file(model_path, device=str(TorchDevice.choose_torch_device()))
69+
sd = load_file(model_path)
6670
model.load_state_dict(sd, strict=False, assign=True)
6771

6872
return model
@@ -105,9 +109,9 @@ def _load_model(
105109

106110
match submodel_type:
107111
case SubModelType.Tokenizer2:
108-
return T5Tokenizer.from_pretrained(Path(config.path) / "encoder", max_length=512)
112+
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
109113
case SubModelType.TextEncoder2:
110-
return T5EncoderModel.from_pretrained(Path(config.path) / "tokenizer")
114+
return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2") #TODO: Fix hf subfolder install
111115

112116
raise Exception("Only Checkpoint Flux models are currently supported.")
113117

@@ -152,7 +156,55 @@ def _load_from_singlefile(
152156

153157
with SilenceWarnings():
154158
model = load_class(params).to(self._torch_dtype)
155-
# load_sft doesn't support torch.device
156-
sd = load_file(model_path, device=str(TorchDevice.choose_torch_device()))
159+
sd = load_file(model_path)
160+
model.load_state_dict(sd, strict=False, assign=True)
161+
return model
162+
163+
164+
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.BnbQuantizednf4b)
165+
class FluxBnbQuantizednf4bCheckpointModel(GenericDiffusersLoader):
166+
"""Class to load main models."""
167+
168+
def _load_model(
169+
self,
170+
config: AnyModelConfig,
171+
submodel_type: Optional[SubModelType] = None,
172+
) -> AnyModel:
173+
if not isinstance(config, CheckpointConfigBase):
174+
raise Exception("Only Checkpoint Flux models are currently supported.")
175+
legacy_config_path = app_config.legacy_conf_path / config.config_path
176+
config_path = legacy_config_path.as_posix()
177+
with open(config_path, "r") as stream:
178+
try:
179+
flux_conf = yaml.safe_load(stream)
180+
except:
181+
raise
182+
183+
match submodel_type:
184+
case SubModelType.Transformer:
185+
return self._load_from_singlefile(config, flux_conf)
186+
187+
raise Exception("Only Checkpoint Flux models are currently supported.")
188+
189+
def _load_from_singlefile(
190+
self,
191+
config: AnyModelConfig,
192+
flux_conf: Any,
193+
) -> AnyModel:
194+
assert isinstance(config, MainBnbQuantized4bCheckpointConfig)
195+
load_class = Flux
196+
params = None
197+
model_path = Path(config.path)
198+
dataclass_fields = {f.name for f in fields(FluxParams)}
199+
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
200+
params = FluxParams(**filtered_data)
201+
202+
with SilenceWarnings():
203+
with accelerate.init_empty_weights():
204+
model = load_class(params)
205+
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
206+
# TODO(ryand): Right now, some of the weights are loaded in bfloat16. Think about how best to handle
207+
# this on GPUs without bfloat16 support.
208+
sd = load_file(model_path)
157209
model.load_state_dict(sd, strict=False, assign=True)
158210
return model

0 commit comments

Comments
 (0)