Skip to content

Commit 476df8e

Browse files
committed
global bool
1 parent 0585e26 commit 476df8e

File tree

10 files changed

+85
-23
lines changed

10 files changed

+85
-23
lines changed

common.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class DownSampleBlock : public GGMLBlock {
2828
if (vae_downsample) {
2929
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["conv"]);
3030

31-
x = ggml_pad(ctx, x, 1, 1, 0, 0);
31+
x = sd_pad(ctx, x, 1, 1, 0, 0);
3232
x = conv->forward(ctx, x);
3333
} else {
3434
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["op"]);

examples/cli/main.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ struct SDParams {
113113
bool diffusion_flash_attn = false;
114114
bool diffusion_conv_direct = false;
115115
bool vae_conv_direct = false;
116+
bool circular_pad = false;
116117
bool canny_preprocess = false;
117118
bool color = false;
118119
int upscale_repeats = 1;
@@ -183,6 +184,7 @@ void print_params(SDParams params) {
183184
printf(" diffusion flash attention: %s\n", params.diffusion_flash_attn ? "true" : "false");
184185
printf(" diffusion Conv2d direct: %s\n", params.diffusion_conv_direct ? "true" : "false");
185186
printf(" vae_conv_direct: %s\n", params.vae_conv_direct ? "true" : "false");
187+
printf(" circular padding: %s\n", params.circular_pad ? "true" : "false");
186188
printf(" control_strength: %.2f\n", params.control_strength);
187189
printf(" prompt: %s\n", params.prompt.c_str());
188190
printf(" negative_prompt: %s\n", params.negative_prompt.c_str());
@@ -304,6 +306,7 @@ void print_usage(int argc, const char* argv[]) {
304306
printf(" This might crash if it is not supported by the backend.\n");
305307
printf(" --vae-conv-direct use Conv2d direct in the vae model (should improve the performance)\n");
306308
printf(" This might crash if it is not supported by the backend.\n");
309+
printf(" --circular use circular padding for convolutions and pad ops\n");
307310
printf(" --control-net-cpu keep controlnet in cpu (for low vram)\n");
308311
printf(" --canny apply canny preprocessor (edge detection)\n");
309312
printf(" --color colors the logging tags according to level\n");
@@ -573,6 +576,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
573576
{"", "--diffusion-fa", "", true, &params.diffusion_flash_attn},
574577
{"", "--diffusion-conv-direct", "", true, &params.diffusion_conv_direct},
575578
{"", "--vae-conv-direct", "", true, &params.vae_conv_direct},
579+
{"", "--circular", "", true, &params.circular_pad},
576580
{"", "--canny", "", true, &params.canny_preprocess},
577581
{"-v", "--verbose", "", true, &params.verbose},
578582
{"", "--color", "", true, &params.color},
@@ -1386,6 +1390,7 @@ int main(int argc, const char* argv[]) {
13861390
params.diffusion_flash_attn,
13871391
params.diffusion_conv_direct,
13881392
params.vae_conv_direct,
1393+
params.circular_pad,
13891394
params.force_sdxl_vae_conv_scale,
13901395
params.chroma_use_dit_mask,
13911396
params.chroma_use_t5_mask,

flux.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -696,7 +696,7 @@ namespace Flux {
696696
vec = approx->forward(ctx, vec); // [344, N, hidden_size]
697697

698698
if (y != NULL) {
699-
txt_img_mask = ggml_pad(ctx, y, img->ne[1], 0, 0, 0);
699+
txt_img_mask = sd_pad(ctx, y, img->ne[1], 0, 0, 0);
700700
}
701701
} else {
702702
auto time_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["time_in"]);
@@ -759,7 +759,7 @@ namespace Flux {
759759
int64_t patch_size = 2;
760760
int pad_h = (patch_size - H % patch_size) % patch_size;
761761
int pad_w = (patch_size - W % patch_size) % patch_size;
762-
x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w]
762+
x = sd_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w]
763763

764764
// img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
765765
auto img = patchify(ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size]
@@ -815,9 +815,9 @@ namespace Flux {
815815
ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 1, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
816816
ggml_tensor* control = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * (C + 1));
817817

818-
masked = ggml_pad(ctx, masked, pad_w, pad_h, 0, 0);
819-
mask = ggml_pad(ctx, mask, pad_w, pad_h, 0, 0);
820-
control = ggml_pad(ctx, control, pad_w, pad_h, 0, 0);
818+
masked = sd_pad(ctx, masked, pad_w, pad_h, 0, 0);
819+
mask = sd_pad(ctx, mask, pad_w, pad_h, 0, 0);
820+
control = sd_pad(ctx, control, pad_w, pad_h, 0, 0);
821821

822822
masked = patchify(ctx, masked, patch_size);
823823
mask = patchify(ctx, mask, patch_size);
@@ -827,7 +827,7 @@ namespace Flux {
827827
} else if (params.version == VERSION_FLUX_CONTROLS) {
828828
GGML_ASSERT(c_concat != NULL);
829829

830-
ggml_tensor* control = ggml_pad(ctx, c_concat, pad_w, pad_h, 0, 0);
830+
ggml_tensor* control = sd_pad(ctx, c_concat, pad_w, pad_h, 0, 0);
831831
control = patchify(ctx, control, patch_size);
832832
img = ggml_concat(ctx, img, control, 0);
833833
}

ggml

Submodule ggml updated 393 files

ggml_extend.hpp

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,39 @@
6060
#define SD_UNUSED(x) (void)(x)
6161
#endif
6262

63+
inline bool& sd_global_circular_padding_enabled() {
64+
static bool enabled = false;
65+
return enabled;
66+
}
67+
68+
__STATIC_INLINE__ struct ggml_tensor* sd_pad(struct ggml_context* ctx,
69+
struct ggml_tensor* a,
70+
int p0,
71+
int p1,
72+
int p2,
73+
int p3) {
74+
if (sd_global_circular_padding_enabled()) {
75+
return ggml_pad_circular(ctx, a, 0, p0, 0, p1, 0, p2, 0, p3);
76+
}
77+
return ggml_pad(ctx, a, p0, p1, p2, p3);
78+
}
79+
80+
__STATIC_INLINE__ struct ggml_tensor* sd_pad_ext(struct ggml_context* ctx,
81+
struct ggml_tensor* a,
82+
int lp0,
83+
int rp0,
84+
int lp1,
85+
int rp1,
86+
int lp2,
87+
int rp2,
88+
int lp3,
89+
int rp3) {
90+
if (sd_global_circular_padding_enabled()) {
91+
return ggml_pad_circular(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3);
92+
}
93+
return ggml_pad_ext(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3);
94+
}
95+
6396
__STATIC_INLINE__ void ggml_log_callback_default(ggml_log_level level, const char* text, void*) {
6497
switch (level) {
6598
case GGML_LOG_LEVEL_DEBUG:
@@ -986,10 +1019,24 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_2d(struct ggml_context* ctx,
9861019
if (scale != 1.f) {
9871020
x = ggml_scale(ctx, x, scale);
9881021
}
1022+
const bool use_circular = sd_global_circular_padding_enabled() && (p0 != 0 || p1 != 0);
1023+
const bool is_depthwise = (w->ne[2] == 1 && x->ne[2] == w->ne[3]);
9891024
if (direct) {
990-
x = ggml_conv_2d_direct(ctx, w, x, s0, s1, p0, p1, d0, d1);
1025+
if (use_circular) {
1026+
if (is_depthwise) {
1027+
x = ggml_conv_2d_dw_direct_circular(ctx, w, x, s0, s1, p0, p1, d0, d1);
1028+
} else {
1029+
x = ggml_conv_2d_direct_circular(ctx, w, x, s0, s1, p0, p1, d0, d1);
1030+
}
1031+
} else {
1032+
x = ggml_conv_2d_direct(ctx, w, x, s0, s1, p0, p1, d0, d1);
1033+
}
9911034
} else {
992-
x = ggml_conv_2d(ctx, w, x, s0, s1, p0, p1, d0, d1);
1035+
if (use_circular) {
1036+
x = ggml_conv_2d_circular(ctx, w, x, s0, s1, p0, p1, d0, d1);
1037+
} else {
1038+
x = ggml_conv_2d(ctx, w, x, s0, s1, p0, p1, d0, d1);
1039+
}
9931040
}
9941041
if (scale != 1.f) {
9951042
x = ggml_scale(ctx, x, 1.f / scale);
@@ -1190,7 +1237,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
11901237

11911238
auto build_kqv = [&](ggml_tensor* q_in, ggml_tensor* k_in, ggml_tensor* v_in, ggml_tensor* mask_in) -> ggml_tensor* {
11921239
if (kv_pad != 0) {
1193-
k_in = ggml_pad(ctx, k_in, 0, kv_pad, 0, 0);
1240+
k_in = sd_pad(ctx, k_in, 0, kv_pad, 0, 0);
11941241
}
11951242
if (kv_scale != 1.0f) {
11961243
k_in = ggml_scale(ctx, k_in, kv_scale);
@@ -1200,7 +1247,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
12001247
v_in = ggml_nn_cont(ctx, ggml_permute(ctx, v_in, 0, 2, 1, 3));
12011248
v_in = ggml_reshape_3d(ctx, v_in, d_head, L_k, n_kv_head * N);
12021249
if (kv_pad != 0) {
1203-
v_in = ggml_pad(ctx, v_in, 0, kv_pad, 0, 0);
1250+
v_in = sd_pad(ctx, v_in, 0, kv_pad, 0, 0);
12041251
}
12051252
if (kv_scale != 1.0f) {
12061253
v_in = ggml_scale(ctx, v_in, kv_scale);
@@ -1223,7 +1270,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
12231270
mask_pad = GGML_PAD(L_q, GGML_KQ_MASK_PAD) - mask_in->ne[1];
12241271
}
12251272
if (mask_pad > 0) {
1226-
mask_in = ggml_pad(ctx, mask_in, 0, mask_pad, 0, 0);
1273+
mask_in = sd_pad(ctx, mask_in, 0, mask_pad, 0, 0);
12271274
}
12281275
mask_in = ggml_cast(ctx, mask_in, GGML_TYPE_F16);
12291276
}

mmdit.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ struct PatchEmbed : public GGMLBlock {
8080
int64_t H = x->ne[1];
8181
int pad_h = (patch_size - H % patch_size) % patch_size;
8282
int pad_w = (patch_size - W % patch_size) % patch_size;
83-
x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // TODO: reflect pad mode
83+
x = sd_pad(ctx, x, pad_w, pad_h, 0, 0); // TODO: reflect pad mode
8484
}
8585
x = proj->forward(ctx, x);
8686

@@ -997,4 +997,4 @@ struct MMDiTRunner : public GGMLRunner {
997997
}
998998
};
999999

1000-
#endif
1000+
#endif

qwen_image.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ namespace Qwen {
363363

364364
int pad_h = (params.patch_size - H % params.patch_size) % params.patch_size;
365365
int pad_w = (params.patch_size - W % params.patch_size) % params.patch_size;
366-
x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w]
366+
x = sd_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w]
367367
return x;
368368
}
369369

@@ -691,4 +691,4 @@ namespace Qwen {
691691

692692
} // namespace name
693693

694-
#endif // __QWEN_IMAGE_HPP__
694+
#endif // __QWEN_IMAGE_HPP__

stable-diffusion.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ class StableDiffusionGGML {
114114
bool use_tiny_autoencoder = false;
115115
sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0, 0};
116116
bool offload_params_to_cpu = false;
117+
bool circular_pad = false;
117118
bool stacked_id = false;
118119

119120
bool is_using_v_parameterization = false;
@@ -187,6 +188,11 @@ class StableDiffusionGGML {
187188
taesd_path = SAFE_STR(sd_ctx_params->taesd_path);
188189
use_tiny_autoencoder = taesd_path.size() > 0;
189190
offload_params_to_cpu = sd_ctx_params->offload_params_to_cpu;
191+
circular_pad = sd_ctx_params->circular_pad;
192+
sd_global_circular_padding_enabled() = circular_pad;
193+
if (circular_pad) {
194+
LOG_INFO("Using circular padding for convolutions");
195+
}
190196

191197
if (sd_ctx_params->rng_type == STD_DEFAULT_RNG) {
192198
rng = std::make_shared<STDDefaultRNG>();
@@ -1820,6 +1826,7 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
18201826
sd_ctx_params->keep_control_net_on_cpu = false;
18211827
sd_ctx_params->keep_vae_on_cpu = false;
18221828
sd_ctx_params->diffusion_flash_attn = false;
1829+
sd_ctx_params->circular_pad = false;
18231830
sd_ctx_params->chroma_use_dit_mask = true;
18241831
sd_ctx_params->chroma_use_t5_mask = false;
18251832
sd_ctx_params->chroma_t5_mask_pad = 1;
@@ -1860,6 +1867,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
18601867
"keep_control_net_on_cpu: %s\n"
18611868
"keep_vae_on_cpu: %s\n"
18621869
"diffusion_flash_attn: %s\n"
1870+
"circular_pad: %s\n"
18631871
"chroma_use_dit_mask: %s\n"
18641872
"chroma_use_t5_mask: %s\n"
18651873
"chroma_t5_mask_pad: %d\n",
@@ -1889,6 +1897,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
18891897
BOOL_STR(sd_ctx_params->keep_control_net_on_cpu),
18901898
BOOL_STR(sd_ctx_params->keep_vae_on_cpu),
18911899
BOOL_STR(sd_ctx_params->diffusion_flash_attn),
1900+
BOOL_STR(sd_ctx_params->circular_pad),
18921901
BOOL_STR(sd_ctx_params->chroma_use_dit_mask),
18931902
BOOL_STR(sd_ctx_params->chroma_use_t5_mask),
18941903
sd_ctx_params->chroma_t5_mask_pad);

stable-diffusion.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ typedef struct {
164164
bool diffusion_flash_attn;
165165
bool diffusion_conv_direct;
166166
bool vae_conv_direct;
167+
bool circular_pad;
167168
bool force_sdxl_vae_conv_scale;
168169
bool chroma_use_dit_mask;
169170
bool chroma_use_t5_mask;

wan.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ namespace WAN {
7373
lp2 -= (int)cache_x->ne[2];
7474
}
7575

76-
x = ggml_pad_ext(ctx, x, lp0, rp0, lp1, rp1, lp2, rp2, 0, 0);
76+
x = sd_pad_ext(ctx, x, lp0, rp0, lp1, rp1, lp2, rp2, 0, 0);
7777
return ggml_nn_conv_3d(ctx, x, w, b, in_channels,
7878
std::get<2>(stride), std::get<1>(stride), std::get<0>(stride),
7979
0, 0, 0,
@@ -172,7 +172,7 @@ namespace WAN {
172172
2);
173173
}
174174
if (chunk_idx == 1 && cache_x->ne[2] < 2) { // Rep
175-
cache_x = ggml_pad_ext(ctx, cache_x, 0, 0, 0, 0, (int)cache_x->ne[2], 0, 0, 0);
175+
cache_x = sd_pad_ext(ctx, cache_x, 0, 0, 0, 0, (int)cache_x->ne[2], 0, 0, 0);
176176
// aka cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device),cache_x],dim=2)
177177
}
178178
if (chunk_idx == 1) {
@@ -198,9 +198,9 @@ namespace WAN {
198198
} else if (mode == "upsample3d") {
199199
x = ggml_upscale(ctx, x, 2, GGML_SCALE_MODE_NEAREST);
200200
} else if (mode == "downsample2d") {
201-
x = ggml_pad(ctx, x, 1, 1, 0, 0);
201+
x = sd_pad(ctx, x, 1, 1, 0, 0);
202202
} else if (mode == "downsample3d") {
203-
x = ggml_pad(ctx, x, 1, 1, 0, 0);
203+
x = sd_pad(ctx, x, 1, 1, 0, 0);
204204
}
205205
x = resample_1->forward(ctx, x);
206206
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (c, t, h, w)
@@ -260,7 +260,7 @@ namespace WAN {
260260

261261
int64_t pad_t = (factor_t - T % factor_t) % factor_t;
262262

263-
x = ggml_pad_ext(ctx, x, 0, 0, 0, 0, pad_t, 0, 0, 0);
263+
x = sd_pad_ext(ctx, x, 0, 0, 0, 0, pad_t, 0, 0, 0);
264264
T = x->ne[2];
265265

266266
x = ggml_reshape_4d(ctx, x, W * H, factor_t, T / factor_t, C); // [C, T/factor_t, factor_t, H*W]
@@ -1838,7 +1838,7 @@ namespace WAN {
18381838
int pad_t = (std::get<0>(params.patch_size) - T % std::get<0>(params.patch_size)) % std::get<0>(params.patch_size);
18391839
int pad_h = (std::get<1>(params.patch_size) - H % std::get<1>(params.patch_size)) % std::get<1>(params.patch_size);
18401840
int pad_w = (std::get<2>(params.patch_size) - W % std::get<2>(params.patch_size)) % std::get<2>(params.patch_size);
1841-
x = ggml_pad(ctx, x, pad_w, pad_h, pad_t, 0); // [N*C, T + pad_t, H + pad_h, W + pad_w]
1841+
x = sd_pad(ctx, x, pad_w, pad_h, pad_t, 0); // [N*C, T + pad_t, H + pad_h, W + pad_w]
18421842

18431843
return x;
18441844
}

0 commit comments

Comments
 (0)