Skip to content

qwen3.5/3.6 styles linear attn CuTe implement (Tuning stage)#93

Open
XiaomingFun233 wants to merge 13 commits into
inclusionAI:mainfrom
XiaomingFun233:main
Open

qwen3.5/3.6 styles linear attn CuTe implement (Tuning stage)#93
XiaomingFun233 wants to merge 13 commits into
inclusionAI:mainfrom
XiaomingFun233:main

Conversation

@XiaomingFun233

Copy link
Copy Markdown

📌 Description

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to cuLA! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing.

⚡ Performance

Reviewer Notes

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds comprehensive support for Qwen3.5 linear attention in cuLA, introducing specialized CUDA kernels and Python wrappers for depthwise causal 1D convolution, layout transformations, and scalar-gated delta-rule prefill and decode operations, alongside SM90 chunk prefill kernels, benchmarks, and tests. The code review highlights several critical performance and safety issues: host-device synchronizations (such as .item() and torch.unique calls) are triggered in hot paths within the decode and prefill wrappers, which will severely degrade latency; the block_sum and L2 norm computations in the CUDA kernels contain serial bottlenecks and excessive barrier synchronizations that should be optimized using warp shuffles and parallel reductions; an unused cached stream utility forces an unnecessary dependency on cuda-python; and incomplete validation of the initial_state batch dimension poses a risk of out-of-bounds memory access.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +33 to +46
def _validate_cudac_state_indices(state_indices: torch.Tensor, *, rows: int, pool_size: int) -> None:
if state_indices.ndim != 1 or state_indices.numel() != rows:
raise ValueError(f"state_indices must be 1D with {rows} entries, got {tuple(state_indices.shape)}")
if rows == 0:
return
min_idx = int(state_indices.min().item())
max_idx = int(state_indices.max().item())
if min_idx < 0 or max_idx >= pool_size:
raise ValueError(f"state_indices must be in [0, {pool_size}), got min={min_idx} max={max_idx}")
if torch.unique(state_indices).numel() != rows:
raise ValueError(
"backend='cudac' requires unique state_indices within one decode launch; "
"duplicate rows need a sequential decode path."
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Calling .item() on state_indices.min() and state_indices.max(), as well as torch.unique(state_indices), triggers host-device synchronizations. In a decode/inference hot path, these CPU-GPU syncs will severely degrade performance and increase latency. Consider removing these checks or gating them behind a debug/validation flag so they do not run during standard production inference.

Comment on lines +73 to +86
CUTE_DEVICE static float block_sum(float value, SharedStorage& storage, int tid) {
storage.scratch[tid] = value;
__syncthreads();

for (int stride = kThreads / 2; stride > 0; stride >>= 1) {
if (tid < stride) {
storage.scratch[tid] += storage.scratch[tid + stride];
}
__syncthreads();
}
const float result = storage.scratch[0];
__syncthreads();
return result;
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The block_sum implementation uses a shared-memory reduction loop with __syncthreads() at each step. Since block_sum is called twice per token inside the main prefill loop, this results in a very high number of barrier synchronizations (e.g., 18 per token), leading to significant warp stalls and overhead. Consider rewriting block_sum using warp shuffle intrinsics (__shfl_down_sync) for intra-warp reduction, followed by a single shared-memory reduction step for the warp results. This will reduce the number of __syncthreads() from 9 per call to just 1 or 2, greatly improving prefill performance.

Comment on lines +68 to +70
for seq_idx in range(cu_seqlens.numel() - 1):
start = int(cu_seqlens[seq_idx].item())
end = int(cu_seqlens[seq_idx + 1].item())

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Calling cu_seqlens[seq_idx].item() inside the loop triggers a host-device synchronization on every iteration. If there are many sequences, this will severely degrade performance. To avoid this, copy cu_seqlens to the CPU once before the loop (e.g., cu_seqlens_cpu = cu_seqlens.tolist()) and iterate over the CPU list.

            cu_seqlens_cpu = cu_seqlens.tolist()
            for seq_idx in range(len(cu_seqlens_cpu) - 1):
                start = cu_seqlens_cpu[seq_idx]
                end = cu_seqlens_cpu[seq_idx + 1]

Comment thread cula/qwen35/runtime.py
Comment on lines +48 to +55
def _get_cached_stream(device: torch.device) -> object:
if cuda is None:
raise RuntimeError("cuda.bindings.driver is not available in this environment.")
stream_id = int(torch.cuda.current_stream(device=device).cuda_stream)
cache_key = (str(device), stream_id)
if cache_key not in _stream_cache:
_stream_cache[cache_key] = cuda.CUstream(stream_id)
return _stream_cache[cache_key]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The function _get_cached_stream is called unconditionally on the CUDA path in both qwen35_linear_attention_prefill and qwen35_linear_attention_decode, but its return value is discarded. Furthermore, if cuda-python is not installed (cuda is None), it raises a RuntimeError. This unnecessarily forces a dependency on cuda-python for users who only want to run the PyTorch/CUDA kernels. Since the returned stream is not used anyway, consider removing these calls entirely, or at least making them a no-op when cuda is None.

Comment on lines +411 to +424
if (tid == 0) {
float q_norm_sq = 0.f;
float k_norm_sq = 0.f;
#pragma unroll
for (int idx = 0; idx < kHeadDimQK; ++idx) {
const float q_val = q_smem(idx);
const float k_val = k_smem(idx);
q_norm_sq += q_val * q_val;
k_norm_sq += k_val * k_val;
}
norm_smem(0) = rsqrtf(q_norm_sq + 1e-6f) * rsqrtf(static_cast<float>(kHeadDimQK));
norm_smem(1) = rsqrtf(k_norm_sq + 1e-6f);
}
__syncthreads();

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Thread 0 is performing a sequential loop of 128 iterations to compute the L2 norm of q and k. This introduces a serial bottleneck in the decode kernel where latency is critical. Since there are 128 threads in the block and kHeadDimQK is 128, you can perform a parallel block reduction (using warp shuffles or shared memory) to compute q_norm_sq and k_norm_sq in parallel. This would reduce the reduction complexity from $O(N)$ to $O(\log N)$.

Comment on lines +62 to +63
if initial_state is not None and initial_state.shape[1:] != (HV, K, K):
raise ValueError(f"initial_state must be [N,HV,128,128], got {tuple(initial_state.shape)}")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The validation for initial_state only checks the trailing dimensions shape[1:]. It does not verify that the batch dimension initial_state.shape[0] matches state_count (which is B or cu_seqlens.numel() - 1). If a tensor with a mismatched batch dimension is passed, the CUDA kernel could access memory out of bounds. Consider adding a check for the batch dimension.

    if initial_state is not None:
        state_count = B if cu_seqlens is None else cu_seqlens.numel() - 1
        if initial_state.shape != (state_count, HV, K, K):
            raise ValueError(f"initial_state must be [{state_count},{HV},128,128], got {tuple(initial_state.shape)}")

@icavan icavan requested review from cherhh and icavan June 22, 2026 09:13
@cherhh

cherhh commented Jun 22, 2026

Copy link
Copy Markdown
Collaborator

@XiaomingFun233 Thank you very much for your contribution. Could you please provide some benchmark results along with corresponding explanations?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants