Skip to content

Commit d1d7420

Browse files
committed
support for flux controls
1 parent fb604b7 commit d1d7420

File tree

4 files changed

+55
-32
lines changed

4 files changed

+55
-32
lines changed

flux.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,6 +1032,14 @@ namespace Flux {
10321032
control = patchify(ctx, control, patch_size);
10331033

10341034
img = ggml_concat(ctx, img, ggml_concat(ctx, ggml_concat(ctx, masked, mask, 0), control, 0), 0);
1035+
} else if (version == VERSION_FLUX_CONTROLS) {
1036+
GGML_ASSERT(c_concat != NULL);
1037+
1038+
ggml_tensor* control = ggml_pad(ctx, c_concat, pad_w, pad_h, 0, 0);
1039+
1040+
control = patchify(ctx, control, patch_size);
1041+
1042+
img = ggml_concat(ctx, img, control, 0);
10351043
}
10361044

10371045
if (ref_latents.size() > 0) {
@@ -1079,6 +1087,8 @@ namespace Flux {
10791087
flux_params.depth_single_blocks = 0;
10801088
if (version == VERSION_FLUX_FILL) {
10811089
flux_params.in_channels = 384;
1090+
} else if (version == VERSION_FLUX_CONTROLS) {
1091+
flux_params.in_channels = 128;
10821092
} else if (version == VERSION_FLEX_2) {
10831093
flux_params.in_channels = 196;
10841094
}

model.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1685,10 +1685,12 @@ SDVersion ModelLoader::get_sd_version() {
16851685
}
16861686

16871687
if (is_flux) {
1688-
is_inpaint = input_block_weight.ne[0] == 384;
1689-
if (is_inpaint) {
1688+
if (input_block_weight.ne[0] == 384) {
16901689
return VERSION_FLUX_FILL;
16911690
}
1691+
if (input_block_weight.ne[0] == 128) {
1692+
return VERSION_FLUX_CONTROLS;
1693+
}
16921694
if(input_block_weight.ne[0] == 196){
16931695
return VERSION_FLEX_2;
16941696
}

model.h

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,13 @@ enum SDVersion {
3131
VERSION_SD3,
3232
VERSION_FLUX,
3333
VERSION_FLUX_FILL,
34+
VERSION_FLUX_CONTROLS,
3435
VERSION_FLEX_2,
3536
VERSION_COUNT,
3637
};
3738

3839
static inline bool sd_version_is_flux(SDVersion version) {
39-
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL || version == VERSION_FLEX_2 ) {
40+
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL || version == VERSION_FLUX_CONTROLS || version == VERSION_FLEX_2 ) {
4041
return true;
4142
}
4243
return false;
@@ -70,15 +71,16 @@ static inline bool sd_version_is_sdxl(SDVersion version) {
7071
return false;
7172
}
7273

73-
static inline bool sd_version_is_inpaint(SDVersion version) {
74-
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL || version == VERSION_FLEX_2) {
74+
75+
static inline bool sd_version_is_dit(SDVersion version) {
76+
if (sd_version_is_flux(version) || sd_version_is_sd3(version)) {
7577
return true;
7678
}
7779
return false;
7880
}
7981

80-
static inline bool sd_version_is_dit(SDVersion version) {
81-
if (sd_version_is_flux(version) || sd_version_is_sd3(version)) {
82+
static inline bool sd_version_is_inpaint(SDVersion version) {
83+
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL || version == VERSION_FLEX_2) {
8284
return true;
8385
}
8486
return false;
@@ -88,8 +90,12 @@ static inline bool sd_version_is_edit(SDVersion version) {
8890
return version == VERSION_SD1_PIX2PIX || version == VERSION_SDXL_PIX2PIX;
8991
}
9092

93+
static inline bool sd_version_is_control(SDVersion version) {
94+
return version == VERSION_FLUX_CONTROLS || version == VERSION_FLEX_2;
95+
}
96+
9197
static bool sd_version_use_concat(SDVersion version) {
92-
return sd_version_is_edit(version) || sd_version_is_inpaint(version);
98+
return sd_version_is_edit(version) || sd_version_is_inpaint(version)|| sd_version_is_control(version);
9399
}
94100

95101
enum PMVersion {

stable-diffusion.cpp

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ class StableDiffusionGGML {
314314
// TODO: shift_factor
315315
}
316316

317-
if(version == VERSION_FLEX_2){
317+
if (sd_version_is_control(version)) {
318318
// Might need vae encode for control cond
319319
vae_decode_only = false;
320320
}
@@ -840,7 +840,7 @@ class StableDiffusionGGML {
840840
int start_merge_step,
841841
SDCondition id_cond,
842842
std::vector<ggml_tensor*> ref_latents = {},
843-
ggml_tensor* denoise_mask = nullptr) {
843+
ggml_tensor* denoise_mask = nullptr) {
844844
std::vector<int> skip_layers(guidance.slg.layers, guidance.slg.layers + guidance.slg.layer_count);
845845

846846
// TODO (Pix2Pix): separate image guidance params (right now it's reusing distilled guidance)
@@ -1512,6 +1512,17 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
15121512
int W = width / 8;
15131513
int H = height / 8;
15141514
LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]);
1515+
1516+
struct ggml_tensor* control_latent = NULL;
1517+
if (sd_version_is_control(sd_ctx->sd->version) && image_hint != NULL) {
1518+
if (!sd_ctx->sd->use_tiny_autoencoder) {
1519+
struct ggml_tensor* control_moments = sd_ctx->sd->encode_first_stage(work_ctx, image_hint);
1520+
control_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, control_moments);
1521+
} else {
1522+
control_latent = sd_ctx->sd->encode_first_stage(work_ctx, image_hint);
1523+
}
1524+
}
1525+
15151526
if (sd_version_is_inpaint(sd_ctx->sd->version)) {
15161527
int64_t mask_channels = 1;
15171528
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
@@ -1544,50 +1555,44 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
15441555
}
15451556
}
15461557
}
1547-
if (sd_ctx->sd->version == VERSION_FLEX_2 && image_hint != NULL && sd_ctx->sd->control_net == NULL) {
1558+
1559+
if (sd_ctx->sd->version == VERSION_FLEX_2 && control_latent != NULL && sd_ctx->sd->control_net == NULL) {
15481560
bool no_inpaint = concat_latent == NULL;
15491561
if (no_inpaint) {
15501562
concat_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], mask_channels + init_latent->ne[2], 1);
15511563
}
15521564
// fill in the control image here
1553-
struct ggml_tensor* control_latents = NULL;
1554-
if (!sd_ctx->sd->use_tiny_autoencoder) {
1555-
struct ggml_tensor* control_moments = sd_ctx->sd->encode_first_stage(work_ctx, image_hint);
1556-
control_latents = sd_ctx->sd->get_first_stage_encoding(work_ctx, control_moments);
1557-
} else {
1558-
control_latents = sd_ctx->sd->encode_first_stage(work_ctx, image_hint);
1559-
}
1560-
for (int64_t x = 0; x < concat_latent->ne[0]; x++) {
1561-
for (int64_t y = 0; y < concat_latent->ne[1]; y++) {
1565+
for (int64_t x = 0; x < control_latent->ne[0]; x++) {
1566+
for (int64_t y = 0; y < control_latent->ne[1]; y++) {
15621567
if (no_inpaint) {
1563-
for (int64_t c = 0; c < concat_latent->ne[2] - control_latents->ne[2]; c++) {
1568+
for (int64_t c = 0; c < concat_latent->ne[2] - control_latent->ne[2]; c++) {
15641569
// 0x16,1x1,0x16
15651570
ggml_tensor_set_f32(concat_latent, c == init_latent->ne[2], x, y, c);
15661571
}
15671572
}
1568-
for (int64_t c = 0; c < control_latents->ne[2]; c++) {
1569-
float v = ggml_tensor_get_f32(control_latents, x, y, c);
1570-
ggml_tensor_set_f32(concat_latent, v, x, y, concat_latent->ne[2] - control_latents->ne[2] + c);
1573+
for (int64_t c = 0; c < control_latent->ne[2]; c++) {
1574+
float v = ggml_tensor_get_f32(control_latent, x, y, c);
1575+
ggml_tensor_set_f32(concat_latent, v, x, y, concat_latent->ne[2] - control_latent->ne[2] + c);
15711576
}
15721577
}
15731578
}
1574-
// Disable controlnet
1575-
image_hint = NULL;
15761579
} else if (concat_latent == NULL) {
15771580
concat_latent = empty_latent;
15781581
}
15791582
cond.c_concat = concat_latent;
15801583
uncond.c_concat = empty_latent;
15811584
denoise_mask = NULL;
1582-
} else if (sd_version_is_edit(sd_ctx->sd->version)) {
1585+
} else if (sd_version_is_edit(sd_ctx->sd->version) || sd_version_is_control(sd_ctx->sd->version)) {
15831586
auto empty_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2], init_latent->ne[3]);
15841587
ggml_set_f32(empty_latent, 0);
15851588
uncond.c_concat = empty_latent;
1589+
if (sd_version_is_control(sd_ctx->sd->version) && control_latent != NULL && sd_ctx->sd->control_net == NULL) {
1590+
concat_latent = control_latent;
1591+
}
15861592
if (concat_latent == NULL) {
15871593
concat_latent = empty_latent;
15881594
}
1589-
cond.c_concat = concat_latent;
1590-
1595+
cond.c_concat = concat_latent;
15911596
}
15921597
for (int b = 0; b < batch_count; b++) {
15931598
int64_t sampling_start = ggml_time_ms();
@@ -1870,7 +1875,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
18701875
ggml_tensor* masked_latent = NULL;
18711876
if (!sd_ctx->sd->use_tiny_autoencoder) {
18721877
ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
1873-
masked_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments);
1878+
masked_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments);
18741879
} else {
18751880
masked_latent = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
18761881
}
@@ -1941,8 +1946,8 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
19411946
} else {
19421947
concat_latent = init_latent;
19431948
}
1944-
}
1945-
1949+
}
1950+
19461951
{
19471952
// LOG_WARN("Inpainting with a base model is not great");
19481953
denoise_mask = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width / 8, height / 8, 1, 1);

0 commit comments

Comments
 (0)