Skip to content

Commit d48544e

Browse files
committed
Update.
1 parent 0dbb51a commit d48544e

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,12 @@ def build(self, input_shape):
5050

5151
def generate_preprocess(self, x):
5252
token_ids = {}
53-
token_ids["clip_l"] = self.clip_l_preprocessor(x)["token_ids"]
54-
token_ids["clip_g"] = self.clip_g_preprocessor(x)["token_ids"]
53+
token_ids["clip_l"] = self.clip_l_preprocessor(
54+
{"prompts": x, "images": None}
55+
)["token_ids"]
56+
token_ids["clip_g"] = self.clip_g_preprocessor(
57+
{"prompts": x, "images": None}
58+
)["token_ids"]
5559
if self.t5_preprocessor is not None:
5660
token_ids["t5"] = self.t5_preprocessor(x)["token_ids"]
5761
return token_ids

tools/checkpoint_conversion/convert_stable_diffusion_3_checkpoints.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,9 @@ def validate_output(preset, keras_model, keras_preprocessor, output_dir):
649649
config = PRESET_MAP[preset]
650650
dtype = config["dtype"]
651651
hf_repo_id = config["root"].replace("hf://", "", 1)
652+
if preset == "stable_diffusion_3_medium":
653+
hf_repo_id += "-diffusers"
654+
652655
if dtype == "float16":
653656
torch_dtype = torch.float16
654657
elif dtype == "bfloat16":

0 commit comments

Comments
 (0)