@@ -63,27 +63,22 @@ namespace Rope {
6363 float omega_val = omega[j];
6464 float original_angle = position * omega_val;
6565 float angle = original_angle;
66- int wrap = 0 ;
66+ float wrap = 0 ;
6767 if (wraps != nullptr && !wraps->empty ()) {
6868 size_t wrap_size = wraps->size ();
69+ // mod batch size since we only store this for one item in the batch
6970 size_t wrap_idx = wrap_size > 0 ? (i % wrap_size) : 0 ;
7071 wrap = (*wraps)[wrap_idx];
7172 }
7273 if (wrap > 0 ) {
7374 constexpr float TWO_PI = 6 .28318530717958647692f ;
7475 float wrap_f = static_cast <float >(wrap);
7576 float cycles = omega_val * wrap_f / TWO_PI;
76- float rounded = std::round (cycles); // closest periodic harmonic
77- float periodic_omega = TWO_PI * rounded / wrap_f;
78- float periodic_angle = position * periodic_omega;
79- float rel_pos = std::fmod (position, wrap_f);
80- if (rel_pos < 0 .0f ) {
81- rel_pos += wrap_f;
82- }
83- float t = wrap_f > 0 .0f ? rel_pos / wrap_f : 0 .0f ;
84- float window = 0 .5f - 0 .5f * std::cos (TWO_PI * t); // 0 at edges, 1 in the middle
85- window = std::clamp (window, 0 .0f , 1 .0f );
86- angle = periodic_angle + window * (original_angle - periodic_angle);
77+ // closest periodic harmonic, necessary to ensure things neatly tile
78+ // without this round, things don't tile at the boundaries and you end up
79+ // with the model knowing what is "center"
80+ float rounded = std::round (cycles);
81+ angle = position * TWO_PI * rounded / wrap_f;
8782 }
8883 float sin_val = std::sin (angle);
8984 float cos_val = std::cos (angle);
@@ -282,7 +277,9 @@ namespace Rope {
282277 const std::vector<ggml_tensor*>& ref_latents,
283278 bool increase_ref_index,
284279 int theta,
285- const std::vector<int >& axes_dim) {
280+ const std::vector<int >& axes_dim,
281+ bool circular = false ) {
282+ circular = true ;
286283 std::vector<std::vector<float >> ids = gen_qwen_image_ids (h, w, patch_size, bs, context_len, ref_latents, increase_ref_index);
287284 std::vector<std::vector<int >> axes_wraps;
288285 if (sd_is_circular_padding_enabled () && bs > 0 && axes_dim.size () >= 3 ) {
@@ -294,7 +291,13 @@ namespace Rope {
294291 const size_t total_tokens = ids.size ();
295292 // Track per-token wrap lengths for the row/column axes so only spatial tokens become periodic.
296293 axes_wraps.assign (axes_dim.size (), std::vector<int >(total_tokens / bs, 0 ));
297- size_t cursor = 0 ;
294+ size_t cursor = context_len; // ignore text tokens
295+ const size_t img_tokens = static_cast <size_t >(h_len) * static_cast <size_t >(w_len);
296+ for (size_t token_i = 0 ; token_i < img_tokens; ++token_i) {
297+ axes_wraps[1 ][cursor + token_i] = h_len;
298+ axes_wraps[2 ][cursor + token_i] = w_len;
299+ }
300+ cursor += img_tokens;
298301 for (ggml_tensor* ref : ref_latents) {
299302 if (ref == nullptr ) {
300303 continue ;
0 commit comments