Skip to content

Commit 8d7f679

Browse files
committed
Further clean of rope
1 parent cbb261d commit 8d7f679

File tree

1 file changed

+22
-22
lines changed

1 file changed

+22
-22
lines changed

rope.hpp

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ namespace Rope {
4444
__STATIC_INLINE__ std::vector<std::vector<float>> rope(const std::vector<float>& pos,
4545
int dim,
4646
int theta,
47-
const std::vector<int>* wraps = nullptr) {
47+
const std::vector<int>* wrap_dims = nullptr) {
4848
assert(dim % 2 == 0);
4949
int half_dim = dim / 2;
5050

@@ -63,16 +63,16 @@ namespace Rope {
6363
float omega_val = omega[j];
6464
float original_angle = position * omega_val;
6565
float angle = original_angle;
66-
float wrap = 0;
67-
if (wraps != nullptr && !wraps->empty()) {
68-
size_t wrap_size = wraps->size();
66+
int wrap_dim = 0;
67+
if (wrap_dims != nullptr && !wrap_dims->empty()) {
68+
size_t wrap_size = wrap_dims->size();
6969
// mod batch size since we only store this for one item in the batch
7070
size_t wrap_idx = wrap_size > 0 ? (i % wrap_size) : 0;
71-
wrap = (*wraps)[wrap_idx];
71+
wrap_dim = (*wrap_dims)[wrap_idx];
7272
}
73-
if (wrap > 0) {
73+
if (wrap_dim > 0) {
7474
constexpr float TWO_PI = 6.28318530717958647692f;
75-
float wrap_f = static_cast<float>(wrap);
75+
float wrap_f = static_cast<float>(wrap_dim);
7676
float cycles = omega_val * wrap_f / TWO_PI;
7777
// closest periodic harmonic, necessary to ensure things neatly tile
7878
// without this round, things don't tile at the boundaries and you end up
@@ -144,7 +144,7 @@ namespace Rope {
144144
int bs,
145145
int theta,
146146
const std::vector<int>& axes_dim,
147-
const std::vector<std::vector<int>>* axes_wraps = nullptr) {
147+
const std::vector<std::vector<int>>* wrap_dims = nullptr) {
148148
std::vector<std::vector<float>> trans_ids = transpose(ids);
149149
size_t pos_len = ids.size() / bs;
150150
int num_axes = axes_dim.size();
@@ -159,12 +159,12 @@ namespace Rope {
159159
std::vector<std::vector<float>> emb(bs * pos_len, std::vector<float>(emb_dim * 2 * 2, 0.0));
160160
int offset = 0;
161161
for (int i = 0; i < num_axes; ++i) {
162-
const std::vector<int>* axis_wrap = nullptr;
163-
if (axes_wraps != nullptr && i < (int)axes_wraps->size()) {
164-
axis_wrap = &(*axes_wraps)[i];
162+
const std::vector<int>* axis_wrap_dims = nullptr;
163+
if (wrap_dims != nullptr && i < (int)wrap_dims->size()) {
164+
axis_wrap_dims = &(*wrap_dims)[i];
165165
}
166166
std::vector<std::vector<float>> rope_emb =
167-
rope(trans_ids[i], axes_dim[i], theta, axis_wrap); // [bs*pos_len, axes_dim[i]/2 * 2 * 2]
167+
rope(trans_ids[i], axes_dim[i], theta, axis_wrap_dims); // [bs*pos_len, axes_dim[i]/2 * 2 * 2]
168168
for (int b = 0; b < bs; ++b) {
169169
for (int j = 0; j < pos_len; ++j) {
170170
for (int k = 0; k < rope_emb[0].size(); ++k) {
@@ -277,11 +277,10 @@ namespace Rope {
277277
const std::vector<ggml_tensor*>& ref_latents,
278278
bool increase_ref_index,
279279
int theta,
280-
const std::vector<int>& axes_dim,
281-
bool circular = false) {
282-
circular = true;
280+
const std::vector<int>& axes_dim) {
283281
std::vector<std::vector<float>> ids = gen_qwen_image_ids(h, w, patch_size, bs, context_len, ref_latents, increase_ref_index);
284-
std::vector<std::vector<int>> axes_wraps;
282+
std::vector<std::vector<int>> wrap_dims;
283+
// This logic simply stores the (pad and patch_adjusted) sizes of images so we can make sure rope correctly tiles
285284
if (sd_is_circular_padding_enabled() && bs > 0 && axes_dim.size() >= 3) {
286285
int pad_h = (patch_size - (h % patch_size)) % patch_size;
287286
int pad_w = (patch_size - (w % patch_size)) % patch_size;
@@ -290,14 +289,15 @@ namespace Rope {
290289
if (h_len > 0 && w_len > 0) {
291290
const size_t total_tokens = ids.size();
292291
// Track per-token wrap lengths for the row/column axes so only spatial tokens become periodic.
293-
axes_wraps.assign(axes_dim.size(), std::vector<int>(total_tokens / bs, 0));
292+
wrap_dims.assign(axes_dim.size(), std::vector<int>(total_tokens / bs, 0));
294293
size_t cursor = context_len; // ignore text tokens
295294
const size_t img_tokens = static_cast<size_t>(h_len) * static_cast<size_t>(w_len);
296295
for (size_t token_i = 0; token_i < img_tokens; ++token_i) {
297-
axes_wraps[1][cursor + token_i] = h_len;
298-
axes_wraps[2][cursor + token_i] = w_len;
296+
wrap_dims[1][cursor + token_i] = h_len;
297+
wrap_dims[2][cursor + token_i] = w_len;
299298
}
300299
cursor += img_tokens;
300+
// For each reference image, store wrap sizes as well
301301
for (ggml_tensor* ref : ref_latents) {
302302
if (ref == nullptr) {
303303
continue;
@@ -310,14 +310,14 @@ namespace Rope {
310310
int ref_w_len = (ref_w + ref_pad_w) / patch_size;
311311
size_t ref_n_tokens = static_cast<size_t>(ref_h_len) * static_cast<size_t>(ref_w_len);
312312
for (size_t token_i = 0; token_i < ref_n_tokens; ++token_i) {
313-
axes_wraps[1][cursor + token_i] = ref_h_len;
314-
axes_wraps[2][cursor + token_i] = ref_w_len;
313+
wrap_dims[1][cursor + token_i] = ref_h_len;
314+
wrap_dims[2][cursor + token_i] = ref_w_len;
315315
}
316316
cursor += ref_n_tokens;
317317
}
318318
}
319319
}
320-
const std::vector<std::vector<int>>* wraps_ptr = axes_wraps.empty() ? nullptr : &axes_wraps;
320+
const std::vector<std::vector<int>>* wraps_ptr = wrap_dims.empty() ? nullptr : &wrap_dims;
321321
return embed_nd(ids, bs, theta, axes_dim, wraps_ptr);
322322
}
323323

0 commit comments

Comments
 (0)