diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 49b202fda..9481169af 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -580,7 +580,7 @@ struct SDContextParams { "--vae", "path to standalone vae model", &vae_path}, - {"", + {"--tae", "--taesd", "path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)", &taesd_path}, diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 1ef851247..d2f23ac37 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -400,8 +400,8 @@ class StableDiffusionGGML { offload_params_to_cpu, tensor_storage_map); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - tensor_storage_map); + offload_params_to_cpu, + tensor_storage_map); } else if (sd_version_is_flux(version)) { bool is_chroma = false; for (auto pair : tensor_storage_map) { @@ -461,10 +461,10 @@ class StableDiffusionGGML { 1, true); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - tensor_storage_map, - "model.diffusion_model", - version); + offload_params_to_cpu, + tensor_storage_map, + "model.diffusion_model", + version); if (strlen(SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path)) > 0) { high_noise_diffusion_model = std::make_shared(backend, offload_params_to_cpu, @@ -564,14 +564,27 @@ class StableDiffusionGGML { } if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) { - first_stage_model = std::make_shared(vae_backend, - offload_params_to_cpu, - tensor_storage_map, - "first_stage_model", - vae_decode_only, - version); - first_stage_model->alloc_params_buffer(); - first_stage_model->get_param_tensors(tensors, "first_stage_model"); + if (!use_tiny_autoencoder) { + first_stage_model = std::make_shared(vae_backend, + offload_params_to_cpu, + tensor_storage_map, + "first_stage_model", + vae_decode_only, + version); + first_stage_model->alloc_params_buffer(); + first_stage_model->get_param_tensors(tensors, "first_stage_model"); + } else { + tae_first_stage = std::make_shared(vae_backend, + offload_params_to_cpu, + tensor_storage_map, + "decoder", + vae_decode_only, + version); + if (sd_ctx_params->vae_conv_direct) { + LOG_INFO("Using Conv2d direct in the tae model"); + tae_first_stage->set_conv2d_direct_enabled(true); + } + } } else if (version == VERSION_CHROMA_RADIANCE) { first_stage_model = std::make_shared(vae_backend, offload_params_to_cpu); @@ -598,14 +611,13 @@ class StableDiffusionGGML { } first_stage_model->alloc_params_buffer(); first_stage_model->get_param_tensors(tensors, "first_stage_model"); - } - if (use_tiny_autoencoder) { - tae_first_stage = std::make_shared(vae_backend, - offload_params_to_cpu, - tensor_storage_map, - "decoder.layers", - vae_decode_only, - version); + } else if (use_tiny_autoencoder) { + tae_first_stage = std::make_shared(vae_backend, + offload_params_to_cpu, + tensor_storage_map, + "decoder.layers", + vae_decode_only, + version); if (sd_ctx_params->vae_conv_direct) { LOG_INFO("Using Conv2d direct in the tae model"); tae_first_stage->set_conv2d_direct_enabled(true); @@ -726,13 +738,16 @@ class StableDiffusionGGML { unet_params_mem_size += high_noise_diffusion_model->get_params_buffer_size(); } size_t vae_params_mem_size = 0; + LOG_DEBUG("Here"); if (!use_tiny_autoencoder || sd_ctx_params->tae_preview_only) { vae_params_mem_size = first_stage_model->get_params_buffer_size(); } if (use_tiny_autoencoder) { + LOG_DEBUG("Here"); if (!tae_first_stage->load_from_file(taesd_path, n_threads)) { return false; } + LOG_DEBUG("Here"); vae_params_mem_size = tae_first_stage->get_params_buffer_size(); } size_t control_net_params_mem_size = 0; @@ -2297,6 +2312,10 @@ class StableDiffusionGGML { first_stage_model->free_compute_buffer(); process_vae_output_tensor(result); } else { + if (sd_version_is_wan(version)) { + x = ggml_permute(work_ctx, x, 0, 1, 3, 2); + } + if (vae_tiling_params.enabled && !decode_video) { // split latent in 64x64 tiles and compute in several steps auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { @@ -2307,6 +2326,7 @@ class StableDiffusionGGML { tae_first_stage->compute(n_threads, x, true, &result); } tae_first_stage->free_compute_buffer(); + } int64_t t1 = ggml_time_ms(); @@ -3817,7 +3837,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s struct ggml_tensor* vid = sd_ctx->sd->decode_first_stage(work_ctx, final_latent, true); int64_t t5 = ggml_time_ms(); LOG_INFO("decode_first_stage completed, taking %.2fs", (t5 - t4) * 1.0f / 1000); - if (sd_ctx->sd->free_params_immediately) { + if (sd_ctx->sd->free_params_immediately && !sd_ctx->sd->use_tiny_autoencoder) { sd_ctx->sd->first_stage_model->free_params_buffer(); } diff --git a/tae.hpp b/tae.hpp index 7f3ca449a..c88dd7ce3 100644 --- a/tae.hpp +++ b/tae.hpp @@ -162,6 +162,242 @@ class TinyDecoder : public UnaryBlock { } }; +class TPool : public UnaryBlock { + int stride; + +public: + TPool(int channels, int stride) : stride(stride) { + blocks["conv"] = std::shared_ptr(new Conv2d(channels * stride, channels, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false)); + } + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override { + auto conv = std::dynamic_pointer_cast(blocks["conv"]); + auto h = x; + if (stride != 1) { + h = ggml_reshape_4d(ctx->ggml_ctx, h, h->ne[0], h->ne[1], h->ne[2] * stride, h->ne[3] / stride); + } + h = conv->forward(ctx, h); + return h; + } +}; + +class TGrow : public UnaryBlock { + int stride; + +public: + TGrow(int channels, int stride) : stride(stride) { + blocks["conv"] = std::shared_ptr(new Conv2d(channels, channels * stride, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false)); + } + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override { + auto conv = std::dynamic_pointer_cast(blocks["conv"]); + auto h = conv->forward(ctx, x); + if (stride != 1) { + h = ggml_reshape_4d(ctx->ggml_ctx, h, h->ne[0], h->ne[1], h->ne[2] / stride, h->ne[3] * stride); + } + return h; + } +}; + +class MemBlock : public GGMLBlock { + bool has_skip_conv = false; + +public: + MemBlock(int channels, int out_channels) : has_skip_conv(channels != out_channels) { + blocks["conv.0"] = std::shared_ptr(new Conv2d(channels * 2, out_channels, {3, 3}, {1, 1}, {1, 1})); + blocks["conv.2"] = std::shared_ptr(new Conv2d(out_channels, out_channels, {3, 3}, {1, 1}, {1, 1})); + blocks["conv.4"] = std::shared_ptr(new Conv2d(out_channels, out_channels, {3, 3}, {1, 1}, {1, 1})); + if (has_skip_conv) { + blocks["skip"] = std::shared_ptr(new Conv2d(channels, out_channels, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false)); + } + } + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* past) { + // x: [n, channels, h, w] + auto conv0 = std::dynamic_pointer_cast(blocks["conv.0"]); + auto conv1 = std::dynamic_pointer_cast(blocks["conv.2"]); + auto conv2 = std::dynamic_pointer_cast(blocks["conv.4"]); + + auto h = ggml_concat(ctx->ggml_ctx, x, past, 2); + h = conv0->forward(ctx, h); + h = ggml_relu_inplace(ctx->ggml_ctx, h); + h = conv1->forward(ctx, h); + h = ggml_relu_inplace(ctx->ggml_ctx, h); + h = conv2->forward(ctx, h); + + auto skip = x; + if (has_skip_conv) { + auto skip_conv = std::dynamic_pointer_cast(blocks["skip"]); + skip = skip_conv->forward(ctx, x); + } + h = ggml_add_inplace(ctx->ggml_ctx, h, skip); + h = ggml_relu_inplace(ctx->ggml_ctx, h); + return h; + } +}; + +class TinyVideoEncoder : public UnaryBlock { + int in_channels = 3; + int hidden = 64; + int z_channels = 4; + int num_blocks = 3; + int num_layers = 3; + int patch_size = 1; + +public: + TinyVideoEncoder(int z_channels = 4, int patch_size = 1) + : z_channels(z_channels), patch_size(patch_size) { + int index = 0; + blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(in_channels * patch_size * patch_size, hidden, {3, 3}, {1, 1}, {1, 1})); + index++; // nn.ReLU() + for (int i = 0; i < num_layers; i++) { + int stride = i == num_layers - 1 ? 1 : 2; + blocks[std::to_string(index++)] = std::shared_ptr(new TPool(hidden, stride)); + blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(hidden, hidden, {3, 3}, {2, 2}, {1, 1}, {1, 1}, false)); + for (int j = 0; j < num_blocks; j++) { + blocks[std::to_string(index++)] = std::shared_ptr(new MemBlock(hidden, hidden)); + } + } + blocks[std::to_string(index)] = std::shared_ptr(new Conv2d(hidden, z_channels, {3, 3}, {1, 1}, {1, 1})); + } + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* z) override { + auto first_conv = std::dynamic_pointer_cast(blocks["0"]); + auto h = first_conv->forward(ctx, z); + h = ggml_relu_inplace(ctx->ggml_ctx, h); + + int index = 2; + for (int i = 0; i < num_layers; i++) { + auto pool = std::dynamic_pointer_cast(blocks[std::to_string(index++)]); + auto conv = std::dynamic_pointer_cast(blocks[std::to_string(index++)]); + + h = pool->forward(ctx, h); + h = conv->forward(ctx, h); + for (int j = 0; j < num_blocks; j++) { + auto block = std::dynamic_pointer_cast(blocks[std::to_string(index++)]); + auto mem = ggml_pad_ext(ctx->ggml_ctx, h, 0, 0, 0, 0, 0, 0, 1, 0); + mem = ggml_view_4d(ctx->ggml_ctx, mem, h->ne[0], h->ne[1], h->ne[2], h->ne[3], h->nb[1], h->nb[2], h->nb[3], 0); + h = block->forward(ctx, h, mem); + } + } + auto last_conv = std::dynamic_pointer_cast(blocks[std::to_string(index)]); + h = last_conv->forward(ctx, h); + return h; + } +}; + +class TinyVideoDecoder : public UnaryBlock { + int z_channels = 4; + int out_channels = 3; + int num_blocks = 3; + static const int num_layers = 3; + int channels[num_layers + 1] = {256, 128, 64, 64}; + +public: + TinyVideoDecoder(int z_channels = 4, int patch_size = 1) : z_channels(z_channels) { + int index = 1; // Clamp() + blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(z_channels, channels[0], {3, 3}, {1, 1}, {1, 1})); + index++; // nn.ReLU() + for (int i = 0; i < num_layers; i++) { + int stride = i == 0 ? 1 : 2; + for (int j = 0; j < num_blocks; j++) { + blocks[std::to_string(index++)] = std::shared_ptr(new MemBlock(channels[i], channels[i])); + } + index++; // nn.Upsample() + blocks[std::to_string(index++)] = std::shared_ptr(new TGrow(channels[i], stride)); + LOG_DEBUG("Create Conv2d %d shape = %d %d", index, channels[i], channels[i + 1]); + blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(channels[i], channels[i + 1], {3, 3}, {1, 1}, {1, 1}, {1, 1}, false)); + } + index++; // nn.ReLU() + blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(channels[num_layers], out_channels * patch_size * patch_size, {3, 3}, {1, 1}, {1, 1})); + } + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* z) override { + auto first_conv = std::dynamic_pointer_cast(blocks["1"]); + + // Clamp() + auto h = ggml_scale_inplace(ctx->ggml_ctx, + ggml_tanh_inplace(ctx->ggml_ctx, + ggml_scale(ctx->ggml_ctx, z, 1.0f / 3.0f)), + 3.0f); + + h = first_conv->forward(ctx, h); + h = ggml_relu_inplace(ctx->ggml_ctx, h); + int index = 3; + for (int i = 0; i < num_layers; i++) { + for (int j = 0; j < num_blocks; j++) { + auto block = std::dynamic_pointer_cast(blocks[std::to_string(index++)]); + auto mem = ggml_pad_ext(ctx->ggml_ctx, h, 0, 0, 0, 0, 0, 0, 1, 0); + mem = ggml_view_4d(ctx->ggml_ctx, mem, h->ne[0], h->ne[1], h->ne[2], h->ne[3], h->nb[1], h->nb[2], h->nb[3], 0); + h = block->forward(ctx, h, mem); + } + // upsample + index++; + h = ggml_upscale(ctx->ggml_ctx, h, 2, GGML_SCALE_MODE_NEAREST); + auto block = std::dynamic_pointer_cast(blocks[std::to_string(index++)]); + h = block->forward(ctx, h); + block = std::dynamic_pointer_cast(blocks[std::to_string(index++)]); + h = block->forward(ctx, h); + } + h = ggml_relu_inplace(ctx->ggml_ctx, h); + + auto last_conv = std::dynamic_pointer_cast(blocks[std::to_string(++index)]); + h = last_conv->forward(ctx, h); + + // shape(W, H, 3, 3 + T) => shape(W, H, 3, T) + h = ggml_view_4d(ctx->ggml_ctx, h, h->ne[0], h->ne[1], h->ne[2], h->ne[3] - 3, h->nb[1], h->nb[2], h->nb[3], 3 * h->nb[3]); + return h; + } +}; + +class TAEHV : public GGMLBlock { +protected: + bool decode_only; + SDVersion version; + +public: + TAEHV(bool decode_only = true, SDVersion version = VERSION_WAN2) + : decode_only(decode_only), version(version) { + int z_channels = 16; + int patch = 1; + if (version == VERSION_WAN2_2_TI2V) { + z_channels = 48; + patch = 2; + } + blocks["decoder"] = std::shared_ptr(new TinyVideoDecoder(z_channels, patch)); + if (!decode_only) { + blocks["encoder"] = std::shared_ptr(new TinyVideoEncoder(z_channels, patch)); + } + } + + struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) { + auto decoder = std::dynamic_pointer_cast(blocks["decoder"]); + auto result = decoder->forward(ctx, z); + if (sd_version_is_wan(version)) { + // (W, H, C, T) -> (W, H, T, C) + result = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, result, 0, 1, 3, 2)); + } + return result; + } + + struct ggml_tensor* encode(GGMLRunnerContext* ctx, struct ggml_tensor* x) { + auto encoder = std::dynamic_pointer_cast(blocks["encoder"]); + // (W, H, T, C) -> (W, H, C, T) + x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); + int64_t num_frames = x->ne[3]; + if (num_frames % 4) { + // pad to multiple of 4 at the end + auto last_frame = ggml_view_4d(ctx->ggml_ctx, x, x->ne[0], x->ne[1], x->ne[2], 1, x->nb[1], x->nb[2], x->nb[3], (num_frames - 1) * x->nb[3]); + for (int i = 0; i < 4 - num_frames % 4; i++) { + x = ggml_concat(ctx->ggml_ctx, x, last_frame, 3); + } + } + x = encoder->forward(ctx, x); + x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); + return x; + } +}; + class TAESD : public GGMLBlock { protected: bool decode_only; @@ -192,18 +428,30 @@ class TAESD : public GGMLBlock { }; struct TinyAutoEncoder : public GGMLRunner { + TinyAutoEncoder(ggml_backend_t backend, bool offload_params_to_cpu) + : GGMLRunner(backend, offload_params_to_cpu) {} + virtual bool compute(const int n_threads, + struct ggml_tensor* z, + bool decode_graph, + struct ggml_tensor** output, + struct ggml_context* output_ctx = nullptr) = 0; + + virtual bool load_from_file(const std::string& file_path, int n_threads) = 0; +}; + +struct TinyImageAutoEncoder : public TinyAutoEncoder { TAESD taesd; bool decode_only = false; - TinyAutoEncoder(ggml_backend_t backend, - bool offload_params_to_cpu, - const String2TensorStorage& tensor_storage_map, - const std::string prefix, - bool decoder_only = true, - SDVersion version = VERSION_SD1) + TinyImageAutoEncoder(ggml_backend_t backend, + bool offload_params_to_cpu, + const String2TensorStorage& tensor_storage_map, + const std::string prefix, + bool decoder_only = true, + SDVersion version = VERSION_SD1) : decode_only(decoder_only), taesd(decoder_only, version), - GGMLRunner(backend, offload_params_to_cpu) { + TinyAutoEncoder(backend, offload_params_to_cpu) { taesd.init(params_ctx, tensor_storage_map, prefix); } @@ -260,4 +508,73 @@ struct TinyAutoEncoder : public GGMLRunner { } }; +struct TinyVideoAutoEncoder : public TinyAutoEncoder { + TAEHV taehv; + bool decode_only = false; + + TinyVideoAutoEncoder(ggml_backend_t backend, + bool offload_params_to_cpu, + const String2TensorStorage& tensor_storage_map, + const std::string prefix, + bool decoder_only = true, + SDVersion version = VERSION_WAN2) + : decode_only(decoder_only), + taehv(decoder_only, version), + TinyAutoEncoder(backend, offload_params_to_cpu) { + taehv.init(params_ctx, tensor_storage_map, prefix); + } + + std::string get_desc() override { + return "taehv"; + } + + bool load_from_file(const std::string& file_path, int n_threads) { + LOG_INFO("loading taehv from '%s', decode_only = %s", file_path.c_str(), decode_only ? "true" : "false"); + alloc_params_buffer(); + std::map taehv_tensors; + taehv.get_param_tensors(taehv_tensors); + std::set ignore_tensors; + if (decode_only) { + ignore_tensors.insert("encoder."); + } + + ModelLoader model_loader; + if (!model_loader.init_from_file(file_path)) { + LOG_ERROR("init taehv model loader from file failed: '%s'", file_path.c_str()); + return false; + } + + bool success = model_loader.load_tensors(taehv_tensors, ignore_tensors, n_threads); + + if (!success) { + LOG_ERROR("load tae tensors from model loader failed"); + return false; + } + + LOG_INFO("taehv model loaded"); + return success; + } + + struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) { + struct ggml_cgraph* gf = ggml_new_graph(compute_ctx); + z = to_backend(z); + auto runner_ctx = get_context(); + struct ggml_tensor* out = decode_graph ? taehv.decode(&runner_ctx, z) : taehv.encode(&runner_ctx, z); + ggml_build_forward_expand(gf, out); + return gf; + } + + bool compute(const int n_threads, + struct ggml_tensor* z, + bool decode_graph, + struct ggml_tensor** output, + struct ggml_context* output_ctx = nullptr) { + auto get_graph = [&]() -> struct ggml_cgraph* { + return build_graph(z, decode_graph); + }; + + return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); + } +}; + #endif // __TAE_HPP__ \ No newline at end of file