qwen3.5/3.6 styles linear attn CuTe implement (Tuning stage)#93
qwen3.5/3.6 styles linear attn CuTe implement (Tuning stage)#93XiaomingFun233 wants to merge 13 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request adds comprehensive support for Qwen3.5 linear attention in cuLA, introducing specialized CUDA kernels and Python wrappers for depthwise causal 1D convolution, layout transformations, and scalar-gated delta-rule prefill and decode operations, alongside SM90 chunk prefill kernels, benchmarks, and tests. The code review highlights several critical performance and safety issues: host-device synchronizations (such as .item() and torch.unique calls) are triggered in hot paths within the decode and prefill wrappers, which will severely degrade latency; the block_sum and L2 norm computations in the CUDA kernels contain serial bottlenecks and excessive barrier synchronizations that should be optimized using warp shuffles and parallel reductions; an unused cached stream utility forces an unnecessary dependency on cuda-python; and incomplete validation of the initial_state batch dimension poses a risk of out-of-bounds memory access.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| def _validate_cudac_state_indices(state_indices: torch.Tensor, *, rows: int, pool_size: int) -> None: | ||
| if state_indices.ndim != 1 or state_indices.numel() != rows: | ||
| raise ValueError(f"state_indices must be 1D with {rows} entries, got {tuple(state_indices.shape)}") | ||
| if rows == 0: | ||
| return | ||
| min_idx = int(state_indices.min().item()) | ||
| max_idx = int(state_indices.max().item()) | ||
| if min_idx < 0 or max_idx >= pool_size: | ||
| raise ValueError(f"state_indices must be in [0, {pool_size}), got min={min_idx} max={max_idx}") | ||
| if torch.unique(state_indices).numel() != rows: | ||
| raise ValueError( | ||
| "backend='cudac' requires unique state_indices within one decode launch; " | ||
| "duplicate rows need a sequential decode path." | ||
| ) |
There was a problem hiding this comment.
Calling .item() on state_indices.min() and state_indices.max(), as well as torch.unique(state_indices), triggers host-device synchronizations. In a decode/inference hot path, these CPU-GPU syncs will severely degrade performance and increase latency. Consider removing these checks or gating them behind a debug/validation flag so they do not run during standard production inference.
| CUTE_DEVICE static float block_sum(float value, SharedStorage& storage, int tid) { | ||
| storage.scratch[tid] = value; | ||
| __syncthreads(); | ||
|
|
||
| for (int stride = kThreads / 2; stride > 0; stride >>= 1) { | ||
| if (tid < stride) { | ||
| storage.scratch[tid] += storage.scratch[tid + stride]; | ||
| } | ||
| __syncthreads(); | ||
| } | ||
| const float result = storage.scratch[0]; | ||
| __syncthreads(); | ||
| return result; | ||
| } |
There was a problem hiding this comment.
The block_sum implementation uses a shared-memory reduction loop with __syncthreads() at each step. Since block_sum is called twice per token inside the main prefill loop, this results in a very high number of barrier synchronizations (e.g., 18 per token), leading to significant warp stalls and overhead. Consider rewriting block_sum using warp shuffle intrinsics (__shfl_down_sync) for intra-warp reduction, followed by a single shared-memory reduction step for the warp results. This will reduce the number of __syncthreads() from 9 per call to just 1 or 2, greatly improving prefill performance.
| for seq_idx in range(cu_seqlens.numel() - 1): | ||
| start = int(cu_seqlens[seq_idx].item()) | ||
| end = int(cu_seqlens[seq_idx + 1].item()) |
There was a problem hiding this comment.
Calling cu_seqlens[seq_idx].item() inside the loop triggers a host-device synchronization on every iteration. If there are many sequences, this will severely degrade performance. To avoid this, copy cu_seqlens to the CPU once before the loop (e.g., cu_seqlens_cpu = cu_seqlens.tolist()) and iterate over the CPU list.
cu_seqlens_cpu = cu_seqlens.tolist()
for seq_idx in range(len(cu_seqlens_cpu) - 1):
start = cu_seqlens_cpu[seq_idx]
end = cu_seqlens_cpu[seq_idx + 1]| def _get_cached_stream(device: torch.device) -> object: | ||
| if cuda is None: | ||
| raise RuntimeError("cuda.bindings.driver is not available in this environment.") | ||
| stream_id = int(torch.cuda.current_stream(device=device).cuda_stream) | ||
| cache_key = (str(device), stream_id) | ||
| if cache_key not in _stream_cache: | ||
| _stream_cache[cache_key] = cuda.CUstream(stream_id) | ||
| return _stream_cache[cache_key] |
There was a problem hiding this comment.
The function _get_cached_stream is called unconditionally on the CUDA path in both qwen35_linear_attention_prefill and qwen35_linear_attention_decode, but its return value is discarded. Furthermore, if cuda-python is not installed (cuda is None), it raises a RuntimeError. This unnecessarily forces a dependency on cuda-python for users who only want to run the PyTorch/CUDA kernels. Since the returned stream is not used anyway, consider removing these calls entirely, or at least making them a no-op when cuda is None.
| if (tid == 0) { | ||
| float q_norm_sq = 0.f; | ||
| float k_norm_sq = 0.f; | ||
| #pragma unroll | ||
| for (int idx = 0; idx < kHeadDimQK; ++idx) { | ||
| const float q_val = q_smem(idx); | ||
| const float k_val = k_smem(idx); | ||
| q_norm_sq += q_val * q_val; | ||
| k_norm_sq += k_val * k_val; | ||
| } | ||
| norm_smem(0) = rsqrtf(q_norm_sq + 1e-6f) * rsqrtf(static_cast<float>(kHeadDimQK)); | ||
| norm_smem(1) = rsqrtf(k_norm_sq + 1e-6f); | ||
| } | ||
| __syncthreads(); |
There was a problem hiding this comment.
Thread 0 is performing a sequential loop of 128 iterations to compute the L2 norm of q and k. This introduces a serial bottleneck in the decode kernel where latency is critical. Since there are 128 threads in the block and kHeadDimQK is 128, you can perform a parallel block reduction (using warp shuffles or shared memory) to compute q_norm_sq and k_norm_sq in parallel. This would reduce the reduction complexity from
| if initial_state is not None and initial_state.shape[1:] != (HV, K, K): | ||
| raise ValueError(f"initial_state must be [N,HV,128,128], got {tuple(initial_state.shape)}") |
There was a problem hiding this comment.
The validation for initial_state only checks the trailing dimensions shape[1:]. It does not verify that the batch dimension initial_state.shape[0] matches state_count (which is B or cu_seqlens.numel() - 1). If a tensor with a mismatched batch dimension is passed, the CUDA kernel could access memory out of bounds. Consider adding a check for the batch dimension.
if initial_state is not None:
state_count = B if cu_seqlens is None else cu_seqlens.numel() - 1
if initial_state.shape != (state_count, HV, K, K):
raise ValueError(f"initial_state must be [{state_count},{HV},128,128], got {tuple(initial_state.shape)}")|
@XiaomingFun233 Thank you very much for your contribution. Could you please provide some benchmark results along with corresponding explanations? |
📌 Description
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to cuLA! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
⚡ Performance
Reviewer Notes