Skip to content

Commit ee0e82a

Browse files
committed
cleaned rope
1 parent 0092711 commit ee0e82a

File tree

3 files changed

+101
-48
lines changed

3 files changed

+101
-48
lines changed

ggml_extend.hpp

Lines changed: 43 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -61,47 +61,6 @@
6161
#define SD_UNUSED(x) (void)(x)
6262
#endif
6363

64-
inline std::atomic<bool>& sd_circular_padding_flag() {
65-
static std::atomic<bool> enabled{false};
66-
return enabled;
67-
}
68-
69-
inline void sd_set_circular_padding_enabled(bool enabled) {
70-
sd_circular_padding_flag().store(enabled, std::memory_order_relaxed);
71-
}
72-
73-
inline bool sd_is_circular_padding_enabled() {
74-
return sd_circular_padding_flag().load(std::memory_order_relaxed);
75-
}
76-
77-
__STATIC_INLINE__ struct ggml_tensor* sd_pad(struct ggml_context* ctx,
78-
struct ggml_tensor* a,
79-
int p0,
80-
int p1,
81-
int p2,
82-
int p3) {
83-
if (sd_is_circular_padding_enabled()) {
84-
return ggml_pad_circular(ctx, a, 0, p0, 0, p1, 0, p2, 0, p3);
85-
}
86-
return ggml_pad(ctx, a, p0, p1, p2, p3);
87-
}
88-
89-
__STATIC_INLINE__ struct ggml_tensor* sd_pad_ext(struct ggml_context* ctx,
90-
struct ggml_tensor* a,
91-
int lp0,
92-
int rp0,
93-
int lp1,
94-
int rp1,
95-
int lp2,
96-
int rp2,
97-
int lp3,
98-
int rp3) {
99-
if (sd_is_circular_padding_enabled()) {
100-
return ggml_pad_circular(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3);
101-
}
102-
return ggml_pad_ext(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3);
103-
}
104-
10564
__STATIC_INLINE__ void ggml_log_callback_default(ggml_log_level level, const char* text, void*) {
10665
switch (level) {
10766
case GGML_LOG_LEVEL_DEBUG:
@@ -628,6 +587,49 @@ __STATIC_INLINE__ void ggml_tensor_clamp(struct ggml_tensor* src, float min, flo
628587
}
629588
}
630589

590+
591+
592+
inline std::atomic<bool>& sd_circular_padding_flag() {
593+
static std::atomic<bool> enabled{false};
594+
return enabled;
595+
}
596+
597+
inline void sd_set_circular_padding_enabled(bool enabled) {
598+
sd_circular_padding_flag().store(enabled, std::memory_order_relaxed);
599+
}
600+
601+
inline bool sd_is_circular_padding_enabled() {
602+
return sd_circular_padding_flag().load(std::memory_order_relaxed);
603+
}
604+
605+
__STATIC_INLINE__ struct ggml_tensor* sd_pad(struct ggml_context* ctx,
606+
struct ggml_tensor* a,
607+
int p0,
608+
int p1,
609+
int p2,
610+
int p3) {
611+
if (sd_is_circular_padding_enabled()) {
612+
return ggml_pad_circular(ctx, a, 0, p0, 0, p1, 0, p2, 0, p3);
613+
}
614+
return ggml_pad(ctx, a, p0, p1, p2, p3);
615+
}
616+
617+
__STATIC_INLINE__ struct ggml_tensor* sd_pad_ext(struct ggml_context* ctx,
618+
struct ggml_tensor* a,
619+
int lp0,
620+
int rp0,
621+
int lp1,
622+
int rp1,
623+
int lp2,
624+
int rp2,
625+
int lp3,
626+
int rp3) {
627+
if (sd_is_circular_padding_enabled()) {
628+
return ggml_pad_circular(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3);
629+
}
630+
return ggml_pad_ext(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3);
631+
}
632+
631633
__STATIC_INLINE__ struct ggml_tensor* ggml_tensor_concat(struct ggml_context* ctx,
632634
struct ggml_tensor* a,
633635
struct ggml_tensor* b,

rope.hpp

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#ifndef __ROPE_HPP__
22
#define __ROPE_HPP__
33

4+
#include <algorithm>
5+
#include <cmath>
46
#include <vector>
57
#include "ggml_extend.hpp"
68

@@ -39,15 +41,20 @@ namespace Rope {
3941
return flat_vec;
4042
}
4143

42-
__STATIC_INLINE__ std::vector<std::vector<float>> rope(const std::vector<float>& pos, int dim, int theta) {
44+
__STATIC_INLINE__ std::vector<std::vector<float>> rope(const std::vector<float>& pos,
45+
int dim,
46+
int theta,
47+
const std::vector<int>* wraps = nullptr) {
4348
assert(dim % 2 == 0);
4449
int half_dim = dim / 2;
4550

51+
std::vector<std::vector<float>> result(pos.size(), std::vector<float>(half_dim * 4));
52+
4653
std::vector<float> scale = linspace(0.f, (dim * 1.f - 2) / dim, half_dim);
4754

4855
std::vector<float> omega(half_dim);
4956
for (int i = 0; i < half_dim; ++i) {
50-
omega[i] = 1.0 / std::pow(theta, scale[i]);
57+
omega[i] = 1.0f / std::pow(theta, scale[i]);
5158
}
5259

5360
for (size_t i = 0; i < pos.size(); ++i) {
@@ -56,7 +63,13 @@ namespace Rope {
5663
float omega_val = omega[j];
5764
float original_angle = position * omega_val;
5865
float angle = original_angle;
59-
if (sd_is_circular_padding_enabled()) {
66+
int wrap = 0;
67+
if (wraps != nullptr && !wraps->empty()) {
68+
size_t wrap_size = wraps->size();
69+
size_t wrap_idx = wrap_size > 0 ? (i % wrap_size) : 0;
70+
wrap = (*wraps)[wrap_idx];
71+
}
72+
if (wrap > 0) {
6073
constexpr float TWO_PI = 6.28318530717958647692f;
6174
float wrap_f = static_cast<float>(wrap);
6275
float cycles = omega_val * wrap_f / TWO_PI;
@@ -80,6 +93,7 @@ namespace Rope {
8093
result[i][4 * j + 3] = cos_val;
8194
}
8295
}
96+
8397
return result;
8498
}
8599

@@ -134,7 +148,8 @@ namespace Rope {
134148
__STATIC_INLINE__ std::vector<float> embed_nd(const std::vector<std::vector<float>>& ids,
135149
int bs,
136150
int theta,
137-
const std::vector<int>& axes_dim) {
151+
const std::vector<int>& axes_dim,
152+
const std::vector<std::vector<int>>* axes_wraps = nullptr) {
138153
std::vector<std::vector<float>> trans_ids = transpose(ids);
139154
size_t pos_len = ids.size() / bs;
140155
int num_axes = axes_dim.size();
@@ -149,7 +164,12 @@ namespace Rope {
149164
std::vector<std::vector<float>> emb(bs * pos_len, std::vector<float>(emb_dim * 2 * 2, 0.0));
150165
int offset = 0;
151166
for (int i = 0; i < num_axes; ++i) {
152-
std::vector<std::vector<float>> rope_emb = rope(trans_ids[i], axes_dim[i], theta); // [bs*pos_len, axes_dim[i]/2 * 2 * 2]
167+
const std::vector<int>* axis_wrap = nullptr;
168+
if (axes_wraps != nullptr && i < (int)axes_wraps->size()) {
169+
axis_wrap = &(*axes_wraps)[i];
170+
}
171+
std::vector<std::vector<float>> rope_emb =
172+
rope(trans_ids[i], axes_dim[i], theta, axis_wrap); // [bs*pos_len, axes_dim[i]/2 * 2 * 2]
153173
for (int b = 0; b < bs; ++b) {
154174
for (int j = 0; j < pos_len; ++j) {
155175
for (int k = 0; k < rope_emb[0].size(); ++k) {
@@ -264,7 +284,38 @@ namespace Rope {
264284
int theta,
265285
const std::vector<int>& axes_dim) {
266286
std::vector<std::vector<float>> ids = gen_qwen_image_ids(h, w, patch_size, bs, context_len, ref_latents, increase_ref_index);
267-
return embed_nd(ids, bs, theta, axes_dim);
287+
std::vector<std::vector<int>> axes_wraps;
288+
if (sd_is_circular_padding_enabled() && bs > 0 && axes_dim.size() >= 3) {
289+
int pad_h = (patch_size - (h % patch_size)) % patch_size;
290+
int pad_w = (patch_size - (w % patch_size)) % patch_size;
291+
int h_len = (h + pad_h) / patch_size;
292+
int w_len = (w + pad_w) / patch_size;
293+
if (h_len > 0 && w_len > 0) {
294+
const size_t total_tokens = ids.size();
295+
// Track per-token wrap lengths for the row/column axes so only spatial tokens become periodic.
296+
axes_wraps.assign(axes_dim.size(), std::vector<int>(total_tokens / bs, 0));
297+
size_t cursor = 0;
298+
for (ggml_tensor* ref : ref_latents) {
299+
if (ref == nullptr) {
300+
continue;
301+
}
302+
int ref_h = static_cast<int>(ref->ne[1]);
303+
int ref_w = static_cast<int>(ref->ne[0]);
304+
int ref_pad_h = (patch_size - (ref_h % patch_size)) % patch_size;
305+
int ref_pad_w = (patch_size - (ref_w % patch_size)) % patch_size;
306+
int ref_h_len = (ref_h + ref_pad_h) / patch_size;
307+
int ref_w_len = (ref_w + ref_pad_w) / patch_size;
308+
size_t ref_n_tokens = static_cast<size_t>(ref_h_len) * static_cast<size_t>(ref_w_len);
309+
for (size_t token_i = 0; token_i < ref_n_tokens; ++token_i) {
310+
axes_wraps[1][cursor + token_i] = ref_h_len;
311+
axes_wraps[2][cursor + token_i] = ref_w_len;
312+
}
313+
cursor += ref_n_tokens;
314+
}
315+
}
316+
}
317+
const std::vector<std::vector<int>>* wraps_ptr = axes_wraps.empty() ? nullptr : &axes_wraps;
318+
return embed_nd(ids, bs, theta, axes_dim, wraps_ptr);
268319
}
269320

270321
__STATIC_INLINE__ std::vector<std::vector<float>> gen_vid_ids(int t,

0 commit comments

Comments
 (0)