Skip to content

Commit e82eb35

Browse files
committed
Fix Flex 2 inpaint
1 parent b2c4087 commit e82eb35

File tree

2 files changed

+32
-12
lines changed

2 files changed

+32
-12
lines changed

ggml_extend.hpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -375,18 +375,31 @@ __STATIC_INLINE__ void sd_mask_to_tensor(const uint8_t* image_data,
375375

376376
__STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data,
377377
struct ggml_tensor* mask,
378-
struct ggml_tensor* output) {
378+
struct ggml_tensor* output,
379+
float masked_value = 0.5f) {
379380
int64_t width = output->ne[0];
380381
int64_t height = output->ne[1];
381382
int64_t channels = output->ne[2];
383+
for (int ix = 0; ix < mask->ne[0]; ix++) {
384+
for (int iy = 0; iy < mask->ne[1]; iy++) {
385+
float m = ggml_tensor_get_f32(mask, ix, iy);
386+
m = round(m); // inpaint models need binary masks
387+
ggml_tensor_set_f32(mask, m, ix, iy);
388+
}
389+
}
390+
float rescale_mx = mask->ne[0]/output->ne[0];
391+
float rescale_my = mask->ne[1]/output->ne[1];
382392
GGML_ASSERT(output->type == GGML_TYPE_F32);
383393
for (int ix = 0; ix < width; ix++) {
384394
for (int iy = 0; iy < height; iy++) {
385-
float m = ggml_tensor_get_f32(mask, ix, iy);
395+
int mx = (int)(ix * rescale_mx);
396+
int my = (int)(iy * rescale_my);
397+
float m = ggml_tensor_get_f32(mask, mx, my);
386398
m = round(m); // inpaint models need binary masks
387-
ggml_tensor_set_f32(mask, m, ix, iy);
399+
ggml_tensor_set_f32(mask, m, mx, my);
388400
for (int k = 0; k < channels; k++) {
389-
float value = (1 - m) * (ggml_tensor_get_f32(image_data, ix, iy, k) - .5) + .5;
401+
float value = ggml_tensor_get_f32(image_data, ix, iy, k);
402+
value = (1 - m) * (value - masked_value) + masked_value;
390403
ggml_tensor_set_f32(output, value, ix, iy, k);
391404
}
392405
}

stable-diffusion.cpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1821,16 +1821,23 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
18211821
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
18221822
mask_channels = 1 + init_latent->ne[2];
18231823
}
1824-
ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
1825-
// Restore init_img (encode_first_stage has side effects) TODO: remove the side effects?
1826-
sd_image_to_tensor(init_image.data, init_img);
1827-
sd_apply_mask(init_img, mask_img, masked_img);
18281824
ggml_tensor* masked_latent = NULL;
1829-
if (!sd_ctx->sd->use_tiny_autoencoder) {
1830-
ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
1831-
masked_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments);
1825+
if (sd_ctx->sd->version != VERSION_FLEX_2) {
1826+
// most inpaint models mask before vae
1827+
ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
1828+
// Restore init_img (encode_first_stage has side effects) TODO: remove the side effects?
1829+
sd_image_to_tensor(init_image.data, init_img);
1830+
sd_apply_mask(init_img, mask_img, masked_img);
1831+
if (!sd_ctx->sd->use_tiny_autoencoder) {
1832+
ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
1833+
masked_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments);
1834+
} else {
1835+
masked_latent = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
1836+
}
18321837
} else {
1833-
masked_latent = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
1838+
// mask after vae
1839+
masked_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2], 1);
1840+
sd_apply_mask(init_latent, mask_img, masked_latent, 0.);
18341841
}
18351842
concat_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, masked_latent->ne[0], masked_latent->ne[1], mask_channels + masked_latent->ne[2], 1);
18361843
for (int ix = 0; ix < masked_latent->ne[0]; ix++) {

0 commit comments

Comments
 (0)