Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ struct SDCliParams {
return -1;
}
const char* preview = argv[index];
int preview_found = -1;
int preview_found = -1;
for (int m = 0; m < PREVIEW_COUNT; m++) {
if (!strcmp(preview, previews_str[m])) {
preview_found = m;
Expand Down Expand Up @@ -515,7 +515,7 @@ struct SDContextParams {
bool chroma_use_t5_mask = false;
int chroma_t5_mask_pad = 1;

prediction_t prediction = DEFAULT_PRED;
prediction_t prediction = PREDICTION_COUNT;
lora_apply_mode_t lora_apply_mode = LORA_APPLY_AUTO;

sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f};
Expand Down
160 changes: 65 additions & 95 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -707,7 +707,7 @@ class StableDiffusionGGML {
return false;
}

// LOG_DEBUG("model size = %.2fMB", total_size / 1024.0 / 1024.0);
LOG_DEBUG("finished loaded file");

{
size_t clip_params_mem_size = cond_stage_model->get_params_buffer_size();
Expand Down Expand Up @@ -782,8 +782,59 @@ class StableDiffusionGGML {
ggml_backend_is_cpu(clip_backend) ? "RAM" : "VRAM");
}

if (sd_ctx_params->prediction != DEFAULT_PRED) {
switch (sd_ctx_params->prediction) {
// init denoiser
{
prediction_t pred_type = sd_ctx_params->prediction;
float flow_shift = sd_ctx_params->flow_shift;

if (pred_type == PREDICTION_COUNT) {
if (sd_version_is_sd2(version)) {
// check is_using_v_parameterization_for_sd2
if (is_using_v_parameterization_for_sd2(ctx, sd_version_is_inpaint(version))) {
pred_type = V_PRED;
} else {
pred_type = EPS_PRED;
}
} else if (sd_version_is_sdxl(version)) {
if (tensor_storage_map.find("edm_vpred.sigma_max") != tensor_storage_map.end()) {
// CosXL models
// TODO: get sigma_min and sigma_max values from file
pred_type = EDM_V_PRED;
} else if (tensor_storage_map.find("v_pred") != tensor_storage_map.end()) {
pred_type = V_PRED;
} else {
pred_type = EPS_PRED;
}
} else if (sd_version_is_sd3(version) ||
sd_version_is_wan(version) ||
sd_version_is_qwen_image(version) ||
sd_version_is_z_image(version)) {
pred_type = FLOW_PRED;
if (flow_shift == INFINITY) {
if (sd_version_is_wan(version)) {
flow_shift = 5.f;
} else {
flow_shift = 3.f;
}
}
} else if (sd_version_is_flux(version)) {
pred_type = FLUX_FLOW_PRED;
if (flow_shift == INFINITY) {
flow_shift = 1.0f; // TODO: validate
for (const auto& [name, tensor_storage] : tensor_storage_map) {
if (starts_with(name, "model.diffusion_model.guidance_in.in_layer.weight")) {
flow_shift = 1.15f;
}
}
}
} else if (sd_version_is_flux2(version)) {
pred_type = FLUX2_FLOW_PRED;
} else {
pred_type = EPS_PRED;
}
}

switch (pred_type) {
case EPS_PRED:
LOG_INFO("running in eps-prediction mode");
break;
Expand All @@ -795,22 +846,14 @@ class StableDiffusionGGML {
LOG_INFO("running in v-prediction EDM mode");
denoiser = std::make_shared<EDMVDenoiser>();
break;
case SD3_FLOW_PRED: {
case FLOW_PRED: {
LOG_INFO("running in FLOW mode");
float shift = sd_ctx_params->flow_shift;
if (shift == INFINITY) {
shift = 3.0;
}
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
denoiser = std::make_shared<DiscreteFlowDenoiser>(flow_shift);
break;
}
case FLUX_FLOW_PRED: {
LOG_INFO("running in Flux FLOW mode");
float shift = sd_ctx_params->flow_shift;
if (shift == INFINITY) {
shift = 3.0;
}
denoiser = std::make_shared<FluxFlowDenoiser>(shift);
denoiser = std::make_shared<FluxFlowDenoiser>(flow_shift);
break;
}
case FLUX2_FLOW_PRED: {
Expand All @@ -819,93 +862,21 @@ class StableDiffusionGGML {
break;
}
default: {
LOG_ERROR("Unknown parametrization %i", sd_ctx_params->prediction);
LOG_ERROR("Unknown predition type %i", pred_type);
ggml_free(ctx);
return false;
}
}
} else {
if (sd_version_is_sd2(version)) {
// check is_using_v_parameterization_for_sd2
if (is_using_v_parameterization_for_sd2(ctx, sd_version_is_inpaint(version))) {
is_using_v_parameterization = true;
}
} else if (sd_version_is_sdxl(version)) {
if (tensor_storage_map.find("edm_vpred.sigma_max") != tensor_storage_map.end()) {
// CosXL models
// TODO: get sigma_min and sigma_max values from file
is_using_edm_v_parameterization = true;
}
if (tensor_storage_map.find("v_pred") != tensor_storage_map.end()) {
is_using_v_parameterization = true;
}
} else if (version == VERSION_SVD) {
// TODO: V_PREDICTION_EDM
is_using_v_parameterization = true;
}

if (sd_version_is_sd3(version)) {
LOG_INFO("running in FLOW mode");
float shift = sd_ctx_params->flow_shift;
if (shift == INFINITY) {
shift = 3.0;
}
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
} else if (sd_version_is_flux(version)) {
LOG_INFO("running in Flux FLOW mode");
float shift = sd_ctx_params->flow_shift;
if (shift == INFINITY) {
shift = 1.0f; // TODO: validate
for (const auto& [name, tensor_storage] : tensor_storage_map) {
if (starts_with(name, "model.diffusion_model.guidance_in.in_layer.weight")) {
shift = 1.15f;
}
}
auto comp_vis_denoiser = std::dynamic_pointer_cast<CompVisDenoiser>(denoiser);
if (comp_vis_denoiser) {
for (int i = 0; i < TIMESTEPS; i++) {
comp_vis_denoiser->sigmas[i] = std::sqrt((1 - ((float*)alphas_cumprod_tensor->data)[i]) / ((float*)alphas_cumprod_tensor->data)[i]);
comp_vis_denoiser->log_sigmas[i] = std::log(comp_vis_denoiser->sigmas[i]);
}
denoiser = std::make_shared<FluxFlowDenoiser>(shift);
} else if (sd_version_is_flux2(version)) {
LOG_INFO("running in Flux2 FLOW mode");
denoiser = std::make_shared<Flux2FlowDenoiser>();
} else if (sd_version_is_wan(version)) {
LOG_INFO("running in FLOW mode");
float shift = sd_ctx_params->flow_shift;
if (shift == INFINITY) {
shift = 5.0;
}
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
} else if (sd_version_is_qwen_image(version)) {
LOG_INFO("running in FLOW mode");
float shift = sd_ctx_params->flow_shift;
if (shift == INFINITY) {
shift = 3.0;
}
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
} else if (sd_version_is_z_image(version)) {
LOG_INFO("running in FLOW mode");
float shift = sd_ctx_params->flow_shift;
if (shift == INFINITY) {
shift = 3.0f;
}
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
} else if (is_using_v_parameterization) {
LOG_INFO("running in v-prediction mode");
denoiser = std::make_shared<CompVisVDenoiser>();
} else if (is_using_edm_v_parameterization) {
LOG_INFO("running in v-prediction EDM mode");
denoiser = std::make_shared<EDMVDenoiser>();
} else {
LOG_INFO("running in eps-prediction mode");
}
}

auto comp_vis_denoiser = std::dynamic_pointer_cast<CompVisDenoiser>(denoiser);
if (comp_vis_denoiser) {
for (int i = 0; i < TIMESTEPS; i++) {
comp_vis_denoiser->sigmas[i] = std::sqrt((1 - ((float*)alphas_cumprod_tensor->data)[i]) / ((float*)alphas_cumprod_tensor->data)[i]);
comp_vis_denoiser->log_sigmas[i] = std::log(comp_vis_denoiser->sigmas[i]);
}
}

LOG_DEBUG("finished loaded file");
ggml_free(ctx);
use_tiny_autoencoder = use_tiny_autoencoder && !sd_ctx_params->tae_preview_only;
return true;
Expand Down Expand Up @@ -2426,7 +2397,6 @@ enum scheduler_t str_to_scheduler(const char* str) {
}

const char* prediction_to_str[] = {
"default",
"eps",
"v",
"edm_v",
Expand Down Expand Up @@ -2512,7 +2482,7 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
sd_ctx_params->wtype = SD_TYPE_COUNT;
sd_ctx_params->rng_type = CUDA_RNG;
sd_ctx_params->sampler_rng_type = RNG_TYPE_COUNT;
sd_ctx_params->prediction = DEFAULT_PRED;
sd_ctx_params->prediction = PREDICTION_COUNT;
sd_ctx_params->lora_apply_mode = LORA_APPLY_AUTO;
sd_ctx_params->offload_params_to_cpu = false;
sd_ctx_params->keep_clip_on_cpu = false;
Expand Down
3 changes: 1 addition & 2 deletions stable-diffusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,10 @@ enum scheduler_t {
};

enum prediction_t {
DEFAULT_PRED,
EPS_PRED,
V_PRED,
EDM_V_PRED,
SD3_FLOW_PRED,
FLOW_PRED,
FLUX_FLOW_PRED,
FLUX2_FLOW_PRED,
PREDICTION_COUNT
Expand Down
Loading