Skip to content

Commit dd4ee77

Browse files
committed
Fix Flex 2 inpaint
1 parent d1d7420 commit dd4ee77

File tree

2 files changed

+32
-12
lines changed

2 files changed

+32
-12
lines changed

ggml_extend.hpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -380,18 +380,31 @@ __STATIC_INLINE__ void sd_mask_to_tensor(const uint8_t* image_data,
380380

381381
__STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data,
382382
struct ggml_tensor* mask,
383-
struct ggml_tensor* output) {
383+
struct ggml_tensor* output,
384+
float masked_value = 0.5f) {
384385
int64_t width = output->ne[0];
385386
int64_t height = output->ne[1];
386387
int64_t channels = output->ne[2];
388+
for (int ix = 0; ix < mask->ne[0]; ix++) {
389+
for (int iy = 0; iy < mask->ne[1]; iy++) {
390+
float m = ggml_tensor_get_f32(mask, ix, iy);
391+
m = round(m); // inpaint models need binary masks
392+
ggml_tensor_set_f32(mask, m, ix, iy);
393+
}
394+
}
395+
float rescale_mx = mask->ne[0]/output->ne[0];
396+
float rescale_my = mask->ne[1]/output->ne[1];
387397
GGML_ASSERT(output->type == GGML_TYPE_F32);
388398
for (int ix = 0; ix < width; ix++) {
389399
for (int iy = 0; iy < height; iy++) {
390-
float m = ggml_tensor_get_f32(mask, ix, iy);
400+
int mx = (int)(ix * rescale_mx);
401+
int my = (int)(iy * rescale_my);
402+
float m = ggml_tensor_get_f32(mask, mx, my);
391403
m = round(m); // inpaint models need binary masks
392-
ggml_tensor_set_f32(mask, m, ix, iy);
404+
ggml_tensor_set_f32(mask, m, mx, my);
393405
for (int k = 0; k < channels; k++) {
394-
float value = (1 - m) * (ggml_tensor_get_f32(image_data, ix, iy, k) - .5) + .5;
406+
float value = ggml_tensor_get_f32(image_data, ix, iy, k);
407+
value = (1 - m) * (value - masked_value) + masked_value;
395408
ggml_tensor_set_f32(output, value, ix, iy, k);
396409
}
397410
}

stable-diffusion.cpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1868,16 +1868,23 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
18681868
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
18691869
mask_channels = 1 + init_latent->ne[2];
18701870
}
1871-
ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
1872-
// Restore init_img (encode_first_stage has side effects) TODO: remove the side effects?
1873-
sd_image_to_tensor(init_image.data, init_img);
1874-
sd_apply_mask(init_img, mask_img, masked_img);
18751871
ggml_tensor* masked_latent = NULL;
1876-
if (!sd_ctx->sd->use_tiny_autoencoder) {
1877-
ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
1878-
masked_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments);
1872+
if (sd_ctx->sd->version != VERSION_FLEX_2) {
1873+
// most inpaint models mask before vae
1874+
ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
1875+
// Restore init_img (encode_first_stage has side effects) TODO: remove the side effects?
1876+
sd_image_to_tensor(init_image.data, init_img);
1877+
sd_apply_mask(init_img, mask_img, masked_img);
1878+
if (!sd_ctx->sd->use_tiny_autoencoder) {
1879+
ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
1880+
masked_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments);
1881+
} else {
1882+
masked_latent = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
1883+
}
18791884
} else {
1880-
masked_latent = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
1885+
// mask after vae
1886+
masked_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2], 1);
1887+
sd_apply_mask(init_latent, mask_img, masked_latent, 0.);
18811888
}
18821889
concat_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, masked_latent->ne[0], masked_latent->ne[1], mask_channels + masked_latent->ne[2], 1);
18831890
for (int ix = 0; ix < masked_latent->ne[0]; ix++) {

0 commit comments

Comments
 (0)