6060#define SD_UNUSED (x ) (void )(x)
6161#endif
6262
63+ inline bool & sd_global_circular_padding_enabled () {
64+ static bool enabled = false ;
65+ return enabled;
66+ }
67+
68+ __STATIC_INLINE__ struct ggml_tensor * sd_pad (struct ggml_context * ctx,
69+ struct ggml_tensor * a,
70+ int p0,
71+ int p1,
72+ int p2,
73+ int p3) {
74+ if (sd_global_circular_padding_enabled ()) {
75+ return ggml_pad_circular (ctx, a, 0 , p0, 0 , p1, 0 , p2, 0 , p3);
76+ }
77+ return ggml_pad (ctx, a, p0, p1, p2, p3);
78+ }
79+
80+ __STATIC_INLINE__ struct ggml_tensor * sd_pad_ext (struct ggml_context * ctx,
81+ struct ggml_tensor * a,
82+ int lp0,
83+ int rp0,
84+ int lp1,
85+ int rp1,
86+ int lp2,
87+ int rp2,
88+ int lp3,
89+ int rp3) {
90+ if (sd_global_circular_padding_enabled ()) {
91+ return ggml_pad_circular (ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3);
92+ }
93+ return ggml_pad_ext (ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3);
94+ }
95+
6396__STATIC_INLINE__ void ggml_log_callback_default (ggml_log_level level, const char * text, void *) {
6497 switch (level) {
6598 case GGML_LOG_LEVEL_DEBUG:
@@ -986,10 +1019,24 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_2d(struct ggml_context* ctx,
9861019 if (scale != 1 .f ) {
9871020 x = ggml_scale (ctx, x, scale);
9881021 }
1022+ const bool use_circular = sd_global_circular_padding_enabled () && (p0 != 0 || p1 != 0 );
1023+ const bool is_depthwise = (w->ne [2 ] == 1 && x->ne [2 ] == w->ne [3 ]);
9891024 if (direct) {
990- x = ggml_conv_2d_direct (ctx, w, x, s0, s1, p0, p1, d0, d1);
1025+ if (use_circular) {
1026+ if (is_depthwise) {
1027+ x = ggml_conv_2d_dw_direct_circular (ctx, w, x, s0, s1, p0, p1, d0, d1);
1028+ } else {
1029+ x = ggml_conv_2d_direct_circular (ctx, w, x, s0, s1, p0, p1, d0, d1);
1030+ }
1031+ } else {
1032+ x = ggml_conv_2d_direct (ctx, w, x, s0, s1, p0, p1, d0, d1);
1033+ }
9911034 } else {
992- x = ggml_conv_2d (ctx, w, x, s0, s1, p0, p1, d0, d1);
1035+ if (use_circular) {
1036+ x = ggml_conv_2d_circular (ctx, w, x, s0, s1, p0, p1, d0, d1);
1037+ } else {
1038+ x = ggml_conv_2d (ctx, w, x, s0, s1, p0, p1, d0, d1);
1039+ }
9931040 }
9941041 if (scale != 1 .f ) {
9951042 x = ggml_scale (ctx, x, 1 .f / scale);
@@ -1190,7 +1237,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
11901237
11911238 auto build_kqv = [&](ggml_tensor* q_in, ggml_tensor* k_in, ggml_tensor* v_in, ggml_tensor* mask_in) -> ggml_tensor* {
11921239 if (kv_pad != 0 ) {
1193- k_in = ggml_pad (ctx, k_in, 0 , kv_pad, 0 , 0 );
1240+ k_in = sd_pad (ctx, k_in, 0 , kv_pad, 0 , 0 );
11941241 }
11951242 if (kv_scale != 1 .0f ) {
11961243 k_in = ggml_scale (ctx, k_in, kv_scale);
@@ -1200,7 +1247,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
12001247 v_in = ggml_nn_cont (ctx, ggml_permute (ctx, v_in, 0 , 2 , 1 , 3 ));
12011248 v_in = ggml_reshape_3d (ctx, v_in, d_head, L_k, n_kv_head * N);
12021249 if (kv_pad != 0 ) {
1203- v_in = ggml_pad (ctx, v_in, 0 , kv_pad, 0 , 0 );
1250+ v_in = sd_pad (ctx, v_in, 0 , kv_pad, 0 , 0 );
12041251 }
12051252 if (kv_scale != 1 .0f ) {
12061253 v_in = ggml_scale (ctx, v_in, kv_scale);
@@ -1223,7 +1270,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
12231270 mask_pad = GGML_PAD (L_q, GGML_KQ_MASK_PAD) - mask_in->ne [1 ];
12241271 }
12251272 if (mask_pad > 0 ) {
1226- mask_in = ggml_pad (ctx, mask_in, 0 , mask_pad, 0 , 0 );
1273+ mask_in = sd_pad (ctx, mask_in, 0 , mask_pad, 0 , 0 );
12271274 }
12281275 mask_in = ggml_cast (ctx, mask_in, GGML_TYPE_F16);
12291276 }
0 commit comments