@@ -1868,16 +1868,23 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
1868
1868
} else if (sd_ctx->sd ->version == VERSION_FLEX_2) {
1869
1869
mask_channels = 1 + init_latent->ne [2 ];
1870
1870
}
1871
- ggml_tensor* masked_img = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32, width, height, 3 , 1 );
1872
- // Restore init_img (encode_first_stage has side effects) TODO: remove the side effects?
1873
- sd_image_to_tensor (init_image.data , init_img);
1874
- sd_apply_mask (init_img, mask_img, masked_img);
1875
1871
ggml_tensor* masked_latent = NULL ;
1876
- if (!sd_ctx->sd ->use_tiny_autoencoder ) {
1877
- ggml_tensor* moments = sd_ctx->sd ->encode_first_stage (work_ctx, masked_img);
1878
- masked_latent = sd_ctx->sd ->get_first_stage_encoding (work_ctx, moments);
1872
+ if (sd_ctx->sd ->version != VERSION_FLEX_2) {
1873
+ // most inpaint models mask before vae
1874
+ ggml_tensor* masked_img = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32, width, height, 3 , 1 );
1875
+ // Restore init_img (encode_first_stage has side effects) TODO: remove the side effects?
1876
+ sd_image_to_tensor (init_image.data , init_img);
1877
+ sd_apply_mask (init_img, mask_img, masked_img);
1878
+ if (!sd_ctx->sd ->use_tiny_autoencoder ) {
1879
+ ggml_tensor* moments = sd_ctx->sd ->encode_first_stage (work_ctx, masked_img);
1880
+ masked_latent = sd_ctx->sd ->get_first_stage_encoding (work_ctx, moments);
1881
+ } else {
1882
+ masked_latent = sd_ctx->sd ->encode_first_stage (work_ctx, masked_img);
1883
+ }
1879
1884
} else {
1880
- masked_latent = sd_ctx->sd ->encode_first_stage (work_ctx, masked_img);
1885
+ // mask after vae
1886
+ masked_latent = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32, init_latent->ne [0 ], init_latent->ne [1 ], init_latent->ne [2 ], 1 );
1887
+ sd_apply_mask (init_latent, mask_img, masked_latent, 0 .);
1881
1888
}
1882
1889
concat_latent = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32, masked_latent->ne [0 ], masked_latent->ne [1 ], mask_channels + masked_latent->ne [2 ], 1 );
1883
1890
for (int ix = 0 ; ix < masked_latent->ne [0 ]; ix++) {
0 commit comments