@@ -1821,16 +1821,23 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
1821
1821
} else if (sd_ctx->sd ->version == VERSION_FLEX_2) {
1822
1822
mask_channels = 1 + init_latent->ne [2 ];
1823
1823
}
1824
- ggml_tensor* masked_img = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32, width, height, 3 , 1 );
1825
- // Restore init_img (encode_first_stage has side effects) TODO: remove the side effects?
1826
- sd_image_to_tensor (init_image.data , init_img);
1827
- sd_apply_mask (init_img, mask_img, masked_img);
1828
1824
ggml_tensor* masked_latent = NULL ;
1829
- if (!sd_ctx->sd ->use_tiny_autoencoder ) {
1830
- ggml_tensor* moments = sd_ctx->sd ->encode_first_stage (work_ctx, masked_img);
1831
- masked_latent = sd_ctx->sd ->get_first_stage_encoding (work_ctx, moments);
1825
+ if (sd_ctx->sd ->version != VERSION_FLEX_2) {
1826
+ // most inpaint models mask before vae
1827
+ ggml_tensor* masked_img = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32, width, height, 3 , 1 );
1828
+ // Restore init_img (encode_first_stage has side effects) TODO: remove the side effects?
1829
+ sd_image_to_tensor (init_image.data , init_img);
1830
+ sd_apply_mask (init_img, mask_img, masked_img);
1831
+ if (!sd_ctx->sd ->use_tiny_autoencoder ) {
1832
+ ggml_tensor* moments = sd_ctx->sd ->encode_first_stage (work_ctx, masked_img);
1833
+ masked_latent = sd_ctx->sd ->get_first_stage_encoding (work_ctx, moments);
1834
+ } else {
1835
+ masked_latent = sd_ctx->sd ->encode_first_stage (work_ctx, masked_img);
1836
+ }
1832
1837
} else {
1833
- masked_latent = sd_ctx->sd ->encode_first_stage (work_ctx, masked_img);
1838
+ // mask after vae
1839
+ masked_latent = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32, init_latent->ne [0 ], init_latent->ne [1 ], init_latent->ne [2 ], 1 );
1840
+ sd_apply_mask (init_latent, mask_img, masked_latent, 0 .);
1834
1841
}
1835
1842
concat_latent = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32, masked_latent->ne [0 ], masked_latent->ne [1 ], mask_channels + masked_latent->ne [2 ], 1 );
1836
1843
for (int ix = 0 ; ix < masked_latent->ne [0 ]; ix++) {
0 commit comments