Skip to content

Commit 0092711

Browse files
committed
cleaner implementation of tiling support in sd cpp
1 parent 6d85b94 commit 0092711

File tree

1 file changed

+27
-15
lines changed

1 file changed

+27
-15
lines changed

rope.hpp

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -50,24 +50,36 @@ namespace Rope {
5050
omega[i] = 1.0 / std::pow(theta, scale[i]);
5151
}
5252

53-
int pos_size = pos.size();
54-
std::vector<std::vector<float>> out(pos_size, std::vector<float>(half_dim));
55-
for (int i = 0; i < pos_size; ++i) {
53+
for (size_t i = 0; i < pos.size(); ++i) {
54+
float position = pos[i];
5655
for (int j = 0; j < half_dim; ++j) {
57-
out[i][j] = pos[i] * omega[j];
58-
}
59-
}
60-
61-
std::vector<std::vector<float>> result(pos_size, std::vector<float>(half_dim * 4));
62-
for (int i = 0; i < pos_size; ++i) {
63-
for (int j = 0; j < half_dim; ++j) {
64-
result[i][4 * j] = std::cos(out[i][j]);
65-
result[i][4 * j + 1] = -std::sin(out[i][j]);
66-
result[i][4 * j + 2] = std::sin(out[i][j]);
67-
result[i][4 * j + 3] = std::cos(out[i][j]);
56+
float omega_val = omega[j];
57+
float original_angle = position * omega_val;
58+
float angle = original_angle;
59+
if (sd_is_circular_padding_enabled()) {
60+
constexpr float TWO_PI = 6.28318530717958647692f;
61+
float wrap_f = static_cast<float>(wrap);
62+
float cycles = omega_val * wrap_f / TWO_PI;
63+
float rounded = std::round(cycles); // closest periodic harmonic
64+
float periodic_omega = TWO_PI * rounded / wrap_f;
65+
float periodic_angle = position * periodic_omega;
66+
float rel_pos = std::fmod(position, wrap_f);
67+
if (rel_pos < 0.0f) {
68+
rel_pos += wrap_f;
69+
}
70+
float t = wrap_f > 0.0f ? rel_pos / wrap_f : 0.0f;
71+
float window = 0.5f - 0.5f * std::cos(TWO_PI * t); // 0 at edges, 1 in the middle
72+
window = std::clamp(window, 0.0f, 1.0f);
73+
angle = periodic_angle + window * (original_angle - periodic_angle);
74+
}
75+
float sin_val = std::sin(angle);
76+
float cos_val = std::cos(angle);
77+
result[i][4 * j] = cos_val;
78+
result[i][4 * j + 1] = -sin_val;
79+
result[i][4 * j + 2] = sin_val;
80+
result[i][4 * j + 3] = cos_val;
6881
}
6982
}
70-
7183
return result;
7284
}
7385

0 commit comments

Comments
 (0)