Skip to content

Commit cbb261d

Browse files
committed
working simplified but still need wraps
1 parent ee0e82a commit cbb261d

File tree

2 files changed

+22
-16
lines changed

2 files changed

+22
-16
lines changed

ggml_extend.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,9 @@ __STATIC_INLINE__ struct ggml_tensor* sd_pad(struct ggml_context* ctx,
611611
if (sd_is_circular_padding_enabled()) {
612612
return ggml_pad_circular(ctx, a, 0, p0, 0, p1, 0, p2, 0, p3);
613613
}
614-
return ggml_pad(ctx, a, p0, p1, p2, p3);
614+
else {
615+
return ggml_pad(ctx, a, p0, p1, p2, p3);
616+
}
615617
}
616618

617619
__STATIC_INLINE__ struct ggml_tensor* sd_pad_ext(struct ggml_context* ctx,
@@ -1030,7 +1032,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_2d(struct ggml_context* ctx,
10301032
if (scale != 1.f) {
10311033
x = ggml_scale(ctx, x, scale);
10321034
}
1033-
const bool use_circular = sd_is_circular_padding_enabled() && (p0 != 0 || p1 != 0);
1035+
const bool use_circular = sd_is_circular_padding_enabled();
1036+
LOG_DEBUG("use circular conv %d", use_circular ? 1 : 0);
10341037
const bool is_depthwise = (w->ne[2] == 1 && x->ne[2] == w->ne[3]);
10351038
if (direct) {
10361039
if (use_circular) {

rope.hpp

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -63,27 +63,22 @@ namespace Rope {
6363
float omega_val = omega[j];
6464
float original_angle = position * omega_val;
6565
float angle = original_angle;
66-
int wrap = 0;
66+
float wrap = 0;
6767
if (wraps != nullptr && !wraps->empty()) {
6868
size_t wrap_size = wraps->size();
69+
// mod batch size since we only store this for one item in the batch
6970
size_t wrap_idx = wrap_size > 0 ? (i % wrap_size) : 0;
7071
wrap = (*wraps)[wrap_idx];
7172
}
7273
if (wrap > 0) {
7374
constexpr float TWO_PI = 6.28318530717958647692f;
7475
float wrap_f = static_cast<float>(wrap);
7576
float cycles = omega_val * wrap_f / TWO_PI;
76-
float rounded = std::round(cycles); // closest periodic harmonic
77-
float periodic_omega = TWO_PI * rounded / wrap_f;
78-
float periodic_angle = position * periodic_omega;
79-
float rel_pos = std::fmod(position, wrap_f);
80-
if (rel_pos < 0.0f) {
81-
rel_pos += wrap_f;
82-
}
83-
float t = wrap_f > 0.0f ? rel_pos / wrap_f : 0.0f;
84-
float window = 0.5f - 0.5f * std::cos(TWO_PI * t); // 0 at edges, 1 in the middle
85-
window = std::clamp(window, 0.0f, 1.0f);
86-
angle = periodic_angle + window * (original_angle - periodic_angle);
77+
// closest periodic harmonic, necessary to ensure things neatly tile
78+
// without this round, things don't tile at the boundaries and you end up
79+
// with the model knowing what is "center"
80+
float rounded = std::round(cycles);
81+
angle = position * TWO_PI * rounded / wrap_f;
8782
}
8883
float sin_val = std::sin(angle);
8984
float cos_val = std::cos(angle);
@@ -282,7 +277,9 @@ namespace Rope {
282277
const std::vector<ggml_tensor*>& ref_latents,
283278
bool increase_ref_index,
284279
int theta,
285-
const std::vector<int>& axes_dim) {
280+
const std::vector<int>& axes_dim,
281+
bool circular = false) {
282+
circular = true;
286283
std::vector<std::vector<float>> ids = gen_qwen_image_ids(h, w, patch_size, bs, context_len, ref_latents, increase_ref_index);
287284
std::vector<std::vector<int>> axes_wraps;
288285
if (sd_is_circular_padding_enabled() && bs > 0 && axes_dim.size() >= 3) {
@@ -294,7 +291,13 @@ namespace Rope {
294291
const size_t total_tokens = ids.size();
295292
// Track per-token wrap lengths for the row/column axes so only spatial tokens become periodic.
296293
axes_wraps.assign(axes_dim.size(), std::vector<int>(total_tokens / bs, 0));
297-
size_t cursor = 0;
294+
size_t cursor = context_len; // ignore text tokens
295+
const size_t img_tokens = static_cast<size_t>(h_len) * static_cast<size_t>(w_len);
296+
for (size_t token_i = 0; token_i < img_tokens; ++token_i) {
297+
axes_wraps[1][cursor + token_i] = h_len;
298+
axes_wraps[2][cursor + token_i] = w_len;
299+
}
300+
cursor += img_tokens;
298301
for (ggml_tensor* ref : ref_latents) {
299302
if (ref == nullptr) {
300303
continue;

0 commit comments

Comments
 (0)