Skip to content

Commit 694925f

Browse files
committed
improve support for V2 variant legacy checkpoints
This commit enhances support for V2 variant (epsilon and v-predict) import and conversion to diffusers, by prompting the user to select the proper config file during startup time autoimport as well as in the invokeai installer script..
1 parent 61d5cb2 commit 694925f

File tree

3 files changed

+65
-22
lines changed

3 files changed

+65
-22
lines changed

ldm/invoke/CLI.py

+15-19
Original file line numberDiff line numberDiff line change
@@ -158,10 +158,16 @@ def main():
158158
except Exception as e:
159159
report_model_error(opt, e)
160160

161+
# completer is the readline object
162+
completer = get_completer(opt, models=gen.model_manager.list_models())
163+
161164
# try to autoconvert new models
162165
if path := opt.autoimport:
163166
gen.model_manager.heuristic_import(
164-
str(path), convert=False, commit_to_conf=opt.conf
167+
str(path),
168+
convert=False,
169+
commit_to_conf=opt.conf,
170+
config_file_callback=lambda x: _pick_configuration_file(completer,x),
165171
)
166172

167173
if path := opt.autoconvert:
@@ -180,7 +186,7 @@ def main():
180186
)
181187

182188
try:
183-
main_loop(gen, opt)
189+
main_loop(gen, opt, completer)
184190
except KeyboardInterrupt:
185191
print(
186192
f'\nGoodbye!\nYou can start InvokeAI again by running the "invoke.bat" (or "invoke.sh") script from {Globals.root}'
@@ -191,7 +197,7 @@ def main():
191197

192198

193199
# TODO: main_loop() has gotten busy. Needs to be refactored.
194-
def main_loop(gen, opt):
200+
def main_loop(gen, opt, completer):
195201
"""prompt/read/execute loop"""
196202
global infile
197203
done = False
@@ -202,7 +208,6 @@ def main_loop(gen, opt):
202208
# The readline completer reads history from the .dream_history file located in the
203209
# output directory specified at the time of script launch. We do not currently support
204210
# changing the history file midstream when the output directory is changed.
205-
completer = get_completer(opt, models=gen.model_manager.list_models())
206211
set_default_output_dir(opt, completer)
207212
if gen.model:
208213
add_embedding_terms(gen, completer)
@@ -661,17 +666,8 @@ def import_model(model_path: str, gen, opt, completer, convert=False):
661666
model_name=model_name,
662667
description=model_desc,
663668
convert=convert,
669+
config_file_callback=lambda x: _pick_configuration_file(completer,x),
664670
)
665-
666-
if not imported_name:
667-
if config_file := _pick_configuration_file(completer):
668-
imported_name = gen.model_manager.heuristic_import(
669-
model_path,
670-
model_name=model_name,
671-
description=model_desc,
672-
convert=convert,
673-
model_config_file=config_file,
674-
)
675671
if not imported_name:
676672
print("** Aborting import.")
677673
return
@@ -687,14 +683,14 @@ def import_model(model_path: str, gen, opt, completer, convert=False):
687683
completer.update_models(gen.model_manager.list_models())
688684
print(f">> {imported_name} successfully installed")
689685

690-
def _pick_configuration_file(completer)->Path:
686+
def _pick_configuration_file(completer, checkpoint_path: Path)->Path:
691687
print(
692-
"""
693-
Please select the type of this model:
688+
f"""
689+
Please select the type of the model at checkpoint {checkpoint_path}:
694690
[1] A Stable Diffusion v1.x ckpt/safetensors model
695691
[2] A Stable Diffusion v1.x inpainting ckpt/safetensors model
696-
[3] A Stable Diffusion v2.x base model (512 pixels)
697-
[4] A Stable Diffusion v2.x v-predictive model (768 pixels)
692+
[3] A Stable Diffusion v2.x base model (512 pixels; there should be no 'parameterization:' line in its yaml file)
693+
[4] A Stable Diffusion v2.x v-predictive model (768 pixels; look for a 'parameterization: "v"' line in its yaml file)
698694
[5] Other (you will be prompted to enter the config file path)
699695
[Q] I have no idea! Skip the import.
700696
""")

ldm/invoke/config/model_install_backend.py

+40
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def install_requested_models(
109109
model_manager.heuristic_import(
110110
path_url_or_repo,
111111
convert=convert_to_diffusers,
112+
config_file_callback=_pick_configuration_file,
112113
commit_to_conf=config_file_path
113114
)
114115
except KeyboardInterrupt:
@@ -138,6 +139,45 @@ def yes_or_no(prompt: str, default_yes=True):
138139
else:
139140
return response[0] in ("y", "Y")
140141

142+
# -------------------------------------
143+
def _pick_configuration_file(checkpoint_path: Path)->Path:
144+
print(
145+
"""
146+
Please select the type of this model:
147+
[1] A Stable Diffusion v1.x ckpt/safetensors model
148+
[2] A Stable Diffusion v1.x inpainting ckpt/safetensors model
149+
[3] A Stable Diffusion v2.x base model (512 pixels; no 'parameterization:' in its yaml file)
150+
[4] A Stable Diffusion v2.x v-predictive model (768 pixels; look for 'parameterization: "v"' in its yaml file)
151+
[5] Other (you will be prompted to enter the config file path)
152+
[Q] I have no idea! Skip the import.
153+
""")
154+
choices = [
155+
global_config_dir() / 'stable-diffusion' / x
156+
for x in [
157+
'v1-inference.yaml',
158+
'v1-inpainting-inference.yaml',
159+
'v2-inference.yaml',
160+
'v2-inference-v.yaml',
161+
]
162+
]
163+
164+
ok = False
165+
while not ok:
166+
try:
167+
choice = input('select 0-5, Q > ').strip()
168+
if choice.startswith(('q','Q')):
169+
return
170+
if choice == '5':
171+
choice = Path(input('Select config file for this model> ').strip()).absolute()
172+
ok = choice.exists()
173+
else:
174+
choice = choices[int(choice)-1]
175+
ok = True
176+
except (ValueError, IndexError):
177+
print(f'{choice} is not a valid choice')
178+
except EOFError:
179+
return
180+
return choice
141181

142182
# -------------------------------------
143183
def get_root(root: str = None) -> str:

ldm/invoke/model_manager.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from enum import Enum
2020
from pathlib import Path
2121
from shutil import move, rmtree
22-
from typing import Any, Optional, Union
22+
from typing import Any, Optional, Union, Callable
2323

2424
import safetensors
2525
import safetensors.torch
@@ -765,6 +765,7 @@ def heuristic_import(
765765
description: str = None,
766766
model_config_file: Path = None,
767767
commit_to_conf: Path = None,
768+
config_file_callback: Callable[[Path],Path] = None,
768769
) -> str:
769770
"""
770771
Accept a string which could be:
@@ -838,7 +839,10 @@ def heuristic_import(
838839
Path(thing).rglob("*.safetensors")
839840
):
840841
if model_name := self.heuristic_import(
841-
str(m), convert, commit_to_conf=commit_to_conf
842+
str(m),
843+
convert,
844+
commit_to_conf=commit_to_conf,
845+
config_file_callback=config_file_callback,
842846
):
843847
print(f" >> {model_name} successfully imported")
844848
return model_name
@@ -901,11 +905,14 @@ def heuristic_import(
901905
print(
902906
f"** {thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
903907
)
904-
return
905908
else:
906909
print(
907910
f"** {thing} is a legacy checkpoint file but not a known Stable Diffusion model. Please provide configuration file path."
908911
)
912+
913+
if not model_config_file and config_file_callback:
914+
model_config_file = config_file_callback(model_path)
915+
if not model_config_file:
909916
return
910917

911918
if model_config_file.name.startswith('v2'):

0 commit comments

Comments
 (0)