Skip to content

Commit 9251756

Browse files
authored
feat: add CosXL support (#683)
1 parent ecf5db9 commit 9251756

File tree

2 files changed

+53
-20
lines changed

2 files changed

+53
-20
lines changed

denoiser.hpp

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -168,24 +168,21 @@ struct AYSSchedule : SigmaSchedule {
168168
std::vector<float> inputs;
169169
std::vector<float> results(n + 1);
170170

171-
switch (version) {
172-
case VERSION_SD2: /* fallthrough */
173-
LOG_WARN("AYS not designed for SD2.X models");
174-
case VERSION_SD1:
175-
LOG_INFO("AYS using SD1.5 noise levels");
176-
inputs = noise_levels[0];
177-
break;
178-
case VERSION_SDXL:
179-
LOG_INFO("AYS using SDXL noise levels");
180-
inputs = noise_levels[1];
181-
break;
182-
case VERSION_SVD:
183-
LOG_INFO("AYS using SVD noise levels");
184-
inputs = noise_levels[2];
185-
break;
186-
default:
187-
LOG_ERROR("Version not compatable with AYS scheduler");
188-
return results;
171+
if (sd_version_is_sd2((SDVersion)version)) {
172+
LOG_WARN("AYS not designed for SD2.X models");
173+
} /* fallthrough */
174+
else if (sd_version_is_sd1((SDVersion)version)) {
175+
LOG_INFO("AYS using SD1.5 noise levels");
176+
inputs = noise_levels[0];
177+
} else if (sd_version_is_sdxl((SDVersion)version)) {
178+
LOG_INFO("AYS using SDXL noise levels");
179+
inputs = noise_levels[1];
180+
} else if (version == VERSION_SVD) {
181+
LOG_INFO("AYS using SVD noise levels");
182+
inputs = noise_levels[2];
183+
} else {
184+
LOG_ERROR("Version not compatable with AYS scheduler");
185+
return results;
189186
}
190187

191188
/* Stretches those pre-calculated reference levels out to the desired
@@ -346,6 +343,31 @@ struct CompVisVDenoiser : public CompVisDenoiser {
346343
}
347344
};
348345

346+
struct EDMVDenoiser : public CompVisVDenoiser {
347+
float min_sigma = 0.002;
348+
float max_sigma = 120.0;
349+
350+
EDMVDenoiser(float min_sigma = 0.002, float max_sigma = 120.0) : min_sigma(min_sigma), max_sigma(max_sigma) {
351+
schedule = std::make_shared<ExponentialSchedule>();
352+
}
353+
354+
float t_to_sigma(float t) {
355+
return std::exp(t * 4/(float)TIMESTEPS);
356+
}
357+
358+
float sigma_to_t(float s) {
359+
return 0.25 * std::log(s);
360+
}
361+
362+
float sigma_min() {
363+
return min_sigma;
364+
}
365+
366+
float sigma_max() {
367+
return max_sigma;
368+
}
369+
};
370+
349371
float time_snr_shift(float alpha, float t) {
350372
if (alpha == 1.0f) {
351373
return t;

stable-diffusion.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,9 @@ class StableDiffusionGGML {
103103
bool vae_tiling = false;
104104
bool stacked_id = false;
105105

106+
bool is_using_v_parameterization = false;
107+
bool is_using_edm_v_parameterization = false;
108+
106109
std::map<std::string, struct ggml_tensor*> tensors;
107110

108111
std::string lora_model_dir;
@@ -543,12 +546,17 @@ class StableDiffusionGGML {
543546
LOG_INFO("loading model from '%s' completed, taking %.2fs", model_path.c_str(), (t1 - t0) * 1.0f / 1000);
544547

545548
// check is_using_v_parameterization_for_sd2
546-
bool is_using_v_parameterization = false;
549+
547550
if (sd_version_is_sd2(version)) {
548551
if (is_using_v_parameterization_for_sd2(ctx, sd_version_is_inpaint(version))) {
549552
is_using_v_parameterization = true;
550553
}
551554
} else if (sd_version_is_sdxl(version)) {
555+
if (model_loader.tensor_storages_types.find("edm_vpred.sigma_max") != model_loader.tensor_storages_types.end()) {
556+
// CosXL models
557+
// TODO: get sigma_min and sigma_max values from file
558+
is_using_edm_v_parameterization = true;
559+
}
552560
if (model_loader.tensor_storages_types.find("v_pred") != model_loader.tensor_storages_types.end()) {
553561
is_using_v_parameterization = true;
554562
}
@@ -573,6 +581,9 @@ class StableDiffusionGGML {
573581
} else if (is_using_v_parameterization) {
574582
LOG_INFO("running in v-prediction mode");
575583
denoiser = std::make_shared<CompVisVDenoiser>();
584+
} else if (is_using_edm_v_parameterization) {
585+
LOG_INFO("running in v-prediction EDM mode");
586+
denoiser = std::make_shared<EDMVDenoiser>();
576587
} else {
577588
LOG_INFO("running in eps-prediction mode");
578589
}
@@ -1396,7 +1407,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
13961407
SDCondition uncond;
13971408
if (cfg_scale != 1.0) {
13981409
bool force_zero_embeddings = false;
1399-
if (sd_version_is_sdxl(sd_ctx->sd->version) && negative_prompt.size() == 0) {
1410+
if (sd_version_is_sdxl(sd_ctx->sd->version) && negative_prompt.size() == 0 && !sd_ctx->sd->is_using_edm_v_parameterization) {
14001411
force_zero_embeddings = true;
14011412
}
14021413
uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,

0 commit comments

Comments
 (0)