@@ -103,6 +103,9 @@ class StableDiffusionGGML {
103
103
bool vae_tiling = false ;
104
104
bool stacked_id = false ;
105
105
106
+ bool is_using_v_parameterization = false ;
107
+ bool is_using_edm_v_parameterization = false ;
108
+
106
109
std::map<std::string, struct ggml_tensor *> tensors;
107
110
108
111
std::string lora_model_dir;
@@ -543,12 +546,17 @@ class StableDiffusionGGML {
543
546
LOG_INFO (" loading model from '%s' completed, taking %.2fs" , model_path.c_str (), (t1 - t0) * 1 .0f / 1000 );
544
547
545
548
// check is_using_v_parameterization_for_sd2
546
- bool is_using_v_parameterization = false ;
549
+
547
550
if (sd_version_is_sd2 (version)) {
548
551
if (is_using_v_parameterization_for_sd2 (ctx, sd_version_is_inpaint (version))) {
549
552
is_using_v_parameterization = true ;
550
553
}
551
554
} else if (sd_version_is_sdxl (version)) {
555
+ if (model_loader.tensor_storages_types .find (" edm_vpred.sigma_max" ) != model_loader.tensor_storages_types .end ()) {
556
+ // CosXL models
557
+ // TODO: get sigma_min and sigma_max values from file
558
+ is_using_edm_v_parameterization = true ;
559
+ }
552
560
if (model_loader.tensor_storages_types .find (" v_pred" ) != model_loader.tensor_storages_types .end ()) {
553
561
is_using_v_parameterization = true ;
554
562
}
@@ -573,6 +581,9 @@ class StableDiffusionGGML {
573
581
} else if (is_using_v_parameterization) {
574
582
LOG_INFO (" running in v-prediction mode" );
575
583
denoiser = std::make_shared<CompVisVDenoiser>();
584
+ } else if (is_using_edm_v_parameterization) {
585
+ LOG_INFO (" running in v-prediction EDM mode" );
586
+ denoiser = std::make_shared<EDMVDenoiser>();
576
587
} else {
577
588
LOG_INFO (" running in eps-prediction mode" );
578
589
}
@@ -1396,7 +1407,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
1396
1407
SDCondition uncond;
1397
1408
if (cfg_scale != 1.0 ) {
1398
1409
bool force_zero_embeddings = false ;
1399
- if (sd_version_is_sdxl (sd_ctx->sd ->version ) && negative_prompt.size () == 0 ) {
1410
+ if (sd_version_is_sdxl (sd_ctx->sd ->version ) && negative_prompt.size () == 0 && !sd_ctx-> sd -> is_using_edm_v_parameterization ) {
1400
1411
force_zero_embeddings = true ;
1401
1412
}
1402
1413
uncond = sd_ctx->sd ->cond_stage_model ->get_learned_condition (work_ctx,
0 commit comments