Skip to content

Commit 8d80802

Browse files
authored
improve support for V2 variant legacy checkpoints (#2926)
This commit enhances support for V2 variant (epsilon and v-predict) import and conversion to diffusers, by prompting the user to select the proper config file during startup time autoimport as well as in the invokeai installer script. Previously the user was only prompted when doing an `!import` from the command line or when using the WebUI Model Manager.
2 parents 61d5cb2 + 694925f commit 8d80802

File tree

3 files changed

+65
-22
lines changed

3 files changed

+65
-22
lines changed

ldm/invoke/CLI.py

+15-19
Original file line numberDiff line numberDiff line change
@@ -158,10 +158,16 @@ def main():
158158
except Exception as e:
159159
report_model_error(opt, e)
160160

161+
# completer is the readline object
162+
completer = get_completer(opt, models=gen.model_manager.list_models())
163+
161164
# try to autoconvert new models
162165
if path := opt.autoimport:
163166
gen.model_manager.heuristic_import(
164-
str(path), convert=False, commit_to_conf=opt.conf
167+
str(path),
168+
convert=False,
169+
commit_to_conf=opt.conf,
170+
config_file_callback=lambda x: _pick_configuration_file(completer,x),
165171
)
166172

167173
if path := opt.autoconvert:
@@ -180,7 +186,7 @@ def main():
180186
)
181187

182188
try:
183-
main_loop(gen, opt)
189+
main_loop(gen, opt, completer)
184190
except KeyboardInterrupt:
185191
print(
186192
f'\nGoodbye!\nYou can start InvokeAI again by running the "invoke.bat" (or "invoke.sh") script from {Globals.root}'
@@ -191,7 +197,7 @@ def main():
191197

192198

193199
# TODO: main_loop() has gotten busy. Needs to be refactored.
194-
def main_loop(gen, opt):
200+
def main_loop(gen, opt, completer):
195201
"""prompt/read/execute loop"""
196202
global infile
197203
done = False
@@ -202,7 +208,6 @@ def main_loop(gen, opt):
202208
# The readline completer reads history from the .dream_history file located in the
203209
# output directory specified at the time of script launch. We do not currently support
204210
# changing the history file midstream when the output directory is changed.
205-
completer = get_completer(opt, models=gen.model_manager.list_models())
206211
set_default_output_dir(opt, completer)
207212
if gen.model:
208213
add_embedding_terms(gen, completer)
@@ -661,17 +666,8 @@ def import_model(model_path: str, gen, opt, completer, convert=False):
661666
model_name=model_name,
662667
description=model_desc,
663668
convert=convert,
669+
config_file_callback=lambda x: _pick_configuration_file(completer,x),
664670
)
665-
666-
if not imported_name:
667-
if config_file := _pick_configuration_file(completer):
668-
imported_name = gen.model_manager.heuristic_import(
669-
model_path,
670-
model_name=model_name,
671-
description=model_desc,
672-
convert=convert,
673-
model_config_file=config_file,
674-
)
675671
if not imported_name:
676672
print("** Aborting import.")
677673
return
@@ -687,14 +683,14 @@ def import_model(model_path: str, gen, opt, completer, convert=False):
687683
completer.update_models(gen.model_manager.list_models())
688684
print(f">> {imported_name} successfully installed")
689685

690-
def _pick_configuration_file(completer)->Path:
686+
def _pick_configuration_file(completer, checkpoint_path: Path)->Path:
691687
print(
692-
"""
693-
Please select the type of this model:
688+
f"""
689+
Please select the type of the model at checkpoint {checkpoint_path}:
694690
[1] A Stable Diffusion v1.x ckpt/safetensors model
695691
[2] A Stable Diffusion v1.x inpainting ckpt/safetensors model
696-
[3] A Stable Diffusion v2.x base model (512 pixels)
697-
[4] A Stable Diffusion v2.x v-predictive model (768 pixels)
692+
[3] A Stable Diffusion v2.x base model (512 pixels; there should be no 'parameterization:' line in its yaml file)
693+
[4] A Stable Diffusion v2.x v-predictive model (768 pixels; look for a 'parameterization: "v"' line in its yaml file)
698694
[5] Other (you will be prompted to enter the config file path)
699695
[Q] I have no idea! Skip the import.
700696
""")

ldm/invoke/config/model_install_backend.py

+40
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def install_requested_models(
109109
model_manager.heuristic_import(
110110
path_url_or_repo,
111111
convert=convert_to_diffusers,
112+
config_file_callback=_pick_configuration_file,
112113
commit_to_conf=config_file_path
113114
)
114115
except KeyboardInterrupt:
@@ -138,6 +139,45 @@ def yes_or_no(prompt: str, default_yes=True):
138139
else:
139140
return response[0] in ("y", "Y")
140141

142+
# -------------------------------------
143+
def _pick_configuration_file(checkpoint_path: Path)->Path:
144+
print(
145+
"""
146+
Please select the type of this model:
147+
[1] A Stable Diffusion v1.x ckpt/safetensors model
148+
[2] A Stable Diffusion v1.x inpainting ckpt/safetensors model
149+
[3] A Stable Diffusion v2.x base model (512 pixels; no 'parameterization:' in its yaml file)
150+
[4] A Stable Diffusion v2.x v-predictive model (768 pixels; look for 'parameterization: "v"' in its yaml file)
151+
[5] Other (you will be prompted to enter the config file path)
152+
[Q] I have no idea! Skip the import.
153+
""")
154+
choices = [
155+
global_config_dir() / 'stable-diffusion' / x
156+
for x in [
157+
'v1-inference.yaml',
158+
'v1-inpainting-inference.yaml',
159+
'v2-inference.yaml',
160+
'v2-inference-v.yaml',
161+
]
162+
]
163+
164+
ok = False
165+
while not ok:
166+
try:
167+
choice = input('select 0-5, Q > ').strip()
168+
if choice.startswith(('q','Q')):
169+
return
170+
if choice == '5':
171+
choice = Path(input('Select config file for this model> ').strip()).absolute()
172+
ok = choice.exists()
173+
else:
174+
choice = choices[int(choice)-1]
175+
ok = True
176+
except (ValueError, IndexError):
177+
print(f'{choice} is not a valid choice')
178+
except EOFError:
179+
return
180+
return choice
141181

142182
# -------------------------------------
143183
def get_root(root: str = None) -> str:

ldm/invoke/model_manager.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from enum import Enum
2020
from pathlib import Path
2121
from shutil import move, rmtree
22-
from typing import Any, Optional, Union
22+
from typing import Any, Optional, Union, Callable
2323

2424
import safetensors
2525
import safetensors.torch
@@ -765,6 +765,7 @@ def heuristic_import(
765765
description: str = None,
766766
model_config_file: Path = None,
767767
commit_to_conf: Path = None,
768+
config_file_callback: Callable[[Path],Path] = None,
768769
) -> str:
769770
"""
770771
Accept a string which could be:
@@ -838,7 +839,10 @@ def heuristic_import(
838839
Path(thing).rglob("*.safetensors")
839840
):
840841
if model_name := self.heuristic_import(
841-
str(m), convert, commit_to_conf=commit_to_conf
842+
str(m),
843+
convert,
844+
commit_to_conf=commit_to_conf,
845+
config_file_callback=config_file_callback,
842846
):
843847
print(f" >> {model_name} successfully imported")
844848
return model_name
@@ -901,11 +905,14 @@ def heuristic_import(
901905
print(
902906
f"** {thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
903907
)
904-
return
905908
else:
906909
print(
907910
f"** {thing} is a legacy checkpoint file but not a known Stable Diffusion model. Please provide configuration file path."
908911
)
912+
913+
if not model_config_file and config_file_callback:
914+
model_config_file = config_file_callback(model_path)
915+
if not model_config_file:
909916
return
910917

911918
if model_config_file.name.startswith('v2'):

0 commit comments

Comments
 (0)