@@ -44,7 +44,7 @@ namespace Rope {
4444 __STATIC_INLINE__ std::vector<std::vector<float >> rope (const std::vector<float >& pos,
4545 int dim,
4646 int theta,
47- const std::vector<int >* wraps = nullptr ) {
47+ const std::vector<int >* wrap_dims = nullptr ) {
4848 assert (dim % 2 == 0 );
4949 int half_dim = dim / 2 ;
5050
@@ -63,16 +63,16 @@ namespace Rope {
6363 float omega_val = omega[j];
6464 float original_angle = position * omega_val;
6565 float angle = original_angle;
66- float wrap = 0 ;
67- if (wraps != nullptr && !wraps ->empty ()) {
68- size_t wrap_size = wraps ->size ();
66+ int wrap_dim = 0 ;
67+ if (wrap_dims != nullptr && !wrap_dims ->empty ()) {
68+ size_t wrap_size = wrap_dims ->size ();
6969 // mod batch size since we only store this for one item in the batch
7070 size_t wrap_idx = wrap_size > 0 ? (i % wrap_size) : 0 ;
71- wrap = (*wraps )[wrap_idx];
71+ wrap_dim = (*wrap_dims )[wrap_idx];
7272 }
73- if (wrap > 0 ) {
73+ if (wrap_dim > 0 ) {
7474 constexpr float TWO_PI = 6 .28318530717958647692f ;
75- float wrap_f = static_cast <float >(wrap );
75+ float wrap_f = static_cast <float >(wrap_dim );
7676 float cycles = omega_val * wrap_f / TWO_PI;
7777 // closest periodic harmonic, necessary to ensure things neatly tile
7878 // without this round, things don't tile at the boundaries and you end up
@@ -144,7 +144,7 @@ namespace Rope {
144144 int bs,
145145 int theta,
146146 const std::vector<int >& axes_dim,
147- const std::vector<std::vector<int >>* axes_wraps = nullptr ) {
147+ const std::vector<std::vector<int >>* wrap_dims = nullptr ) {
148148 std::vector<std::vector<float >> trans_ids = transpose (ids);
149149 size_t pos_len = ids.size () / bs;
150150 int num_axes = axes_dim.size ();
@@ -159,12 +159,12 @@ namespace Rope {
159159 std::vector<std::vector<float >> emb (bs * pos_len, std::vector<float >(emb_dim * 2 * 2 , 0.0 ));
160160 int offset = 0 ;
161161 for (int i = 0 ; i < num_axes; ++i) {
162- const std::vector<int >* axis_wrap = nullptr ;
163- if (axes_wraps != nullptr && i < (int )axes_wraps ->size ()) {
164- axis_wrap = &(*axes_wraps )[i];
162+ const std::vector<int >* axis_wrap_dims = nullptr ;
163+ if (wrap_dims != nullptr && i < (int )wrap_dims ->size ()) {
164+ axis_wrap_dims = &(*wrap_dims )[i];
165165 }
166166 std::vector<std::vector<float >> rope_emb =
167- rope (trans_ids[i], axes_dim[i], theta, axis_wrap ); // [bs*pos_len, axes_dim[i]/2 * 2 * 2]
167+ rope (trans_ids[i], axes_dim[i], theta, axis_wrap_dims ); // [bs*pos_len, axes_dim[i]/2 * 2 * 2]
168168 for (int b = 0 ; b < bs; ++b) {
169169 for (int j = 0 ; j < pos_len; ++j) {
170170 for (int k = 0 ; k < rope_emb[0 ].size (); ++k) {
@@ -277,11 +277,10 @@ namespace Rope {
277277 const std::vector<ggml_tensor*>& ref_latents,
278278 bool increase_ref_index,
279279 int theta,
280- const std::vector<int >& axes_dim,
281- bool circular = false ) {
282- circular = true ;
280+ const std::vector<int >& axes_dim) {
283281 std::vector<std::vector<float >> ids = gen_qwen_image_ids (h, w, patch_size, bs, context_len, ref_latents, increase_ref_index);
284- std::vector<std::vector<int >> axes_wraps;
282+ std::vector<std::vector<int >> wrap_dims;
283+ // This logic simply stores the (pad and patch_adjusted) sizes of images so we can make sure rope correctly tiles
285284 if (sd_is_circular_padding_enabled () && bs > 0 && axes_dim.size () >= 3 ) {
286285 int pad_h = (patch_size - (h % patch_size)) % patch_size;
287286 int pad_w = (patch_size - (w % patch_size)) % patch_size;
@@ -290,14 +289,15 @@ namespace Rope {
290289 if (h_len > 0 && w_len > 0 ) {
291290 const size_t total_tokens = ids.size ();
292291 // Track per-token wrap lengths for the row/column axes so only spatial tokens become periodic.
293- axes_wraps .assign (axes_dim.size (), std::vector<int >(total_tokens / bs, 0 ));
292+ wrap_dims .assign (axes_dim.size (), std::vector<int >(total_tokens / bs, 0 ));
294293 size_t cursor = context_len; // ignore text tokens
295294 const size_t img_tokens = static_cast <size_t >(h_len) * static_cast <size_t >(w_len);
296295 for (size_t token_i = 0 ; token_i < img_tokens; ++token_i) {
297- axes_wraps [1 ][cursor + token_i] = h_len;
298- axes_wraps [2 ][cursor + token_i] = w_len;
296+ wrap_dims [1 ][cursor + token_i] = h_len;
297+ wrap_dims [2 ][cursor + token_i] = w_len;
299298 }
300299 cursor += img_tokens;
300+ // For each reference image, store wrap sizes as well
301301 for (ggml_tensor* ref : ref_latents) {
302302 if (ref == nullptr ) {
303303 continue ;
@@ -310,14 +310,14 @@ namespace Rope {
310310 int ref_w_len = (ref_w + ref_pad_w) / patch_size;
311311 size_t ref_n_tokens = static_cast <size_t >(ref_h_len) * static_cast <size_t >(ref_w_len);
312312 for (size_t token_i = 0 ; token_i < ref_n_tokens; ++token_i) {
313- axes_wraps [1 ][cursor + token_i] = ref_h_len;
314- axes_wraps [2 ][cursor + token_i] = ref_w_len;
313+ wrap_dims [1 ][cursor + token_i] = ref_h_len;
314+ wrap_dims [2 ][cursor + token_i] = ref_w_len;
315315 }
316316 cursor += ref_n_tokens;
317317 }
318318 }
319319 }
320- const std::vector<std::vector<int >>* wraps_ptr = axes_wraps .empty () ? nullptr : &axes_wraps ;
320+ const std::vector<std::vector<int >>* wraps_ptr = wrap_dims .empty () ? nullptr : &wrap_dims ;
321321 return embed_nd (ids, bs, theta, axes_dim, wraps_ptr);
322322 }
323323
0 commit comments