Skip to content

Commit 2f7d66d

Browse files
committed
Add CosXL support (broken)
1 parent 10c6501 commit 2f7d66d

File tree

2 files changed

+77
-39
lines changed

2 files changed

+77
-39
lines changed

denoiser.hpp

Lines changed: 66 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,38 @@ struct CompVisVDenoiser : public CompVisDenoiser {
346346
}
347347
};
348348

349+
struct CompVisEDMVDenoiser : public CompVisVDenoiser {
350+
float sigmas[TIMESTEPS];
351+
float min_sigma = 0.002;
352+
float max_sigma = 120.0;
353+
float sigma_data = 1;
354+
std::shared_ptr<SigmaSchedule> schedule = std::make_shared<ExponentialSchedule>();
355+
356+
CompVisEDMVDenoiser(float min_sigma = 0.002, float max_sigma = 120.0, float sigma_data = 1) : min_sigma(min_sigma), max_sigma(max_sigma), sigma_data(sigma_data) {
357+
}
358+
359+
float t_to_sigma(float t) {
360+
return std::exp(t * 4);
361+
}
362+
363+
float sigma_to_t(float s) {
364+
return 0.25 * std::log(s);
365+
}
366+
367+
float sigma_min() {
368+
return min_sigma;
369+
}
370+
371+
float sigma_max() {
372+
return max_sigma;
373+
}
374+
375+
std::vector<float> get_sigmas(uint32_t n) {
376+
auto bound_t_to_sigma = std::bind(&Denoiser::t_to_sigma, this, std::placeholders::_1);
377+
return schedule->get_sigmas(n, sigma_min(), sigma_max(), bound_t_to_sigma);
378+
}
379+
};
380+
349381
float time_snr_shift(float alpha, float t) {
350382
if (alpha == 1.0f) {
351383
return t;
@@ -1019,7 +1051,7 @@ static void sample_k_diffusion(sample_method_t method,
10191051
// also needed to invert the behavior of CompVisDenoiser
10201052
// (k-diffusion's LMSDiscreteScheduler)
10211053
float beta_start = 0.00085f;
1022-
float beta_end = 0.0120f;
1054+
float beta_end = 0.0120f;
10231055
std::vector<double> alphas_cumprod;
10241056
std::vector<double> compvis_sigmas;
10251057

@@ -1030,8 +1062,9 @@ static void sample_k_diffusion(sample_method_t method,
10301062
(i == 0 ? 1.0f : alphas_cumprod[i - 1]) *
10311063
(1.0f -
10321064
std::pow(sqrtf(beta_start) +
1033-
(sqrtf(beta_end) - sqrtf(beta_start)) *
1034-
((float)i / (TIMESTEPS - 1)), 2));
1065+
(sqrtf(beta_end) - sqrtf(beta_start)) *
1066+
((float)i / (TIMESTEPS - 1)),
1067+
2));
10351068
compvis_sigmas[i] =
10361069
std::sqrt((1 - alphas_cumprod[i]) /
10371070
alphas_cumprod[i]);
@@ -1061,7 +1094,8 @@ static void sample_k_diffusion(sample_method_t method,
10611094
// - pred_prev_sample -> "x_t-1"
10621095
int timestep =
10631096
roundf(TIMESTEPS -
1064-
i * ((float)TIMESTEPS / steps)) - 1;
1097+
i * ((float)TIMESTEPS / steps)) -
1098+
1;
10651099
// 1. get previous step value (=t-1)
10661100
int prev_timestep = timestep - TIMESTEPS / steps;
10671101
// The sigma here is chosen to cause the
@@ -1086,10 +1120,9 @@ static void sample_k_diffusion(sample_method_t method,
10861120
float* vec_x = (float*)x->data;
10871121
for (int j = 0; j < ggml_nelements(x); j++) {
10881122
vec_x[j] *= std::sqrt(sigma * sigma + 1) /
1089-
sigma;
1123+
sigma;
10901124
}
1091-
}
1092-
else {
1125+
} else {
10931126
// For the subsequent steps after the first one,
10941127
// at this point x = latents or x = sample, and
10951128
// needs to be prescaled with x <- sample / c_in
@@ -1127,9 +1160,8 @@ static void sample_k_diffusion(sample_method_t method,
11271160
float alpha_prod_t = alphas_cumprod[timestep];
11281161
// Note final_alpha_cumprod = alphas_cumprod[0] due to
11291162
// trailing timestep spacing
1130-
float alpha_prod_t_prev = prev_timestep >= 0 ?
1131-
alphas_cumprod[prev_timestep] : alphas_cumprod[0];
1132-
float beta_prod_t = 1 - alpha_prod_t;
1163+
float alpha_prod_t_prev = prev_timestep >= 0 ? alphas_cumprod[prev_timestep] : alphas_cumprod[0];
1164+
float beta_prod_t = 1 - alpha_prod_t;
11331165
// 3. compute predicted original sample from predicted
11341166
// noise also called "predicted x_0" of formula (12)
11351167
// from https://arxiv.org/pdf/2010.02502.pdf
@@ -1145,7 +1177,7 @@ static void sample_k_diffusion(sample_method_t method,
11451177
vec_pred_original_sample[j] =
11461178
(vec_x[j] / std::sqrt(sigma * sigma + 1) -
11471179
std::sqrt(beta_prod_t) *
1148-
vec_model_output[j]) *
1180+
vec_model_output[j]) *
11491181
(1 / std::sqrt(alpha_prod_t));
11501182
}
11511183
}
@@ -1159,8 +1191,8 @@ static void sample_k_diffusion(sample_method_t method,
11591191
// sigma_t = sqrt((1 - alpha_t-1)/(1 - alpha_t)) *
11601192
// sqrt(1 - alpha_t/alpha_t-1)
11611193
float beta_prod_t_prev = 1 - alpha_prod_t_prev;
1162-
float variance = (beta_prod_t_prev / beta_prod_t) *
1163-
(1 - alpha_prod_t / alpha_prod_t_prev);
1194+
float variance = (beta_prod_t_prev / beta_prod_t) *
1195+
(1 - alpha_prod_t / alpha_prod_t_prev);
11641196
float std_dev_t = eta * std::sqrt(variance);
11651197
// 6. compute "direction pointing to x_t" of formula
11661198
// (12) from https://arxiv.org/pdf/2010.02502.pdf
@@ -1179,8 +1211,8 @@ static void sample_k_diffusion(sample_method_t method,
11791211
std::pow(std_dev_t, 2)) *
11801212
vec_model_output[j];
11811213
vec_x[j] = std::sqrt(alpha_prod_t_prev) *
1182-
vec_pred_original_sample[j] +
1183-
pred_sample_direction;
1214+
vec_pred_original_sample[j] +
1215+
pred_sample_direction;
11841216
}
11851217
}
11861218
if (eta > 0) {
@@ -1208,7 +1240,7 @@ static void sample_k_diffusion(sample_method_t method,
12081240
// by Semi-Linear Consistency Function with Trajectory
12091241
// Mapping", arXiv:2402.19159 [cs.CV]
12101242
float beta_start = 0.00085f;
1211-
float beta_end = 0.0120f;
1243+
float beta_end = 0.0120f;
12121244
std::vector<double> alphas_cumprod;
12131245
std::vector<double> compvis_sigmas;
12141246

@@ -1219,8 +1251,9 @@ static void sample_k_diffusion(sample_method_t method,
12191251
(i == 0 ? 1.0f : alphas_cumprod[i - 1]) *
12201252
(1.0f -
12211253
std::pow(sqrtf(beta_start) +
1222-
(sqrtf(beta_end) - sqrtf(beta_start)) *
1223-
((float)i / (TIMESTEPS - 1)), 2));
1254+
(sqrtf(beta_end) - sqrtf(beta_start)) *
1255+
((float)i / (TIMESTEPS - 1)),
1256+
2));
12241257
compvis_sigmas[i] =
12251258
std::sqrt((1 - alphas_cumprod[i]) /
12261259
alphas_cumprod[i]);
@@ -1235,13 +1268,10 @@ static void sample_k_diffusion(sample_method_t method,
12351268
for (int i = 0; i < steps; i++) {
12361269
// Analytic form for TCD timesteps
12371270
int timestep = TIMESTEPS - 1 -
1238-
(TIMESTEPS / original_steps) *
1239-
(int)floor(i * ((float)original_steps / steps));
1271+
(TIMESTEPS / original_steps) *
1272+
(int)floor(i * ((float)original_steps / steps));
12401273
// 1. get previous step value
1241-
int prev_timestep = i >= steps - 1 ? 0 :
1242-
TIMESTEPS - 1 - (TIMESTEPS / original_steps) *
1243-
(int)floor((i + 1) *
1244-
((float)original_steps / steps));
1274+
int prev_timestep = i >= steps - 1 ? 0 : TIMESTEPS - 1 - (TIMESTEPS / original_steps) * (int)floor((i + 1) * ((float)original_steps / steps));
12451275
// Here timestep_s is tau_n' in Algorithm 4. The _s
12461276
// notation appears to be that from C. Lu,
12471277
// "DPM-Solver: A Fast ODE Solver for Diffusion
@@ -1258,10 +1288,9 @@ static void sample_k_diffusion(sample_method_t method,
12581288
float* vec_x = (float*)x->data;
12591289
for (int j = 0; j < ggml_nelements(x); j++) {
12601290
vec_x[j] *= std::sqrt(sigma * sigma + 1) /
1261-
sigma;
1291+
sigma;
12621292
}
1263-
}
1264-
else {
1293+
} else {
12651294
float* vec_x = (float*)x->data;
12661295
for (int j = 0; j < ggml_nelements(x); j++) {
12671296
vec_x[j] *= std::sqrt(sigma * sigma + 1);
@@ -1294,15 +1323,14 @@ static void sample_k_diffusion(sample_method_t method,
12941323
// DPM-Solver. In fact, we have alpha_{t_n} =
12951324
// \sqrt{\hat{alpha_n}}, [...]"
12961325
float alpha_prod_t = alphas_cumprod[timestep];
1297-
float beta_prod_t = 1 - alpha_prod_t;
1326+
float beta_prod_t = 1 - alpha_prod_t;
12981327
// Note final_alpha_cumprod = alphas_cumprod[0] since
12991328
// TCD is always "trailing"
1300-
float alpha_prod_t_prev = prev_timestep >= 0 ?
1301-
alphas_cumprod[prev_timestep] : alphas_cumprod[0];
1329+
float alpha_prod_t_prev = prev_timestep >= 0 ? alphas_cumprod[prev_timestep] : alphas_cumprod[0];
13021330
// The subscript _s are the only portion in this
13031331
// section (2) unique to TCD
13041332
float alpha_prod_s = alphas_cumprod[timestep_s];
1305-
float beta_prod_s = 1 - alpha_prod_s;
1333+
float beta_prod_s = 1 - alpha_prod_s;
13061334
// 3. Compute the predicted noised sample x_s based on
13071335
// the model parameterization
13081336
//
@@ -1317,7 +1345,7 @@ static void sample_k_diffusion(sample_method_t method,
13171345
vec_pred_original_sample[j] =
13181346
(vec_x[j] / std::sqrt(sigma * sigma + 1) -
13191347
std::sqrt(beta_prod_t) *
1320-
vec_model_output[j]) *
1348+
vec_model_output[j]) *
13211349
(1 / std::sqrt(alpha_prod_t));
13221350
}
13231351
}
@@ -1339,9 +1367,9 @@ static void sample_k_diffusion(sample_method_t method,
13391367
// pred_epsilon = model_output
13401368
vec_x[j] =
13411369
std::sqrt(alpha_prod_s) *
1342-
vec_pred_original_sample[j] +
1370+
vec_pred_original_sample[j] +
13431371
std::sqrt(beta_prod_s) *
1344-
vec_model_output[j];
1372+
vec_model_output[j];
13451373
}
13461374
}
13471375
// 4. Sample and inject noise z ~ N(0, I) for
@@ -1357,7 +1385,7 @@ static void sample_k_diffusion(sample_method_t method,
13571385
// In this case, x is still pred_noised_sample,
13581386
// continue in-place
13591387
ggml_tensor_set_f32_randn(noise, rng);
1360-
float* vec_x = (float*)x->data;
1388+
float* vec_x = (float*)x->data;
13611389
float* vec_noise = (float*)noise->data;
13621390
for (int j = 0; j < ggml_nelements(x); j++) {
13631391
// Corresponding to (35) in Zheng et
@@ -1366,10 +1394,10 @@ static void sample_k_diffusion(sample_method_t method,
13661394
vec_x[j] =
13671395
std::sqrt(alpha_prod_t_prev /
13681396
alpha_prod_s) *
1369-
vec_x[j] +
1397+
vec_x[j] +
13701398
std::sqrt(1 - alpha_prod_t_prev /
1371-
alpha_prod_s) *
1372-
vec_noise[j];
1399+
alpha_prod_s) *
1400+
vec_noise[j];
13731401
}
13741402
}
13751403
}

stable-diffusion.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -522,12 +522,19 @@ class StableDiffusionGGML {
522522
LOG_INFO("loading model from '%s' completed, taking %.2fs", model_path.c_str(), (t1 - t0) * 1.0f / 1000);
523523

524524
// check is_using_v_parameterization_for_sd2
525-
bool is_using_v_parameterization = false;
525+
bool is_using_v_parameterization = false;
526+
bool is_using_edm_parameterization = false;
527+
526528
if (sd_version_is_sd2(version)) {
527529
if (is_using_v_parameterization_for_sd2(ctx, sd_version_is_inpaint(version))) {
528530
is_using_v_parameterization = true;
529531
}
530532
} else if (sd_version_is_sdxl(version)) {
533+
if (model_loader.tensor_storages_types.find("edm_vpred.sigma_max") != model_loader.tensor_storages_types.end()) {
534+
// CosXL models
535+
// TODO: get sigma_min and sigma_max values from file
536+
is_using_edm_parameterization = true;
537+
}
531538
if (model_loader.tensor_storages_types.find("v_pred") != model_loader.tensor_storages_types.end()) {
532539
is_using_v_parameterization = true;
533540
}
@@ -552,6 +559,9 @@ class StableDiffusionGGML {
552559
} else if (is_using_v_parameterization) {
553560
LOG_INFO("running in v-prediction mode");
554561
denoiser = std::make_shared<CompVisVDenoiser>();
562+
} else if (is_using_edm_parameterization) {
563+
LOG_INFO("running in edm mode");
564+
denoiser = std::make_shared<CompVisEDMVDenoiser>();
555565
} else {
556566
LOG_INFO("running in eps-prediction mode");
557567
}

0 commit comments

Comments
 (0)