11#ifndef __ROPE_HPP__
22#define __ROPE_HPP__
33
4+ #include < algorithm>
5+ #include < cmath>
46#include < vector>
57#include " ggml_extend.hpp"
68
@@ -39,15 +41,20 @@ namespace Rope {
3941 return flat_vec;
4042 }
4143
42- __STATIC_INLINE__ std::vector<std::vector<float >> rope (const std::vector<float >& pos, int dim, int theta) {
44+ __STATIC_INLINE__ std::vector<std::vector<float >> rope (const std::vector<float >& pos,
45+ int dim,
46+ int theta,
47+ const std::vector<int >* wraps = nullptr ) {
4348 assert (dim % 2 == 0 );
4449 int half_dim = dim / 2 ;
4550
51+ std::vector<std::vector<float >> result (pos.size (), std::vector<float >(half_dim * 4 ));
52+
4653 std::vector<float > scale = linspace (0 .f , (dim * 1 .f - 2 ) / dim, half_dim);
4754
4855 std::vector<float > omega (half_dim);
4956 for (int i = 0 ; i < half_dim; ++i) {
50- omega[i] = 1.0 / std::pow (theta, scale[i]);
57+ omega[i] = 1 .0f / std::pow (theta, scale[i]);
5158 }
5259
5360 for (size_t i = 0 ; i < pos.size (); ++i) {
@@ -56,7 +63,13 @@ namespace Rope {
5663 float omega_val = omega[j];
5764 float original_angle = position * omega_val;
5865 float angle = original_angle;
59- if (sd_is_circular_padding_enabled ()) {
66+ int wrap = 0 ;
67+ if (wraps != nullptr && !wraps->empty ()) {
68+ size_t wrap_size = wraps->size ();
69+ size_t wrap_idx = wrap_size > 0 ? (i % wrap_size) : 0 ;
70+ wrap = (*wraps)[wrap_idx];
71+ }
72+ if (wrap > 0 ) {
6073 constexpr float TWO_PI = 6 .28318530717958647692f ;
6174 float wrap_f = static_cast <float >(wrap);
6275 float cycles = omega_val * wrap_f / TWO_PI;
@@ -80,6 +93,7 @@ namespace Rope {
8093 result[i][4 * j + 3 ] = cos_val;
8194 }
8295 }
96+
8397 return result;
8498 }
8599
@@ -134,7 +148,8 @@ namespace Rope {
134148 __STATIC_INLINE__ std::vector<float > embed_nd (const std::vector<std::vector<float >>& ids,
135149 int bs,
136150 int theta,
137- const std::vector<int >& axes_dim) {
151+ const std::vector<int >& axes_dim,
152+ const std::vector<std::vector<int >>* axes_wraps = nullptr ) {
138153 std::vector<std::vector<float >> trans_ids = transpose (ids);
139154 size_t pos_len = ids.size () / bs;
140155 int num_axes = axes_dim.size ();
@@ -149,7 +164,12 @@ namespace Rope {
149164 std::vector<std::vector<float >> emb (bs * pos_len, std::vector<float >(emb_dim * 2 * 2 , 0.0 ));
150165 int offset = 0 ;
151166 for (int i = 0 ; i < num_axes; ++i) {
152- std::vector<std::vector<float >> rope_emb = rope (trans_ids[i], axes_dim[i], theta); // [bs*pos_len, axes_dim[i]/2 * 2 * 2]
167+ const std::vector<int >* axis_wrap = nullptr ;
168+ if (axes_wraps != nullptr && i < (int )axes_wraps->size ()) {
169+ axis_wrap = &(*axes_wraps)[i];
170+ }
171+ std::vector<std::vector<float >> rope_emb =
172+ rope (trans_ids[i], axes_dim[i], theta, axis_wrap); // [bs*pos_len, axes_dim[i]/2 * 2 * 2]
153173 for (int b = 0 ; b < bs; ++b) {
154174 for (int j = 0 ; j < pos_len; ++j) {
155175 for (int k = 0 ; k < rope_emb[0 ].size (); ++k) {
@@ -264,7 +284,38 @@ namespace Rope {
264284 int theta,
265285 const std::vector<int >& axes_dim) {
266286 std::vector<std::vector<float >> ids = gen_qwen_image_ids (h, w, patch_size, bs, context_len, ref_latents, increase_ref_index);
267- return embed_nd (ids, bs, theta, axes_dim);
287+ std::vector<std::vector<int >> axes_wraps;
288+ if (sd_is_circular_padding_enabled () && bs > 0 && axes_dim.size () >= 3 ) {
289+ int pad_h = (patch_size - (h % patch_size)) % patch_size;
290+ int pad_w = (patch_size - (w % patch_size)) % patch_size;
291+ int h_len = (h + pad_h) / patch_size;
292+ int w_len = (w + pad_w) / patch_size;
293+ if (h_len > 0 && w_len > 0 ) {
294+ const size_t total_tokens = ids.size ();
295+ // Track per-token wrap lengths for the row/column axes so only spatial tokens become periodic.
296+ axes_wraps.assign (axes_dim.size (), std::vector<int >(total_tokens / bs, 0 ));
297+ size_t cursor = 0 ;
298+ for (ggml_tensor* ref : ref_latents) {
299+ if (ref == nullptr ) {
300+ continue ;
301+ }
302+ int ref_h = static_cast <int >(ref->ne [1 ]);
303+ int ref_w = static_cast <int >(ref->ne [0 ]);
304+ int ref_pad_h = (patch_size - (ref_h % patch_size)) % patch_size;
305+ int ref_pad_w = (patch_size - (ref_w % patch_size)) % patch_size;
306+ int ref_h_len = (ref_h + ref_pad_h) / patch_size;
307+ int ref_w_len = (ref_w + ref_pad_w) / patch_size;
308+ size_t ref_n_tokens = static_cast <size_t >(ref_h_len) * static_cast <size_t >(ref_w_len);
309+ for (size_t token_i = 0 ; token_i < ref_n_tokens; ++token_i) {
310+ axes_wraps[1 ][cursor + token_i] = ref_h_len;
311+ axes_wraps[2 ][cursor + token_i] = ref_w_len;
312+ }
313+ cursor += ref_n_tokens;
314+ }
315+ }
316+ }
317+ const std::vector<std::vector<int >>* wraps_ptr = axes_wraps.empty () ? nullptr : &axes_wraps;
318+ return embed_nd (ids, bs, theta, axes_dim, wraps_ptr);
268319 }
269320
270321 __STATIC_INLINE__ std::vector<std::vector<float >> gen_vid_ids (int t,
0 commit comments