Skip to content

Support for Flux Controls + Flex.2 #692

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 43 additions & 37 deletions examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,16 @@ struct SDParams {

std::string prompt;
std::string negative_prompt;
float min_cfg = 1.0f;
float cfg_scale = 7.0f;
float guidance = 3.5f;
float eta = 0.f;
float style_ratio = 20.f;
int clip_skip = -1; // <= 0 represents unspecified
int width = 512;
int height = 512;
int batch_count = 1;
float min_cfg = 1.0f;
float cfg_scale = 7.0f;
float img_cfg_scale = INFINITY;
float guidance = 3.5f;
float eta = 0.f;
float style_ratio = 20.f;
int clip_skip = -1; // <= 0 represents unspecified
int width = 512;
int height = 512;
int batch_count = 1;

int video_frames = 6;
int motion_bucket_id = 127;
Expand Down Expand Up @@ -176,6 +177,7 @@ void print_params(SDParams params) {
printf(" negative_prompt: %s\n", params.negative_prompt.c_str());
printf(" min_cfg: %.2f\n", params.min_cfg);
printf(" cfg_scale: %.2f\n", params.cfg_scale);
printf(" img_cfg_scale: %.2f\n", params.img_cfg_scale);
printf(" slg_scale: %.2f\n", params.slg_scale);
printf(" guidance: %.2f\n", params.guidance);
printf(" eta: %.2f\n", params.eta);
Expand Down Expand Up @@ -234,7 +236,8 @@ void print_usage(int argc, const char* argv[]) {
printf(" -p, --prompt [PROMPT] the prompt to render\n");
printf(" -n, --negative-prompt PROMPT the negative prompt (default: \"\")\n");
printf(" --cfg-scale SCALE unconditional guidance scale: (default: 7.0)\n");
printf(" --guidance SCALE guidance scale for img2img (default: 3.5)\n");
printf(" --img-cfg-scale SCALE image guidance scale for inpaint or instruct-pix2pix models: (default: same as --cfg-scale)\n");
printf(" --guidance SCALE distilled guidance scale for models with guidance input (default: 3.5)\n");
printf(" --slg-scale SCALE skip layer guidance (SLG) scale, only for DiT models: (default: 0)\n");
printf(" 0 means disabled, a value of 2.5 is nice for sd3.5 medium\n");
printf(" --eta SCALE eta in DDIM, only for DDIM and TCD: (default: 0)\n");
Expand Down Expand Up @@ -470,6 +473,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
break;
}
params.cfg_scale = std::stof(argv[i]);
} else if (arg == "--img-cfg-scale") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.img_cfg_scale = std::stof(argv[i]);
} else if (arg == "--guidance") {
if (++i >= argc) {
invalid_arg = true;
Expand Down Expand Up @@ -755,6 +764,10 @@ void parse_args(int argc, const char** argv, SDParams& params) {
params.output_path = "output.gguf";
}
}

if (!isfinite(params.img_cfg_scale)) {
params.img_cfg_scale = params.cfg_scale;
}
}

static std::string sd_basename(const std::string& path) {
Expand Down Expand Up @@ -849,6 +862,18 @@ int main(int argc, const char* argv[]) {

parse_args(argc, argv, params);

sd_guidance_params_t guidance_params = {params.cfg_scale,
params.img_cfg_scale,
params.min_cfg,
params.guidance,
{
params.skip_layers.data(),
params.skip_layers.size(),
params.skip_layer_start,
params.skip_layer_end,
params.slg_scale,
}};

sd_set_log_callback(sd_log_cb, (void*)&params);

if (params.verbose) {
Expand Down Expand Up @@ -1000,7 +1025,7 @@ int main(int argc, const char* argv[]) {
}

sd_image_t* control_image = NULL;
if (params.controlnet_path.size() > 0 && params.control_image_path.size() > 0) {
if (params.control_image_path.size() > 0) {
int c = 0;
control_image_buffer = stbi_load(params.control_image_path.c_str(), &params.width, &params.height, &c, 3);
if (control_image_buffer == NULL) {
Expand Down Expand Up @@ -1041,8 +1066,7 @@ int main(int argc, const char* argv[]) {
params.prompt.c_str(),
params.negative_prompt.c_str(),
params.clip_skip,
params.cfg_scale,
params.guidance,
guidance_params,
params.eta,
params.width,
params.height,
Expand All @@ -1054,12 +1078,7 @@ int main(int argc, const char* argv[]) {
params.control_strength,
params.style_ratio,
params.normalize_input,
params.input_id_images_path.c_str(),
params.skip_layers.data(),
params.skip_layers.size(),
params.slg_scale,
params.skip_layer_start,
params.skip_layer_end);
params.input_id_images_path.c_str());
} else if (params.mode == IMG2IMG || params.mode == IMG2VID) {
sd_image_t input_image = {(uint32_t)params.width,
(uint32_t)params.height,
Expand All @@ -1075,8 +1094,7 @@ int main(int argc, const char* argv[]) {
params.motion_bucket_id,
params.fps,
params.augmentation_level,
params.min_cfg,
params.cfg_scale,
guidance_params,
params.sample_method,
params.sample_steps,
params.strength,
Expand Down Expand Up @@ -1109,8 +1127,7 @@ int main(int argc, const char* argv[]) {
params.prompt.c_str(),
params.negative_prompt.c_str(),
params.clip_skip,
params.cfg_scale,
params.guidance,
guidance_params,
params.eta,
params.width,
params.height,
Expand All @@ -1123,12 +1140,7 @@ int main(int argc, const char* argv[]) {
params.control_strength,
params.style_ratio,
params.normalize_input,
params.input_id_images_path.c_str(),
params.skip_layers.data(),
params.skip_layers.size(),
params.slg_scale,
params.skip_layer_start,
params.skip_layer_end);
params.input_id_images_path.c_str());
}
} else { // EDIT
results = edit(sd_ctx,
Expand All @@ -1137,25 +1149,19 @@ int main(int argc, const char* argv[]) {
params.prompt.c_str(),
params.negative_prompt.c_str(),
params.clip_skip,
params.cfg_scale,
params.guidance,
guidance_params,
params.eta,
params.width,
params.height,
params.sample_method,
params.sample_steps,
params.strength,
params.seed,
params.batch_count,
control_image,
params.control_strength,
params.style_ratio,
params.normalize_input,
params.skip_layers.data(),
params.skip_layers.size(),
params.slg_scale,
params.skip_layer_start,
params.skip_layer_end);
params.input_id_images_path.c_str());
}

if (results == NULL) {
Expand Down
38 changes: 34 additions & 4 deletions flux.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -984,7 +984,8 @@ namespace Flux {
struct ggml_tensor* pe,
struct ggml_tensor* mod_index_arange = NULL,
std::vector<ggml_tensor*> ref_latents = {},
std::vector<int> skip_layers = {}) {
std::vector<int> skip_layers = {},
SDVersion version = VERSION_FLUX) {
// Forward pass of DiT.
// x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
// timestep: (N,) tensor of diffusion timesteps
Expand All @@ -1007,14 +1008,38 @@ namespace Flux {
auto img = process_img(ctx, x);
uint64_t img_tokens = img->ne[1];

if (c_concat != NULL) {
if (version == VERSION_FLUX_FILL) {
GGML_ASSERT(c_concat != NULL);
ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);

masked = process_img(ctx, masked);
mask = process_img(ctx, mask);

img = ggml_concat(ctx, img, ggml_concat(ctx, masked, mask, 0), 0);
} else if (version == VERSION_FLEX_2) {
GGML_ASSERT(c_concat != NULL);
ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 1, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
ggml_tensor* control = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * (C + 1));

masked = ggml_pad(ctx, masked, pad_w, pad_h, 0, 0);
mask = ggml_pad(ctx, mask, pad_w, pad_h, 0, 0);
control = ggml_pad(ctx, control, pad_w, pad_h, 0, 0);

masked = patchify(ctx, masked, patch_size);
mask = patchify(ctx, mask, patch_size);
control = patchify(ctx, control, patch_size);

img = ggml_concat(ctx, img, ggml_concat(ctx, ggml_concat(ctx, masked, mask, 0), control, 0), 0);
} else if (version == VERSION_FLUX_CONTROLS) {
GGML_ASSERT(c_concat != NULL);

ggml_tensor* control = ggml_pad(ctx, c_concat, pad_w, pad_h, 0, 0);

control = patchify(ctx, control, patch_size);

img = ggml_concat(ctx, img, control, 0);
}

if (ref_latents.size() > 0) {
Expand Down Expand Up @@ -1055,13 +1080,17 @@ namespace Flux {
SDVersion version = VERSION_FLUX,
bool flash_attn = false,
bool use_mask = false)
: GGMLRunner(backend), use_mask(use_mask) {
: GGMLRunner(backend), version(version), use_mask(use_mask) {
flux_params.flash_attn = flash_attn;
flux_params.guidance_embed = false;
flux_params.depth = 0;
flux_params.depth_single_blocks = 0;
if (version == VERSION_FLUX_FILL) {
flux_params.in_channels = 384;
} else if (version == VERSION_FLUX_CONTROLS) {
flux_params.in_channels = 128;
} else if (version == VERSION_FLEX_2) {
flux_params.in_channels = 196;
}
for (auto pair : tensor_types) {
std::string tensor_name = pair.first;
Expand Down Expand Up @@ -1171,7 +1200,8 @@ namespace Flux {
pe,
mod_index_arange,
ref_latents,
skip_layers);
skip_layers,
version);

ggml_build_forward_expand(gf, out);

Expand Down
21 changes: 17 additions & 4 deletions ggml_extend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,18 +380,31 @@ __STATIC_INLINE__ void sd_mask_to_tensor(const uint8_t* image_data,

__STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data,
struct ggml_tensor* mask,
struct ggml_tensor* output) {
struct ggml_tensor* output,
float masked_value = 0.5f) {
int64_t width = output->ne[0];
int64_t height = output->ne[1];
int64_t channels = output->ne[2];
for (int ix = 0; ix < mask->ne[0]; ix++) {
for (int iy = 0; iy < mask->ne[1]; iy++) {
float m = ggml_tensor_get_f32(mask, ix, iy);
m = round(m); // inpaint models need binary masks
ggml_tensor_set_f32(mask, m, ix, iy);
}
}
float rescale_mx = mask->ne[0]/output->ne[0];
float rescale_my = mask->ne[1]/output->ne[1];
GGML_ASSERT(output->type == GGML_TYPE_F32);
for (int ix = 0; ix < width; ix++) {
for (int iy = 0; iy < height; iy++) {
float m = ggml_tensor_get_f32(mask, ix, iy);
int mx = (int)(ix * rescale_mx);
int my = (int)(iy * rescale_my);
float m = ggml_tensor_get_f32(mask, mx, my);
m = round(m); // inpaint models need binary masks
ggml_tensor_set_f32(mask, m, ix, iy);
ggml_tensor_set_f32(mask, m, mx, my);
for (int k = 0; k < channels; k++) {
float value = (1 - m) * (ggml_tensor_get_f32(image_data, ix, iy, k) - .5) + .5;
float value = ggml_tensor_get_f32(image_data, ix, iy, k);
value = (1 - m) * (value - masked_value) + masked_value;
ggml_tensor_set_f32(output, value, ix, iy, k);
}
}
Expand Down
16 changes: 14 additions & 2 deletions model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1673,25 +1673,37 @@ SDVersion ModelLoader::get_sd_version() {
}
}
bool is_inpaint = input_block_weight.ne[2] == 9;
bool is_ip2p = input_block_weight.ne[2] == 8;
if (is_xl) {
if (is_inpaint) {
return VERSION_SDXL_INPAINT;
}
if (is_ip2p) {
return VERSION_SDXL_PIX2PIX;
}
return VERSION_SDXL;
}

if (is_flux) {
is_inpaint = input_block_weight.ne[0] == 384;
if (is_inpaint) {
if (input_block_weight.ne[0] == 384) {
return VERSION_FLUX_FILL;
}
if (input_block_weight.ne[0] == 128) {
return VERSION_FLUX_CONTROLS;
}
if(input_block_weight.ne[0] == 196){
return VERSION_FLEX_2;
}
return VERSION_FLUX;
}

if (token_embedding_weight.ne[0] == 768) {
if (is_inpaint) {
return VERSION_SD1_INPAINT;
}
if (is_ip2p) {
return VERSION_SD1_PIX2PIX;
}
return VERSION_SD1;
} else if (token_embedding_weight.ne[0] == 1024) {
if (is_inpaint) {
Expand Down
Loading
Loading