Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
292 changes: 213 additions & 79 deletions conditioner.hpp

Large diffs are not rendered by default.

71 changes: 53 additions & 18 deletions flux.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,19 @@ namespace Flux {

public:
SelfAttention(int64_t dim,
int64_t num_heads = 8,
bool qkv_bias = false,
bool proj_bias = true)
int64_t num_heads = 8,
bool qkv_bias = false,
bool proj_bias = true,
bool diffusers_style = false)
: num_heads(num_heads) {
int64_t head_dim = dim / num_heads;
blocks["qkv"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim * 3, qkv_bias));
blocks["norm"] = std::shared_ptr<GGMLBlock>(new QKNorm(head_dim));
blocks["proj"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim, proj_bias));
if (diffusers_style) {
blocks["qkv"] = std::shared_ptr<GGMLBlock>(new SplitLinear(dim, {dim, dim, dim}, qkv_bias));
} else {
blocks["qkv"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim * 3, qkv_bias));
}
blocks["norm"] = std::shared_ptr<GGMLBlock>(new QKNorm(head_dim));
blocks["proj"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim, proj_bias));
}

std::vector<struct ggml_tensor*> pre_attention(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
Expand Down Expand Up @@ -210,7 +215,8 @@ namespace Flux {
bool prune_mod = false,
bool share_modulation = false,
bool mlp_proj_bias = true,
bool use_mlp_silu_act = false)
bool use_mlp_silu_act = false,
bool diffusers_style = false)
: idx(idx), prune_mod(prune_mod), use_mlp_silu_act(use_mlp_silu_act) {
int64_t mlp_hidden_dim = hidden_size * mlp_ratio;
int64_t mlp_mult_factor = use_mlp_silu_act ? 2 : 1;
Expand All @@ -219,7 +225,7 @@ namespace Flux {
blocks["img_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true));
}
blocks["img_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
blocks["img_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias));
blocks["img_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias, diffusers_style));

blocks["img_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
blocks["img_mlp.0"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, mlp_hidden_dim * mlp_mult_factor, mlp_proj_bias));
Expand All @@ -230,7 +236,7 @@ namespace Flux {
blocks["txt_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true));
}
blocks["txt_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
blocks["txt_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias));
blocks["txt_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias, diffusers_style));

blocks["txt_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
blocks["txt_mlp.0"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, mlp_hidden_dim * mlp_mult_factor, mlp_proj_bias));
Expand Down Expand Up @@ -383,6 +389,7 @@ namespace Flux {
int idx = 0;
bool use_mlp_silu_act;
int64_t mlp_mult_factor;
bool diffusers_style = false;

public:
SingleStreamBlock(int64_t hidden_size,
Expand All @@ -393,7 +400,8 @@ namespace Flux {
bool prune_mod = false,
bool share_modulation = false,
bool mlp_proj_bias = true,
bool use_mlp_silu_act = false)
bool use_mlp_silu_act = false,
bool diffusers_style = false)
: hidden_size(hidden_size), num_heads(num_heads), idx(idx), prune_mod(prune_mod), use_mlp_silu_act(use_mlp_silu_act) {
int64_t head_dim = hidden_size / num_heads;
float scale = qk_scale;
Expand All @@ -405,8 +413,11 @@ namespace Flux {
if (use_mlp_silu_act) {
mlp_mult_factor = 2;
}

blocks["linear1"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim * mlp_mult_factor, mlp_proj_bias));
if (diffusers_style) {
blocks["linear1"] = std::shared_ptr<GGMLBlock>(new SplitLinear(hidden_size, {hidden_size, hidden_size, hidden_size, mlp_hidden_dim * mlp_mult_factor}, mlp_proj_bias));
} else {
blocks["linear1"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim * mlp_mult_factor, mlp_proj_bias));
}
blocks["linear2"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size + mlp_hidden_dim, hidden_size, mlp_proj_bias));
blocks["norm"] = std::shared_ptr<GGMLBlock>(new QKNorm(head_dim));
blocks["pre_norm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
Expand Down Expand Up @@ -728,6 +739,7 @@ namespace Flux {
bool share_modulation = false;
bool use_mlp_silu_act = false;
float ref_index_scale = 1.f;
bool diffusers_style = false;
ChromaRadianceParams chroma_radiance_params;
};

Expand Down Expand Up @@ -770,7 +782,8 @@ namespace Flux {
params.is_chroma,
params.share_modulation,
!params.disable_bias,
params.use_mlp_silu_act);
params.use_mlp_silu_act,
params.diffusers_style);
}

for (int i = 0; i < params.depth_single_blocks; i++) {
Expand All @@ -782,7 +795,8 @@ namespace Flux {
params.is_chroma,
params.share_modulation,
!params.disable_bias,
params.use_mlp_silu_act);
params.use_mlp_silu_act,
params.diffusers_style);
}

if (params.version == VERSION_CHROMA_RADIANCE) {
Expand Down Expand Up @@ -829,6 +843,11 @@ namespace Flux {
int64_t C = x->ne[2];
int64_t H = x->ne[1];
int64_t W = x->ne[0];
if (params.patch_size == 1) {
x = ggml_reshape_3d(ctx, x, H * W, C, N); // [N, C, H*W]
x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [N, H*W, C]
return x;
}
int64_t p = params.patch_size;
int64_t h = H / params.patch_size;
int64_t w = W / params.patch_size;
Expand Down Expand Up @@ -863,6 +882,12 @@ namespace Flux {
int64_t W = w * params.patch_size;
int64_t p = params.patch_size;

if (params.patch_size == 1) {
x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [N, C, H*W]
x = ggml_reshape_4d(ctx, x, W, H, C, N); // [N, C, H, W]
return x;
}

GGML_ASSERT(C * p * p == x->ne[0]);

x = ggml_reshape_4d(ctx, x, p * p, C, w * h, N); // [N, h*w, C, p*p]
Expand Down Expand Up @@ -1222,6 +1247,10 @@ namespace Flux {
flux_params.share_modulation = true;
flux_params.ref_index_scale = 10.f;
flux_params.use_mlp_silu_act = true;
} else if (sd_version_is_longcat(version)) {
flux_params.context_in_dim = 3584;
flux_params.vec_in_dim = 0;
flux_params.patch_size = 1;
}
for (auto pair : tensor_storage_map) {
std::string tensor_name = pair.first;
Expand All @@ -1231,6 +1260,9 @@ namespace Flux {
// not schnell
flux_params.guidance_embed = true;
}
if (tensor_name.find("model.diffusion_model.single_blocks.0.linear1.weight.1") != std::string::npos) {
flux_params.diffusers_style = true;
}
if (tensor_name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) {
// Chroma
flux_params.is_chroma = true;
Expand Down Expand Up @@ -1260,6 +1292,10 @@ namespace Flux {
LOG_INFO("Flux guidance is disabled (Schnell mode)");
}

if (flux_params.diffusers_style) {
LOG_INFO("Using diffusers-style attention blocks");
}

flux = Flux(flux_params);
flux.init(params_ctx, tensor_storage_map, prefix);
}
Expand Down Expand Up @@ -1363,7 +1399,6 @@ namespace Flux {
for (int i = 0; i < ref_latents.size(); i++) {
ref_latents[i] = to_backend(ref_latents[i]);
}

pe_vec = Rope::gen_flux_pe(x->ne[1],
x->ne[0],
flux_params.patch_size,
Expand All @@ -1373,10 +1408,10 @@ namespace Flux {
sd_version_is_flux2(version) ? true : increase_ref_index,
flux_params.ref_index_scale,
flux_params.theta,
flux_params.axes_dim);
flux_params.axes_dim,
sd_version_is_longcat(version));
int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2;
// LOG_DEBUG("pos_len %d", pos_len);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len);
// pe->data = pe_vec.data();
// print_ggml_tensor(pe);
// pe->data = nullptr;
Expand Down
77 changes: 77 additions & 0 deletions ggml_extend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2159,6 +2159,83 @@ class Linear : public UnaryBlock {
}
};

class SplitLinear : public Linear {
protected:
int64_t in_features;
std::vector<int64_t> out_features_vec;
bool bias;
bool force_f32;
bool force_prec_f32;
float scale;
std::string prefix;

void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
this->prefix = prefix;
enum ggml_type wtype = get_type(prefix + "weight", tensor_storage_map, GGML_TYPE_F32);
if (in_features % ggml_blck_size(wtype) != 0 || force_f32) {
wtype = GGML_TYPE_F32;
}
params["weight"] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features_vec[0]);
for (int i = 1; i < out_features_vec.size(); i++) {
// most likely same type as the first weight
params["weight." + std::to_string(i)] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features_vec[i]);
}
if (bias) {
enum ggml_type wtype = GGML_TYPE_F32;
params["bias"] = ggml_new_tensor_1d(ctx, wtype, out_features_vec[0]);
for (int i = 1; i < out_features_vec.size(); i++) {
params["bias." + std::to_string(i)] = ggml_new_tensor_1d(ctx, wtype, out_features_vec[i]);
}
}
}

public:
SplitLinear(int64_t in_features,
std::vector<int64_t> out_features_vec,
bool bias = true,
bool force_f32 = false,
bool force_prec_f32 = false,
float scale = 1.f)
: Linear(in_features, out_features_vec[0], bias, force_f32, force_prec_f32, scale),
in_features(in_features),
out_features_vec(out_features_vec),
bias(bias),
force_f32(force_f32),
force_prec_f32(force_prec_f32),
scale(scale) {}

struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
struct ggml_tensor* w = params["weight"];
struct ggml_tensor* b = nullptr;
if (bias) {
b = params["bias"];
}
if (ctx->weight_adapter) {
// concat all weights and biases together so it runs in one linear layer
for (int i = 1; i < out_features_vec.size(); i++) {
w = ggml_concat(ctx->ggml_ctx, w, params["weight." + std::to_string(i)], 1);
if (bias) {
b = ggml_concat(ctx->ggml_ctx, b, params["bias." + std::to_string(i)], 0);
}
}
WeightAdapter::ForwardParams forward_params;
forward_params.op_type = WeightAdapter::ForwardParams::op_type_t::OP_LINEAR;
forward_params.linear.force_prec_f32 = force_prec_f32;
forward_params.linear.scale = scale;
return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, x, w, b, prefix, forward_params);
}
auto out = ggml_ext_linear(ctx->ggml_ctx, x, w, b, force_prec_f32, scale);
for (int i = 1; i < out_features_vec.size(); i++) {
auto wi = params["weight." + std::to_string(i)];
auto bi = bias ? params["bias." + std::to_string(i)] : nullptr;
auto curr_out = ggml_ext_linear(ctx->ggml_ctx, x, wi, bi, force_prec_f32, scale);
out = ggml_concat(ctx->ggml_ctx, out, curr_out, 0);
}

return out;
}
};

__STATIC_INLINE__ bool support_get_rows(ggml_type wtype) {
std::set<ggml_type> allow_types = {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0};
if (allow_types.find(wtype) != allow_types.end()) {
Expand Down
80 changes: 71 additions & 9 deletions latent-preview.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,41 @@ const float flux_latent_rgb_proj[16][3] = {
{-0.111849f, -0.055589f, -0.032361f}};
float flux_latent_rgb_bias[3] = {0.024600f, -0.006937f, -0.008089f};

const float flux2_latent_rgb_proj[32][3] = {
{0.000736f, -0.008385f, -0.019710f},
{-0.001352f, -0.016392f, 0.020693f},
{-0.006376f, 0.002428f, 0.036736f},
{0.039384f, 0.074167f, 0.119789f},
{0.007464f, -0.005705f, -0.004734f},
{-0.004086f, 0.005287f, -0.000409f},
{-0.032835f, 0.050802f, -0.028120f},
{-0.003158f, -0.000835f, 0.000406f},
{-0.112840f, -0.084337f, -0.023083f},
{0.001462f, -0.006656f, 0.000549f},
{-0.009980f, -0.007480f, 0.009702f},
{0.032540f, 0.000214f, -0.061388f},
{0.011023f, 0.000694f, 0.007143f},
{-0.001468f, -0.006723f, -0.001678f},
{-0.005921f, -0.010320f, -0.003907f},
{-0.028434f, 0.027584f, 0.018457f},
{0.014349f, 0.011523f, 0.000441f},
{0.009874f, 0.003081f, 0.001507f},
{0.002218f, 0.005712f, 0.001563f},
{0.053010f, -0.019844f, 0.008683f},
{-0.002507f, 0.005384f, 0.000938f},
{-0.002177f, -0.011366f, 0.003559f},
{-0.000261f, 0.015121f, -0.003240f},
{-0.003944f, -0.002083f, 0.005043f},
{-0.009138f, 0.011336f, 0.003781f},
{0.011429f, 0.003985f, -0.003855f},
{0.010518f, -0.005586f, 0.010131f},
{0.007883f, 0.002912f, -0.001473f},
{-0.003318f, -0.003160f, 0.003684f},
{-0.034560f, -0.008740f, 0.012996f},
{0.000166f, 0.001079f, -0.012153f},
{0.017772f, 0.000937f, -0.011953f}};
float flux2_latent_rgb_bias[3] = {-0.028738f, -0.098463f, -0.107619f};

// This one was taken straight from
// https://github.com/Stability-AI/sd3.5/blob/8565799a3b41eb0c7ba976d18375f0f753f56402/sd3_impls.py#L288-L303
// (MiT Licence)
Expand Down Expand Up @@ -128,16 +163,43 @@ const float sd_latent_rgb_proj[4][3] = {
{-0.178022f, -0.200862f, -0.678514f}};
float sd_latent_rgb_bias[3] = {-0.017478f, -0.055834f, -0.105825f};

void preview_latent_video(uint8_t* buffer, struct ggml_tensor* latents, const float (*latent_rgb_proj)[3], const float latent_rgb_bias[3], int width, int height, int frames, int dim) {

void preview_latent_video(uint8_t* buffer, struct ggml_tensor* latents, const float (*latent_rgb_proj)[3], const float latent_rgb_bias[3], int patch_size) {
size_t buffer_head = 0;

uint32_t latent_width = latents->ne[0];
uint32_t latent_height = latents->ne[1];
uint32_t dim = latents->ne[ggml_n_dims(latents) - 1];
uint32_t frames = 1;
if (ggml_n_dims(latents) == 4) {
frames = latents->ne[2];
}

uint32_t rgb_width = latent_width * patch_size;
uint32_t rgb_height = latent_height * patch_size;

uint32_t unpatched_dim = dim / (patch_size * patch_size);

for (int k = 0; k < frames; k++) {
for (int j = 0; j < height; j++) {
for (int i = 0; i < width; i++) {
size_t latent_id = (i * latents->nb[0] + j * latents->nb[1] + k * latents->nb[2]);
for (int rgb_x = 0; rgb_x < rgb_width; rgb_x++) {
for (int rgb_y = 0; rgb_y < rgb_height; rgb_y++) {
int latent_x = rgb_x / patch_size;
int latent_y = rgb_y / patch_size;

int channel_offset = 0;
if (patch_size > 1) {
channel_offset = ((rgb_y % patch_size) * patch_size + (rgb_x % patch_size));
}

size_t latent_id = (latent_x * latents->nb[0] + latent_y * latents->nb[1] + k * latents->nb[2]);

// should be incremented by 1 for each pixel
size_t pixel_id = k * rgb_width * rgb_height + rgb_y * rgb_width + rgb_x;

float r = 0, g = 0, b = 0;
if (latent_rgb_proj != nullptr) {
for (int d = 0; d < dim; d++) {
float value = *(float*)((char*)latents->data + latent_id + d * latents->nb[ggml_n_dims(latents) - 1]);
for (int d = 0; d < unpatched_dim; d++) {
float value = *(float*)((char*)latents->data + latent_id + (d * patch_size * patch_size + channel_offset) * latents->nb[ggml_n_dims(latents) - 1]);
r += value * latent_rgb_proj[d][0];
g += value * latent_rgb_proj[d][1];
b += value * latent_rgb_proj[d][2];
Expand All @@ -164,9 +226,9 @@ void preview_latent_video(uint8_t* buffer, struct ggml_tensor* latents, const fl
g = g >= 0 ? g <= 1 ? g : 1 : 0;
b = b >= 0 ? b <= 1 ? b : 1 : 0;

buffer[buffer_head++] = (uint8_t)(r * 255);
buffer[buffer_head++] = (uint8_t)(g * 255);
buffer[buffer_head++] = (uint8_t)(b * 255);
buffer[pixel_id * 3 + 0] = (uint8_t)(r * 255);
buffer[pixel_id * 3 + 1] = (uint8_t)(g * 255);
buffer[pixel_id * 3 + 2] = (uint8_t)(b * 255);
}
}
}
Expand Down
Loading