Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions csrc/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4157,6 +4157,25 @@ PYBIND11_MODULE(flash_rt_kernels, m) {
py::arg("B"), py::arg("S"), py::arg("conv_dim"), py::arg("k"),
py::arg("apply_silu") = true, py::arg("stream") = 0);

m.def("causal_conv1d_qwen36_update_chunk_saves_bf16",
[](uintptr_t x, uintptr_t w, uintptr_t bias,
uintptr_t out, uintptr_t state,
uintptr_t state_steps, int64_t step_stride,
int B, int S, int conv_dim, int k, bool apply_silu,
uintptr_t stream) {
flash_rt::kernels::causal_conv1d_qwen36_update_chunk_saves_bf16(
to_ptr(x), to_ptr(w),
bias ? to_ptr(bias) : nullptr,
to_ptr(out), to_ptr(state),
to_ptr(state_steps), step_stride,
B, S, conv_dim, k, apply_silu, to_stream(stream));
},
py::arg("x"), py::arg("w"), py::arg("bias"),
py::arg("out"), py::arg("state"),
py::arg("state_steps"), py::arg("step_stride"),
py::arg("B"), py::arg("S"), py::arg("conv_dim"), py::arg("k"),
py::arg("apply_silu") = true, py::arg("stream") = 0);

m.def("causal_conv1d_qwen36_update_chunk_parallel_bf16",
[](uintptr_t x, uintptr_t w, uintptr_t bias,
uintptr_t out, uintptr_t state,
Expand Down Expand Up @@ -4876,6 +4895,31 @@ PYBIND11_MODULE(flash_rt_kernels, m) {
py::arg("a_stride"), py::arg("b_stride"),
py::arg("use_qk_l2norm") = true, py::arg("stream") = 0);

m.def("qwen36_gdn_chunk_from_conv_smem_strided_saves_bf16",
[](uintptr_t conv_out, uintptr_t a, uintptr_t b,
uintptr_t neg_exp_A_log, uintptr_t dt_bias,
uintptr_t state, uintptr_t state_steps, int64_t step_stride,
uintptr_t out,
int S, int num_v_heads, int a_stride, int b_stride,
bool use_qk_l2norm, uintptr_t stream) {
flash_rt::kernels::
qwen36_gdn_chunk_from_conv_smem_strided_saves_bf16(
to_ptr(conv_out), to_ptr(a), to_ptr(b),
reinterpret_cast<const float*>(neg_exp_A_log),
reinterpret_cast<const float*>(dt_bias),
to_ptr(state), to_ptr(state_steps), step_stride,
to_ptr(out),
S, num_v_heads, a_stride, b_stride,
use_qk_l2norm, to_stream(stream));
},
py::arg("conv_out"), py::arg("a"), py::arg("b"),
py::arg("neg_exp_A_log"), py::arg("dt_bias"),
py::arg("state"), py::arg("state_steps"),
py::arg("step_stride"), py::arg("out"),
py::arg("S"), py::arg("num_v_heads"),
py::arg("a_stride"), py::arg("b_stride"),
py::arg("use_qk_l2norm") = true, py::arg("stream") = 0);

m.def("qwen36_gdn_wy_norm_cumsum_bf16",
[](uintptr_t q16, uintptr_t k16, uintptr_t g,
uintptr_t q16_l2, uintptr_t k16_l2, uintptr_t g_cumsum,
Expand Down
99 changes: 99 additions & 0 deletions csrc/kernels/causal_conv1d_qwen36.cu
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,84 @@ __global__ void causal_conv1d_update_chunk_kernel(
}
}

// Per-step-checkpoint variant of the chunk kernel above: identical
// math (the carried window values are bf16-exact in fp32 registers),
// plus a bf16 dump of the post-shift state after every step into
// ``state_steps`` (step s at state_steps + s * step_stride). Slot s
// byte-matches the committed state of an S = s + 1 run, which is what
// the spec-decode partial-accept rollback copies.
__global__ void causal_conv1d_update_chunk_saves_kernel(
const __nv_bfloat16* __restrict__ x,
const __nv_bfloat16* __restrict__ w,
const __nv_bfloat16* __restrict__ bias,
__nv_bfloat16* __restrict__ out,
__nv_bfloat16* __restrict__ state,
__nv_bfloat16* __restrict__ state_steps,
int64_t step_stride,
int B, int S, int conv_dim, int k,
bool apply_silu)
{
const int c = blockIdx.x * kThreadsX + threadIdx.x;
const int b = blockIdx.y;
if (c >= conv_dim) return;

const int sk = k - 1;
const int state_base = (b * conv_dim + c) * sk;

float wv[kMaxK];
#pragma unroll
for (int i = 0; i < kMaxK; ++i) {
wv[i] = (i < k) ? static_cast<float>(w[c * k + i]) : 0.0f;
}

float sv[kMaxK];
#pragma unroll
for (int i = 0; i < kMaxK; ++i) {
sv[i] = (i < sk)
? static_cast<float>(state[state_base + i])
: 0.0f;
}

for (int s = 0; s < S; ++s) {
const float x_v = static_cast<float>(
x[(size_t)b * S * conv_dim + (size_t)s * conv_dim + c]);

float acc = (bias != nullptr) ? static_cast<float>(bias[c]) : 0.0f;
#pragma unroll
for (int i = 0; i < kMaxK; ++i) {
if (i < sk) acc = fmaf(sv[i], wv[i], acc);
}
acc = fmaf(x_v, wv[sk], acc);

if (apply_silu) acc = silu(acc);
out[(size_t)b * S * conv_dim + (size_t)s * conv_dim + c] =
__float2bfloat16(acc);

#pragma unroll
for (int i = 0; i < kMaxK - 1; ++i) {
if (i < sk - 1) sv[i] = sv[i + 1];
}
if (sk >= 1) {
sv[sk - 1] = x_v;
}

#pragma unroll
for (int i = 0; i < kMaxK; ++i) {
if (i < sk) {
state_steps[(size_t)s * step_stride + state_base + i] =
__float2bfloat16(sv[i]);
}
}
}

#pragma unroll
for (int i = 0; i < kMaxK; ++i) {
if (i < sk) {
state[state_base + i] = __float2bfloat16(sv[i]);
}
}
}

__global__ void causal_conv1d_update_chunk_parallel_kernel(
const __nv_bfloat16* __restrict__ x,
const __nv_bfloat16* __restrict__ w,
Expand Down Expand Up @@ -360,6 +438,27 @@ void causal_conv1d_qwen36_update_chunk_bf16(
B, S, conv_dim, k, apply_silu);
}

void causal_conv1d_qwen36_update_chunk_saves_bf16(
const void* x, const void* w, const void* bias,
void* out, void* state,
void* state_steps, int64_t step_stride,
int B, int S, int conv_dim, int k,
bool apply_silu,
cudaStream_t stream)
{
dim3 grid((conv_dim + kThreadsX - 1) / kThreadsX, B);
dim3 block(kThreadsX);
causal_conv1d_update_chunk_saves_kernel<<<grid, block, 0, stream>>>(
reinterpret_cast<const __nv_bfloat16*>(x),
reinterpret_cast<const __nv_bfloat16*>(w),
reinterpret_cast<const __nv_bfloat16*>(bias),
reinterpret_cast<__nv_bfloat16*>(out),
reinterpret_cast<__nv_bfloat16*>(state),
reinterpret_cast<__nv_bfloat16*>(state_steps),
step_stride,
B, S, conv_dim, k, apply_silu);
}

void causal_conv1d_qwen36_update_chunk_parallel_bf16(
const void* x, const void* w, const void* bias,
void* out, void* state,
Expand Down
15 changes: 15 additions & 0 deletions csrc/kernels/causal_conv1d_qwen36.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,21 @@ void causal_conv1d_qwen36_update_chunk_bf16(
bool apply_silu,
cudaStream_t stream);

// Chunk variant with per-step state checkpoints: dumps the post-step
// conv state to state_steps + s * step_stride for every step s, for
// the spec-decode partial-accept rollback.
void causal_conv1d_qwen36_update_chunk_saves_bf16(
const void* x,
const void* w,
const void* bias,
void* out,
void* state,
void* state_steps,
int64_t step_stride,
int B, int S, int conv_dim, int k,
bool apply_silu,
cudaStream_t stream);

// Parallel prefill variant: computes each (S, channel) output
// independently, then updates the final state in a second tiny kernel.
// This trades extra global loads for much higher S-dimension
Expand Down
172 changes: 172 additions & 0 deletions csrc/kernels/gated_deltanet_qwen36.cu
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,134 @@ __global__ void qwen36_gdn_chunk_from_conv_smem_kernel(
}
}

// Per-step-checkpoint variant of the chunk kernel above: identical
// math and rounding cadence (the state is rounded to bf16 after every
// step exactly as the original does between steps), plus a dump of
// each step's rounded state into ``state_steps`` (step s at
// state_steps + s * step_stride). Slot s byte-matches the committed
// state of an S = s + 1 run, which is what the spec-decode
// partial-accept rollback copies.
template <int HD>
__global__ void qwen36_gdn_chunk_from_conv_smem_saves_kernel(
const __nv_bfloat16* __restrict__ conv_out,
const __nv_bfloat16* __restrict__ a_in,
const __nv_bfloat16* __restrict__ b_in,
const float* __restrict__ neg_exp_A_log,
const float* __restrict__ dt_bias,
__nv_bfloat16* __restrict__ state,
__nv_bfloat16* __restrict__ state_steps,
int64_t step_stride,
__nv_bfloat16* __restrict__ out_,
int S,
int num_v_heads,
int a_stride,
int b_stride,
bool use_qk_l2norm)
{
static_assert(HD == 128, "HD must be 128 for Qwen3.6");
const int h = blockIdx.x;
const int b = blockIdx.y;
const int t = threadIdx.x;
if (t >= HD) return;

extern __shared__ float smem[];
float* state_s = smem;
float* qs = state_s + HD * HD;
float* ks = qs + HD;
float* scratch = ks + HD;

const size_t state_h_off =
(((size_t)b * num_v_heads + h)) * HD * HD;
#pragma unroll 16
for (int i = 0; i < HD; ++i) {
state_s[i * HD + t] = static_cast<float>(
state[state_h_off + (size_t)i * HD + t]);
}
__syncthreads();

const int src_h = h / 3;
for (int s = 0; s < S; ++s) {
const size_t row = static_cast<size_t>(s) * 10240;
const size_t out_off = ((size_t)s * num_v_heads + h) * HD + t;
qs[t] = static_cast<float>(conv_out[row + src_h * HD + t]);
ks[t] = static_cast<float>(conv_out[row + 2048 + src_h * HD + t]);
__syncthreads();

if (use_qk_l2norm) {
float q_sq = qs[t] * qs[t];
float k_sq = ks[t] * ks[t];
q_sq = block_reduce_sum<HD>(q_sq, scratch);
// See the non-saves kernel for why this barrier is required
// between the two block reductions sharing ``scratch``.
__syncthreads();
k_sq = block_reduce_sum<HD>(k_sq, scratch);
const float q_inv = rsqrtf(q_sq + kEps);
const float k_inv = rsqrtf(k_sq + kEps);
qs[t] *= q_inv;
ks[t] *= k_inv;
__syncthreads();
}

qs[t] *= rsqrtf(static_cast<float>(HD));
__syncthreads();

const float av =
static_cast<float>(a_in[s * a_stride + h]) + dt_bias[h];
const float sp = log1pf(__expf(av));
const float g_log = static_cast<float>(
__float2bfloat16(neg_exp_A_log[h] * sp));
const float g_t = __expf(g_log);
const float bv = static_cast<float>(b_in[s * b_stride + h]);
const float beta_t = static_cast<float>(
__float2bfloat16(1.0f / (1.0f + __expf(-bv))));

#pragma unroll 16
for (int i = 0; i < HD; ++i) {
state_s[i * HD + t] *= g_t;
}

float kv_mem = 0.0f;
#pragma unroll 16
for (int i = 0; i < HD; ++i) {
kv_mem = fmaf(state_s[i * HD + t], ks[i], kv_mem);
}

const float v_t =
static_cast<float>(conv_out[row + 4096 + h * HD + t]);
const float delta = (v_t - kv_mem) * beta_t;

#pragma unroll 16
for (int i = 0; i < HD; ++i) {
state_s[i * HD + t] =
fmaf(ks[i], delta, state_s[i * HD + t]);
}

float out_t = 0.0f;
#pragma unroll 16
for (int i = 0; i < HD; ++i) {
out_t = fmaf(state_s[i * HD + t], qs[i], out_t);
}
out_[out_off] = __float2bfloat16(out_t);

#pragma unroll 16
for (int i = 0; i < HD; ++i) {
const __nv_bfloat16 v =
__float2bfloat16(state_s[i * HD + t]);
state_steps[
(size_t)s * step_stride + state_h_off + (size_t)i * HD + t] =
v;
state_s[i * HD + t] = static_cast<float>(v);
}
__syncthreads();
}

#pragma unroll 16
for (int i = 0; i < HD; ++i) {
state[state_h_off + (size_t)i * HD + t] =
__float2bfloat16(state_s[i * HD + t]);
}
}

__global__ void qwen36_gdn_wy_norm_qk_kernel(
const __nv_bfloat16* __restrict__ q16,
const __nv_bfloat16* __restrict__ k16,
Expand Down Expand Up @@ -1469,6 +1597,50 @@ void qwen36_gdn_chunk_from_conv_smem_strided_bf16(
S, num_v_heads, a_stride, b_stride, use_qk_l2norm);
}

void qwen36_gdn_chunk_from_conv_smem_strided_saves_bf16(
const void* conv_out,
const void* a,
const void* b,
const float* neg_exp_A_log,
const float* dt_bias,
void* state,
void* state_steps,
int64_t step_stride,
void* out,
int S,
int num_v_heads,
int a_stride,
int b_stride,
bool use_qk_l2norm,
cudaStream_t stream)
{
if (S <= 0 || num_v_heads <= 0) return;
dim3 grid(num_v_heads, 1);
dim3 block(kHD);
constexpr size_t kSmemBytes =
(kHD * kHD + 2 * kHD + 32) * sizeof(float);
static bool attr_set = false;
if (!attr_set) {
cudaFuncSetAttribute(
qwen36_gdn_chunk_from_conv_smem_saves_kernel<kHD>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
static_cast<int>(kSmemBytes));
attr_set = true;
}
qwen36_gdn_chunk_from_conv_smem_saves_kernel<kHD><<<
grid, block, kSmemBytes, stream>>>(
reinterpret_cast<const __nv_bfloat16*>(conv_out),
reinterpret_cast<const __nv_bfloat16*>(a),
reinterpret_cast<const __nv_bfloat16*>(b),
neg_exp_A_log,
dt_bias,
reinterpret_cast<__nv_bfloat16*>(state),
reinterpret_cast<__nv_bfloat16*>(state_steps),
step_stride,
reinterpret_cast<__nv_bfloat16*>(out),
S, num_v_heads, a_stride, b_stride, use_qk_l2norm);
}

void gated_deltanet_chunk_smem_qwen36_bf16(
const void* q,
const void* k,
Expand Down
Loading