|
1 | 1 | # Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
|
2 | 2 | """Class for Flux model loading in InvokeAI."""
|
3 | 3 |
|
| 4 | +import accelerate |
| 5 | +import torch |
4 | 6 | from dataclasses import fields
|
5 | 7 | from pathlib import Path
|
6 | 8 | from typing import Any, Optional
|
|
24 | 26 | CheckpointConfigBase,
|
25 | 27 | CLIPEmbedDiffusersConfig,
|
26 | 28 | MainCheckpointConfig,
|
| 29 | + MainBnbQuantized4bCheckpointConfig, |
27 | 30 | T5EncoderConfig,
|
28 | 31 | VAECheckpointConfig,
|
29 | 32 | )
|
30 | 33 | from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
31 | 34 | from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
32 | 35 | from invokeai.backend.util.devices import TorchDevice
|
33 | 36 | from invokeai.backend.util.silence_warnings import SilenceWarnings
|
| 37 | +from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4 |
34 | 38 |
|
35 | 39 | app_config = get_config()
|
36 | 40 |
|
@@ -62,7 +66,7 @@ def _load_model(
|
62 | 66 | with SilenceWarnings():
|
63 | 67 | model = load_class(params).to(self._torch_dtype)
|
64 | 68 | # load_sft doesn't support torch.device
|
65 |
| - sd = load_file(model_path, device=str(TorchDevice.choose_torch_device())) |
| 69 | + sd = load_file(model_path) |
66 | 70 | model.load_state_dict(sd, strict=False, assign=True)
|
67 | 71 |
|
68 | 72 | return model
|
@@ -105,9 +109,9 @@ def _load_model(
|
105 | 109 |
|
106 | 110 | match submodel_type:
|
107 | 111 | case SubModelType.Tokenizer2:
|
108 |
| - return T5Tokenizer.from_pretrained(Path(config.path) / "encoder", max_length=512) |
| 112 | + return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512) |
109 | 113 | case SubModelType.TextEncoder2:
|
110 |
| - return T5EncoderModel.from_pretrained(Path(config.path) / "tokenizer") |
| 114 | + return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2") #TODO: Fix hf subfolder install |
111 | 115 |
|
112 | 116 | raise Exception("Only Checkpoint Flux models are currently supported.")
|
113 | 117 |
|
@@ -152,7 +156,55 @@ def _load_from_singlefile(
|
152 | 156 |
|
153 | 157 | with SilenceWarnings():
|
154 | 158 | model = load_class(params).to(self._torch_dtype)
|
155 |
| - # load_sft doesn't support torch.device |
156 |
| - sd = load_file(model_path, device=str(TorchDevice.choose_torch_device())) |
| 159 | + sd = load_file(model_path) |
| 160 | + model.load_state_dict(sd, strict=False, assign=True) |
| 161 | + return model |
| 162 | + |
| 163 | + |
| 164 | +@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.BnbQuantizednf4b) |
| 165 | +class FluxBnbQuantizednf4bCheckpointModel(GenericDiffusersLoader): |
| 166 | + """Class to load main models.""" |
| 167 | + |
| 168 | + def _load_model( |
| 169 | + self, |
| 170 | + config: AnyModelConfig, |
| 171 | + submodel_type: Optional[SubModelType] = None, |
| 172 | + ) -> AnyModel: |
| 173 | + if not isinstance(config, CheckpointConfigBase): |
| 174 | + raise Exception("Only Checkpoint Flux models are currently supported.") |
| 175 | + legacy_config_path = app_config.legacy_conf_path / config.config_path |
| 176 | + config_path = legacy_config_path.as_posix() |
| 177 | + with open(config_path, "r") as stream: |
| 178 | + try: |
| 179 | + flux_conf = yaml.safe_load(stream) |
| 180 | + except: |
| 181 | + raise |
| 182 | + |
| 183 | + match submodel_type: |
| 184 | + case SubModelType.Transformer: |
| 185 | + return self._load_from_singlefile(config, flux_conf) |
| 186 | + |
| 187 | + raise Exception("Only Checkpoint Flux models are currently supported.") |
| 188 | + |
| 189 | + def _load_from_singlefile( |
| 190 | + self, |
| 191 | + config: AnyModelConfig, |
| 192 | + flux_conf: Any, |
| 193 | + ) -> AnyModel: |
| 194 | + assert isinstance(config, MainBnbQuantized4bCheckpointConfig) |
| 195 | + load_class = Flux |
| 196 | + params = None |
| 197 | + model_path = Path(config.path) |
| 198 | + dataclass_fields = {f.name for f in fields(FluxParams)} |
| 199 | + filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields} |
| 200 | + params = FluxParams(**filtered_data) |
| 201 | + |
| 202 | + with SilenceWarnings(): |
| 203 | + with accelerate.init_empty_weights(): |
| 204 | + model = load_class(params) |
| 205 | + model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16) |
| 206 | + # TODO(ryand): Right now, some of the weights are loaded in bfloat16. Think about how best to handle |
| 207 | + # this on GPUs without bfloat16 support. |
| 208 | + sd = load_file(model_path) |
157 | 209 | model.load_state_dict(sd, strict=False, assign=True)
|
158 | 210 | return model
|
0 commit comments