Skip to content

Rewrite SM100 KDA Forward in CuteDSL #92

@icavan

Description

@icavan

Title

Rewrite the SM100 KDA forward path in CuteDSL

Summary

The current SM100 (Blackwell) KDA forward path still depends on CUDA/CUTLASS C++ kernels under csrc/kda/sm100 and the SM100 API wrapper in csrc/api/kda_sm100.cu. We should rewrite this forward path in CuteDSL so the Blackwell implementation is easier to maintain, iterate on, and optimize alongside the rest of the CuteDSL kernel stack.

Motivation

  • Reduce the amount of architecture-specific CUDA/CUTLASS template code in the SM100 forward path.
  • Bring the SM100 KDA forward implementation closer to the rest of the CuteDSL-based operator stack in cula/ops.
  • Improve iteration speed for kernel changes, scheduling experiments, and debugging.
  • Reduce the number of prepackaged kernel variants that need to be built and shipped, since CuteDSL kernels can be compiled just in time instead of relying on a large ahead-of-time combination matrix.
  • Make it easier to keep behavior aligned across modular KDA forward components and any future fused Blackwell forward path.

Current State

Relevant code paths today:

  • Public SM100 modular KDA forward entry is documented as from cula.kda import chunk_kda.
  • SM100 API wrapper lives in csrc/api/kda_sm100.cu.
  • SM100 forward kernels currently live in csrc/kda/sm100, including:
    • kda_fwd_sm100.cu
    • kda_fwd_intra_kernel_sm100.hpp
    • kda_fwd_intra_mainloop_sm100.hpp
    • kda_fwd_common.cuh
  • There is also a Blackwell fused forward Python entry in cula/kda/blackwell_fused_fwd.py, but the device dispatch in cula/utils.py still marks Blackwell fused prefill as not yet available.

Proposal

Implement the SM100 KDA forward path in CuteDSL and make it the primary Blackwell implementation, while preserving the existing public Python API and numerics contract.

At minimum, this should cover the current SM100 forward compute path exposed to users, including the forward sub-kernels that are still implemented in CUDA/CUTLASS C++ today.

Scope

In scope

  • Rewrite the SM100 KDA forward kernels from csrc/kda/sm100 into CuteDSL-based implementations.
  • Keep the existing user-facing API unchanged wherever possible.
  • Preserve support for:
    • fixed-length and varlen inputs
    • initial_state
    • output_final_state
    • use_qk_l2norm_in_kernel
    • use_gate_in_kernel
    • safe_gate
    • lower_bound
  • Wire the new CuteDSL implementation into the Blackwell dispatch path.
  • Add or update tests and benchmarks so the CuteDSL path is directly validated against the existing reference behavior.

Acceptance Criteria

  • Blackwell KDA forward no longer depends on the current CUDA/CUTLASS SM100 forward implementation as the primary path.
  • Existing public entry points continue to work without API changes.
  • Numerical results match the current reference path within the project's existing tolerances.
  • Performance is at least on par with the current CUTLASS/C++ SM100 implementation for the benchmark configurations that represent the supported use cases.
  • Fixed-length and varlen coverage both pass.
  • Benchmark coverage exists for the new CuteDSL path on SM100.
  • Any temporary fallback path is clearly isolated and documented.

Suggested Validation

  • Run the relevant KDA forward tests, including current fused/modular forward coverage where applicable.
  • Run the Blackwell KDA benchmarks and compare against the current baseline.
  • Verify parity for representative settings involving:
    • bf16 inputs
    • beta in fp32 and bf16
    • safe_gate=True
    • use_gate_in_kernel=True
    • variable-length sequences

Suggested Starting Points

  • csrc/api/kda_sm100.cu
  • csrc/kda/sm100/kda_fwd_sm100.cu
  • csrc/kda/sm100/kda_fwd_intra_kernel_sm100.hpp
  • csrc/kda/sm100/kda_fwd_intra_mainloop_sm100.hpp
  • cula/kda/chunk.py
  • cula/kda/chunk_fwd.py
  • cula/kda/blackwell_fused_fwd.py
  • cula/utils.py
  • tests/test_kda.py
  • tests/test_kda_compare_fla.py
  • tests/test_kda_fused_fwd.py
  • benchmarks/bench_kda.py
  • benchmarks/bench_kda_fused_fwd.py

Notes

There is already a strong CuteDSL direction in this repository for performance-critical kernels. Converging the SM100 KDA forward path on CuteDSL should reduce implementation fragmentation and make future Blackwell-specific tuning substantially easier.

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions