Title
Rewrite the SM100 KDA forward path in CuteDSL
Summary
The current SM100 (Blackwell) KDA forward path still depends on CUDA/CUTLASS C++ kernels under csrc/kda/sm100 and the SM100 API wrapper in csrc/api/kda_sm100.cu. We should rewrite this forward path in CuteDSL so the Blackwell implementation is easier to maintain, iterate on, and optimize alongside the rest of the CuteDSL kernel stack.
Motivation
- Reduce the amount of architecture-specific CUDA/CUTLASS template code in the SM100 forward path.
- Bring the SM100 KDA forward implementation closer to the rest of the CuteDSL-based operator stack in
cula/ops.
- Improve iteration speed for kernel changes, scheduling experiments, and debugging.
- Reduce the number of prepackaged kernel variants that need to be built and shipped, since CuteDSL kernels can be compiled just in time instead of relying on a large ahead-of-time combination matrix.
- Make it easier to keep behavior aligned across modular KDA forward components and any future fused Blackwell forward path.
Current State
Relevant code paths today:
- Public SM100 modular KDA forward entry is documented as
from cula.kda import chunk_kda.
- SM100 API wrapper lives in
csrc/api/kda_sm100.cu.
- SM100 forward kernels currently live in
csrc/kda/sm100, including:
kda_fwd_sm100.cu
kda_fwd_intra_kernel_sm100.hpp
kda_fwd_intra_mainloop_sm100.hpp
kda_fwd_common.cuh
- There is also a Blackwell fused forward Python entry in
cula/kda/blackwell_fused_fwd.py, but the device dispatch in cula/utils.py still marks Blackwell fused prefill as not yet available.
Proposal
Implement the SM100 KDA forward path in CuteDSL and make it the primary Blackwell implementation, while preserving the existing public Python API and numerics contract.
At minimum, this should cover the current SM100 forward compute path exposed to users, including the forward sub-kernels that are still implemented in CUDA/CUTLASS C++ today.
Scope
In scope
- Rewrite the SM100 KDA forward kernels from
csrc/kda/sm100 into CuteDSL-based implementations.
- Keep the existing user-facing API unchanged wherever possible.
- Preserve support for:
- fixed-length and varlen inputs
initial_state
output_final_state
use_qk_l2norm_in_kernel
use_gate_in_kernel
safe_gate
lower_bound
- Wire the new CuteDSL implementation into the Blackwell dispatch path.
- Add or update tests and benchmarks so the CuteDSL path is directly validated against the existing reference behavior.
Acceptance Criteria
- Blackwell KDA forward no longer depends on the current CUDA/CUTLASS SM100 forward implementation as the primary path.
- Existing public entry points continue to work without API changes.
- Numerical results match the current reference path within the project's existing tolerances.
- Performance is at least on par with the current CUTLASS/C++ SM100 implementation for the benchmark configurations that represent the supported use cases.
- Fixed-length and varlen coverage both pass.
- Benchmark coverage exists for the new CuteDSL path on SM100.
- Any temporary fallback path is clearly isolated and documented.
Suggested Validation
- Run the relevant KDA forward tests, including current fused/modular forward coverage where applicable.
- Run the Blackwell KDA benchmarks and compare against the current baseline.
- Verify parity for representative settings involving:
- bf16 inputs
beta in fp32 and bf16
safe_gate=True
use_gate_in_kernel=True
- variable-length sequences
Suggested Starting Points
csrc/api/kda_sm100.cu
csrc/kda/sm100/kda_fwd_sm100.cu
csrc/kda/sm100/kda_fwd_intra_kernel_sm100.hpp
csrc/kda/sm100/kda_fwd_intra_mainloop_sm100.hpp
cula/kda/chunk.py
cula/kda/chunk_fwd.py
cula/kda/blackwell_fused_fwd.py
cula/utils.py
tests/test_kda.py
tests/test_kda_compare_fla.py
tests/test_kda_fused_fwd.py
benchmarks/bench_kda.py
benchmarks/bench_kda_fused_fwd.py
Notes
There is already a strong CuteDSL direction in this repository for performance-critical kernels. Converging the SM100 KDA forward path on CuteDSL should reduce implementation fragmentation and make future Blackwell-specific tuning substantially easier.
Title
Rewrite the SM100 KDA forward path in CuteDSL
Summary
The current SM100 (Blackwell) KDA forward path still depends on CUDA/CUTLASS C++ kernels under
csrc/kda/sm100and the SM100 API wrapper incsrc/api/kda_sm100.cu. We should rewrite this forward path in CuteDSL so the Blackwell implementation is easier to maintain, iterate on, and optimize alongside the rest of the CuteDSL kernel stack.Motivation
cula/ops.Current State
Relevant code paths today:
from cula.kda import chunk_kda.csrc/api/kda_sm100.cu.csrc/kda/sm100, including:kda_fwd_sm100.cukda_fwd_intra_kernel_sm100.hppkda_fwd_intra_mainloop_sm100.hppkda_fwd_common.cuhcula/kda/blackwell_fused_fwd.py, but the device dispatch incula/utils.pystill marks Blackwell fused prefill as not yet available.Proposal
Implement the SM100 KDA forward path in CuteDSL and make it the primary Blackwell implementation, while preserving the existing public Python API and numerics contract.
At minimum, this should cover the current SM100 forward compute path exposed to users, including the forward sub-kernels that are still implemented in CUDA/CUTLASS C++ today.
Scope
In scope
csrc/kda/sm100into CuteDSL-based implementations.initial_stateoutput_final_stateuse_qk_l2norm_in_kerneluse_gate_in_kernelsafe_gatelower_boundAcceptance Criteria
Suggested Validation
betain fp32 and bf16safe_gate=Trueuse_gate_in_kernel=TrueSuggested Starting Points
csrc/api/kda_sm100.cucsrc/kda/sm100/kda_fwd_sm100.cucsrc/kda/sm100/kda_fwd_intra_kernel_sm100.hppcsrc/kda/sm100/kda_fwd_intra_mainloop_sm100.hppcula/kda/chunk.pycula/kda/chunk_fwd.pycula/kda/blackwell_fused_fwd.pycula/utils.pytests/test_kda.pytests/test_kda_compare_fla.pytests/test_kda_fused_fwd.pybenchmarks/bench_kda.pybenchmarks/bench_kda_fused_fwd.pyNotes
There is already a strong CuteDSL direction in this repository for performance-critical kernels. Converging the SM100 KDA forward path on CuteDSL should reduce implementation fragmentation and make future Blackwell-specific tuning substantially easier.