Skip to content

Commit 63a6df3

Browse files
committed
Instruct-p2p support
1 parent 10c6501 commit 63a6df3

File tree

4 files changed

+24
-1
lines changed

4 files changed

+24
-1
lines changed

model.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1539,6 +1539,7 @@ SDVersion ModelLoader::get_sd_version() {
15391539
}
15401540
}
15411541
bool is_inpaint = input_block_weight.ne[2] == 9;
1542+
bool is_ip2p = input_block_weight.ne[2] == 8;
15421543
if (is_xl) {
15431544
if (is_inpaint) {
15441545
return VERSION_SDXL_INPAINT;
@@ -1558,6 +1559,9 @@ SDVersion ModelLoader::get_sd_version() {
15581559
if (is_inpaint) {
15591560
return VERSION_SD1_INPAINT;
15601561
}
1562+
if(is_ip2p) {
1563+
return VERSION_INSTRUCT_PIX2PIX;
1564+
}
15611565
return VERSION_SD1;
15621566
} else if (token_embedding_weight.ne[0] == 1024) {
15631567
if (is_inpaint) {

model.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
enum SDVersion {
2222
VERSION_SD1,
2323
VERSION_SD1_INPAINT,
24+
VERSION_INSTRUCT_PIX2PIX,
2425
VERSION_SD2,
2526
VERSION_SD2_INPAINT,
2627
VERSION_SDXL,
@@ -47,7 +48,7 @@ static inline bool sd_version_is_sd3(SDVersion version) {
4748
}
4849

4950
static inline bool sd_version_is_sd1(SDVersion version) {
50-
if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT) {
51+
if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT || version == VERSION_INSTRUCT_PIX2PIX) {
5152
return true;
5253
}
5354
return false;

stable-diffusion.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
const char* model_version_to_str[] = {
2828
"SD 1.x",
2929
"SD 1.x Inpaint",
30+
"Instruct-Pix2Pix",
3031
"SD 2.x",
3132
"SD 2.x Inpaint",
3233
"SDXL",
@@ -1430,9 +1431,16 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
14301431
}
14311432
cond.c_concat = masked_image;
14321433
uncond.c_concat = masked_image;
1434+
// noise_mask = masked_image;
1435+
} else if (sd_ctx->sd->version == VERSION_INSTRUCT_PIX2PIX) {
1436+
cond.c_concat = masked_image;
1437+
auto empty_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, masked_image->ne[0], masked_image->ne[1], masked_image->ne[2], masked_image->ne[3]);
1438+
ggml_set_f32(empty_img, 0);
1439+
uncond.c_concat = empty_img;
14331440
} else {
14341441
noise_mask = masked_image;
14351442
}
1443+
14361444
for (int b = 0; b < batch_count; b++) {
14371445
int64_t sampling_start = ggml_time_ms();
14381446
int64_t cur_seed = seed + b;
@@ -1745,6 +1753,14 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
17451753
}
17461754
}
17471755
}
1756+
} else if (sd_ctx->sd->version == VERSION_INSTRUCT_PIX2PIX) {
1757+
// Not actually masked, we're just highjacking the masked_image variable since it will be used the same way
1758+
if (!sd_ctx->sd->use_tiny_autoencoder) {
1759+
ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
1760+
masked_image = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments);
1761+
} else {
1762+
masked_image = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
1763+
}
17481764
} else {
17491765
// LOG_WARN("Inpainting with a base model is not great");
17501766
masked_image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width / 8, height / 8, 1, 1);

unet.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,8 @@ class UnetModelBlock : public GGMLBlock {
207207
}
208208
if (sd_version_is_inpaint(version)) {
209209
in_channels = 9;
210+
} else if (version == VERSION_INSTRUCT_PIX2PIX) {
211+
in_channels = 8;
210212
}
211213

212214
// dims is always 2

0 commit comments

Comments
 (0)