diff --git a/benchmarks/bench_qwen35_decode.py b/benchmarks/bench_qwen35_decode.py new file mode 100755 index 00000000..7cd9f858 --- /dev/null +++ b/benchmarks/bench_qwen35_decode.py @@ -0,0 +1,667 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Benchmark Qwen3.5 decode on the active CUDA device. + +Two timing scopes are reported: + - native_core: direct native scalar GDN decode op. + - triton_core: FLA/SGLang-style fused_sigmoid_gating_delta_rule_update + Triton decode op vendored in cuLA. + - sglang_core: fused_sigmoid_gating_delta_rule_update from SGLang, when + available from the installed package or --sglang-path. + - fused_layout_kda: direct cuLA fused Qwen3.5 layout + scalar KDA decode op. + - sglang_packed: SGLang packed Qwen3.5 layout + recurrent update op, when + available from the installed package or --sglang-path. + - full: cuLA Python Qwen3.5 decode chain, including conv + layout + core. + +State buffers are reset before each timed iteration and the reset copy is not +included in the event timing window. +""" + +from __future__ import annotations + +import argparse +import csv +import importlib +import importlib.util +import inspect +import pathlib +import statistics +import sys +import time +from collections.abc import Callable + +import torch + +ROOT = pathlib.Path(__file__).resolve().parent.parent +sys.path.insert(0, str(ROOT)) + +import cula.cudac as cula_cuda +from cula.ops.kda_decode_fla import fused_sigmoid_gating_delta_rule_update as triton_fused_sigmoid_update +from cula.qwen35.common import DEFAULT_QWEN35_LINEAR_ATTN_CONFIG as CONFIG +from cula.qwen35.common import Qwen35LinearAttentionConfig +from cula.qwen35.runtime import qwen35_linear_attention_decode + +SGLANG_CORE_MODULES = [ + "sglang.srt.layers.attention.linear.kernels.gdn_triton", + "sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent", +] +SGLANG_CORE_FILES = [ + pathlib.Path("sglang/srt/layers/attention/linear/kernels/gdn_triton.py"), + pathlib.Path("sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py"), +] +SGLANG_PACKED_MODULES = [ + "sglang.srt.layers.attention.fla.fused_recurrent", + "sglang.srt.layers.attention.linear.kernels.gdn_triton", +] +SGLANG_PACKED_FILES = [ + pathlib.Path("sglang/srt/layers/attention/fla/fused_recurrent.py"), + pathlib.Path("sglang/srt/layers/attention/linear/kernels/gdn_triton.py"), +] + + +def accelerator_device() -> torch.device: + if torch.cuda.is_available(): + return torch.device("cuda") + raise RuntimeError("No CUDA accelerator is available.") + + +def accelerator_name(device: torch.device) -> str: + if device.type != "cuda": + raise ValueError(f"Unsupported device={device}") + return torch.cuda.get_device_name(device.index or 0) + + +def synchronize(device: torch.device) -> None: + if device.type != "cuda": + raise ValueError(f"Unsupported device={device}") + torch.cuda.synchronize() + + +def benchmark_accel_fn( + fn: Callable[[], object], + *, + device: torch.device, + setup_fn: Callable[[], None] | None, + warmup: int, + rep: int, +) -> float: + for _ in range(warmup): + if setup_fn is not None: + setup_fn() + fn() + synchronize(device) + + times: list[float] = [] + try: + starts = [torch.cuda.Event(enable_timing=True) for _ in range(rep)] + ends = [torch.cuda.Event(enable_timing=True) for _ in range(rep)] + for i in range(rep): + if setup_fn is not None: + setup_fn() + starts[i].record() + fn() + ends[i].record() + synchronize(device) + times = [s.elapsed_time(e) for s, e in zip(starts, ends)] + except Exception: + for _ in range(rep): + if setup_fn is not None: + setup_fn() + synchronize(device) + t0 = time.perf_counter() + fn() + synchronize(device) + times.append((time.perf_counter() - t0) * 1000.0) + + if not times: + return 0.0 + if len(times) < 4: + return statistics.mean(times) + times = sorted(times) + iqr = times[len(times) // 4 : 3 * len(times) // 4] + return statistics.mean(iqr) + + +def local_config_from_tp_size(tp_size: int) -> Qwen35LinearAttentionConfig: + if tp_size not in (1, 2, 4, 8): + raise ValueError(f"tp_size must be one of 1, 2, 4, 8, got {tp_size}") + return Qwen35LinearAttentionConfig( + hidden_size=CONFIG.hidden_size // tp_size, + conv_kernel_size=CONFIG.conv_kernel_size, + num_k_heads=CONFIG.num_k_heads // tp_size, + num_v_heads=CONFIG.num_v_heads // tp_size, + head_k_dim=CONFIG.head_k_dim, + head_v_dim=CONFIG.head_v_dim, + qkv_dtype=CONFIG.qkv_dtype, + state_dtype=CONFIG.state_dtype, + ) + + +def make_full_inputs(tokens: int, device: torch.device, seed: int, config: Qwen35LinearAttentionConfig): + torch.manual_seed(seed) + pool_size = max(tokens, 1) + mixed_qkv = torch.randn(tokens, config.conv_dim, device=device, dtype=config.qkv_dtype) + a = torch.randn(tokens, config.num_v_heads, device=device, dtype=config.qkv_dtype) + b = torch.randn(tokens, config.num_v_heads, device=device, dtype=config.qkv_dtype) + conv_weight = torch.randn(config.conv_dim, config.conv_kernel_size, device=device, dtype=config.qkv_dtype) + conv_state = torch.randn( + tokens, + config.conv_dim, + config.conv_kernel_size, + device=device, + dtype=config.qkv_dtype, + ) + recurrent_state = torch.randn( + pool_size, + config.num_v_heads, + config.head_k_dim, + config.head_v_dim, + device=device, + dtype=config.state_dtype, + ) * 0.01 + A_log = -torch.rand(config.num_v_heads, device=device, dtype=torch.float32) + dt_bias = torch.randn(config.num_v_heads, device=device, dtype=torch.float32) * 0.1 + state_indices = torch.arange(tokens, device=device, dtype=torch.int32) + return mixed_qkv, a, b, conv_weight, conv_state, recurrent_state, A_log, dt_bias, state_indices + + +def make_fused_layout_kda_inputs(tokens: int, device: torch.device, seed: int, config: Qwen35LinearAttentionConfig): + torch.manual_seed(seed) + mixed_qkv_conv = torch.randn(tokens, config.conv_dim, device=device, dtype=config.qkv_dtype) + a = torch.randn(tokens, config.num_v_heads, device=device, dtype=config.qkv_dtype) + b = torch.randn(tokens, config.num_v_heads, device=device, dtype=config.qkv_dtype) + A_log = -torch.rand(config.num_v_heads, device=device, dtype=torch.float32) + dt_bias = torch.randn(config.num_v_heads, device=device, dtype=torch.float32) * 0.1 + state = torch.randn( + tokens, + config.num_v_heads, + config.head_k_dim, + config.head_v_dim, + device=device, + dtype=config.state_dtype, + ) * 0.01 + state_work = torch.empty_like(state) + state_indices = torch.arange(tokens, device=device, dtype=torch.int32) + out = torch.empty(tokens, config.num_v_heads, config.head_v_dim, device=device, dtype=config.qkv_dtype) + return mixed_qkv_conv, a, b, A_log, dt_bias, state, state_work, state_indices, out + + +def make_core_inputs(tokens: int, device: torch.device, seed: int, config: Qwen35LinearAttentionConfig): + torch.manual_seed(seed) + q = torch.randn(tokens, config.num_v_heads, config.head_k_dim, device=device, dtype=config.qkv_dtype) + k = torch.randn(tokens, config.num_v_heads, config.head_k_dim, device=device, dtype=config.qkv_dtype) + v = torch.randn(tokens, config.num_v_heads, config.head_v_dim, device=device, dtype=config.qkv_dtype) + a = torch.randn(tokens, config.num_v_heads, device=device, dtype=config.qkv_dtype) + b = torch.randn(tokens, config.num_v_heads, device=device, dtype=config.qkv_dtype) + A_log = -torch.rand(config.num_v_heads, device=device, dtype=torch.float32) + dt_bias = torch.randn(config.num_v_heads, device=device, dtype=torch.float32) * 0.1 + state = torch.randn( + tokens, + config.num_v_heads, + config.head_k_dim, + config.head_v_dim, + device=device, + dtype=config.state_dtype, + ) * 0.01 + state_indices = torch.arange(tokens, device=device, dtype=torch.int32) + out = torch.empty_like(v) + state_work = torch.empty_like(state) + return q, k, v, a, b, A_log, dt_bias, state, state_work, state_indices, out + + +def _add_sglang_import_roots(sglang_path: pathlib.Path | None) -> None: + if sglang_path is not None: + for import_root in (sglang_path, sglang_path / "python"): + if import_root.exists(): + sys.path.insert(0, str(import_root)) + + +def _resolve_sglang_symbol( + *, + sglang_path: pathlib.Path | None, + module_names: list[str], + file_paths: list[pathlib.Path], + symbol_names: list[str], +): + _add_sglang_import_roots(sglang_path) + + import_errors = [] + for module_name in module_names: + try: + module = importlib.import_module(module_name) + except (ImportError, PermissionError, ModuleNotFoundError) as exc: + import_errors.append(f"{module_name}: {type(exc).__name__}: {exc}") + continue + for symbol_name in symbol_names: + if hasattr(module, symbol_name): + return getattr(module, symbol_name), f"{module_name}.{symbol_name}" + + if sglang_path is not None: + candidates: list[pathlib.Path] = [] + for rel_path in file_paths: + candidates.extend([sglang_path / rel_path, sglang_path / "python" / rel_path]) + for idx, path in enumerate(candidates): + if path.exists(): + spec = importlib.util.spec_from_file_location(f"_sglang_qwen35_decode_provider_{idx}", path) + if spec is None or spec.loader is None: + continue + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + for symbol_name in symbol_names: + if hasattr(module, symbol_name): + return getattr(module, symbol_name), f"{path}:{symbol_name}" + raise RuntimeError( + f"Could not find any of {symbol_names} under --sglang-path={sglang_path}. " + "Pass the SGLang repo root or its python/ directory. " + f"Import errors: {'; '.join(import_errors) or 'none'}" + ) + + return None, None + + +def resolve_sglang_core_update(sglang_path: pathlib.Path | None): + """Return SGLang's scalar-gated recurrent update function if available.""" + return _resolve_sglang_symbol( + sglang_path=sglang_path, + module_names=SGLANG_CORE_MODULES, + file_paths=SGLANG_CORE_FILES, + symbol_names=["fused_sigmoid_gating_delta_rule_update"], + ) + + +def resolve_sglang_packed_decode(sglang_path: pathlib.Path | None): + """Return SGLang's packed layout + recurrent decode function if available.""" + return _resolve_sglang_symbol( + sglang_path=sglang_path, + module_names=SGLANG_PACKED_MODULES, + file_paths=SGLANG_PACKED_FILES, + symbol_names=[ + "fused_recurrent_gated_delta_rule_packed_decode", + "fused_recurrent_gated_delta_rule_packed_decode_cpu", + ], + ) + + +def call_with_supported_kwargs(fn: Callable, **kwargs): + """Call a provider while tolerating minor SGLang signature drift.""" + try: + signature = inspect.signature(fn) + except (TypeError, ValueError): + return fn(**kwargs) + if any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()): + return fn(**kwargs) + filtered = {name: value for name, value in kwargs.items() if name in signature.parameters} + return fn(**filtered) + + +def bench_native_core(tokens: int, device: torch.device, warmup: int, rep: int, seed: int, config: Qwen35LinearAttentionConfig) -> float: + q, k, v, a, b, A_log, dt_bias, state, state_work, state_indices, out = make_core_inputs(tokens, device, seed, config) + + def setup() -> None: + state_work.copy_(state) + + def run() -> None: + cula_cuda.qwen35_scalar_kda_decode( + q, + k, + v, + a, + b, + A_log, + dt_bias, + state_work, + state_indices, + out, + ) + + return benchmark_accel_fn(run, device=device, setup_fn=setup, warmup=warmup, rep=rep) + + +def bench_triton_core(tokens: int, device: torch.device, warmup: int, rep: int, seed: int, config: Qwen35LinearAttentionConfig) -> float: + q, k, v, a, b, A_log, dt_bias, state, state_work, state_indices, _ = make_core_inputs(tokens, device, seed, config) + q_4d = q.unsqueeze(1).contiguous() + k_4d = k.unsqueeze(1).contiguous() + v_4d = v.unsqueeze(1).contiguous() + a_3d = a.unsqueeze(1).contiguous() + b_3d = b.unsqueeze(1).contiguous() + + def setup() -> None: + state_work.copy_(state) + + def run() -> None: + triton_fused_sigmoid_update( + A_log=A_log, + a=a_3d, + dt_bias=dt_bias, + softplus_beta=1.0, + softplus_threshold=20.0, + q=q_4d, + k=k_4d, + v=v_4d, + b=b_3d, + initial_state_source=state_work, + initial_state_indices=state_indices, + scale=config.head_k_dim**-0.5, + use_qk_l2norm_in_kernel=True, + cu_seqlens=None, + is_kda=False, + ) + + return benchmark_accel_fn(run, device=device, setup_fn=setup, warmup=warmup, rep=rep) + + +def bench_fused_layout_kda(tokens: int, device: torch.device, warmup: int, rep: int, seed: int, config: Qwen35LinearAttentionConfig) -> float: + mixed_qkv_conv, a, b, A_log, dt_bias, state, state_work, state_indices, out = make_fused_layout_kda_inputs( + tokens, device, seed, config + ) + + def setup() -> None: + state_work.copy_(state) + + def run() -> None: + cula_cuda.qwen35_layout_scalar_kda_decode( + mixed_qkv_conv, + a, + b, + A_log, + dt_bias, + state_work, + state_indices, + out, + ) + + return benchmark_accel_fn(run, device=device, setup_fn=setup, warmup=warmup, rep=rep) + + +def bench_sglang_core( + tokens: int, + device: torch.device, + warmup: int, + rep: int, + seed: int, + sglang_fused_update: Callable, + config: Qwen35LinearAttentionConfig, +) -> float: + q, k, v, a, b, A_log, dt_bias, state, _, state_indices, _ = make_core_inputs(tokens, device, seed, config) + q_4d = q.unsqueeze(1).contiguous() + k_4d = k.unsqueeze(1).contiguous() + v_4d = v.unsqueeze(1).contiguous() + a_3d = a.unsqueeze(1).contiguous() + b_3d = b.unsqueeze(1).contiguous() + state_vk = state.transpose(-1, -2).contiguous() + state_vk_work = torch.empty_like(state_vk) + + def setup() -> None: + state_vk_work.copy_(state_vk) + + def run() -> None: + call_with_supported_kwargs( + sglang_fused_update, + A_log=A_log, + a=a_3d, + dt_bias=dt_bias, + softplus_beta=1.0, + softplus_threshold=20.0, + q=q_4d, + k=k_4d, + v=v_4d, + b=b_3d, + initial_state_source=state_vk_work, + initial_state_indices=state_indices, + scale=config.head_k_dim**-0.5, + use_qk_l2norm_in_kernel=True, + cu_seqlens=None, + is_kda=False, + ) + + return benchmark_accel_fn(run, device=device, setup_fn=setup, warmup=warmup, rep=rep) + + +def bench_sglang_packed_layout_kda( + tokens: int, + device: torch.device, + warmup: int, + rep: int, + seed: int, + sglang_packed_decode: Callable, + config: Qwen35LinearAttentionConfig, +) -> float: + mixed_qkv_conv, a, b, A_log, dt_bias, state, _, state_indices, _ = make_fused_layout_kda_inputs(tokens, device, seed, config) + state_vk = state.transpose(-1, -2).contiguous() + state_vk_work = torch.empty_like(state_vk) + out = torch.empty(tokens, 1, config.num_v_heads, config.head_v_dim, device=device, dtype=config.qkv_dtype) + + def setup() -> None: + state_vk_work.copy_(state_vk) + + def run() -> None: + call_with_supported_kwargs( + sglang_packed_decode, + mixed_qkv=mixed_qkv_conv, + a=a, + b=b, + A_log=A_log, + dt_bias=dt_bias, + scale=config.head_k_dim**-0.5, + initial_state=state_vk_work, + out=out, + ssm_state_indices=state_indices, + use_qk_l2norm_in_kernel=True, + ) + + return benchmark_accel_fn(run, device=device, setup_fn=setup, warmup=warmup, rep=rep) + + +def bench_full(tokens: int, device: torch.device, warmup: int, rep: int, seed: int, config: Qwen35LinearAttentionConfig) -> float: + inputs = make_full_inputs(tokens, device, seed, config) + mixed_qkv, a, b, conv_weight, conv_state, recurrent_state, A_log, dt_bias, state_indices = inputs + conv_state_work = torch.empty_like(conv_state) + recurrent_state_work = torch.empty_like(recurrent_state) + + def setup() -> None: + conv_state_work.copy_(conv_state) + recurrent_state_work.copy_(recurrent_state) + + def run() -> None: + qwen35_linear_attention_decode( + mixed_qkv, + a, + b, + conv_weight, + A_log, + dt_bias, + config=config, + conv_state=conv_state_work, + recurrent_state=recurrent_state_work, + state_indices=state_indices, + backend="cudac", + ) + + return benchmark_accel_fn(run, device=device, setup_fn=setup, warmup=warmup, rep=rep) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Benchmark cuLA Qwen3.5 decode.") + parser.add_argument("--tokens", nargs="+", type=int, default=[1, 2, 4, 8, 16, 32, 64, 128]) + parser.add_argument("--warmup", type=int, default=20) + parser.add_argument("--rep", type=int, default=100) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--scope", choices=["core", "fused", "full", "both"], default="both") + parser.add_argument("--tp-size", type=int, choices=[1, 2, 4, 8], default=1) + parser.add_argument("--skip-triton", action="store_true", help="Skip the vendored Triton core timing.") + parser.add_argument("--skip-sglang", action="store_true", help="Do not try the SGLang kernel provider.") + parser.add_argument("--require-sglang", action="store_true", help="Fail if the SGLang kernel provider is unavailable.") + parser.add_argument("--sglang-path", type=pathlib.Path, default=None, help="SGLang repo root or python/ directory.") + parser.add_argument("--csv", type=pathlib.Path, default=None) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + config = local_config_from_tp_size(args.tp_size) + device = accelerator_device() + rows: list[dict[str, object]] = [] + sglang_fused_update = None + sglang_core_source = None + sglang_packed_decode = None + sglang_packed_source = None + if not args.skip_sglang: + sglang_fused_update, sglang_core_source = resolve_sglang_core_update(args.sglang_path) + sglang_packed_decode, sglang_packed_source = resolve_sglang_packed_decode(args.sglang_path) + if args.require_sglang and (sglang_fused_update is None or sglang_packed_decode is None): + raise RuntimeError("SGLang core and packed decode providers must both be available.") + + print(f"device={device} name={accelerator_name(device)} torch={torch.__version__}") + print( + f"qwen35: tp={args.tp_size} local_HK={config.num_k_heads} local_HV={config.num_v_heads} " + f"K={config.head_k_dim} V={config.head_v_dim} conv_dim={config.conv_dim}" + ) + print(f"sglang_core_provider={sglang_core_source or 'unavailable'}") + print(f"sglang_packed_provider={sglang_packed_source or 'unavailable'}") + print("| tokens | native_core_ms | triton_core_ms | sglang_core_ms | fused_layout_kda_ms | sglang_packed_ms | full_ms | triton/native | sglang/native | packed/fused | native_us_per_token | triton_us_per_token | sglang_us_per_token | fused_us_per_token | packed_us_per_token | full_us_per_token |") + print("|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|") + + for tokens in args.tokens: + native_core_ms = None + triton_core_ms = None + sglang_core_ms = None + fused_layout_kda_ms = None + sglang_packed_ms = None + full_ms = None + if args.scope in ("core", "both"): + native_core_ms = bench_native_core(tokens, device, args.warmup, args.rep, args.seed, config) + if not args.skip_triton: + triton_core_ms = bench_triton_core(tokens, device, args.warmup, args.rep, args.seed, config) + if sglang_fused_update is not None: + sglang_core_ms = bench_sglang_core( + tokens, + device, + args.warmup, + args.rep, + args.seed, + sglang_fused_update, + config, + ) + if args.scope in ("fused", "both"): + fused_layout_kda_ms = bench_fused_layout_kda(tokens, device, args.warmup, args.rep, args.seed, config) + if sglang_packed_decode is not None: + sglang_packed_ms = bench_sglang_packed_layout_kda( + tokens, + device, + args.warmup, + args.rep, + args.seed, + sglang_packed_decode, + config, + ) + if args.scope in ("full", "both"): + full_ms = bench_full(tokens, device, args.warmup, args.rep, args.seed, config) + + native_core_us = None if native_core_ms is None else native_core_ms * 1000.0 / tokens + triton_core_us = None if triton_core_ms is None else triton_core_ms * 1000.0 / tokens + sglang_core_us = None if sglang_core_ms is None else sglang_core_ms * 1000.0 / tokens + fused_layout_kda_us = None if fused_layout_kda_ms is None else fused_layout_kda_ms * 1000.0 / tokens + sglang_packed_us = None if sglang_packed_ms is None else sglang_packed_ms * 1000.0 / tokens + full_us = None if full_ms is None else full_ms * 1000.0 / tokens + triton_ratio = None + if native_core_ms is not None and triton_core_ms is not None and native_core_ms > 0: + triton_ratio = triton_core_ms / native_core_ms + sglang_ratio = None + if native_core_ms is not None and sglang_core_ms is not None and native_core_ms > 0: + sglang_ratio = sglang_core_ms / native_core_ms + packed_ratio = None + if fused_layout_kda_ms is not None and sglang_packed_ms is not None and fused_layout_kda_ms > 0: + packed_ratio = sglang_packed_ms / fused_layout_kda_ms + print( + f"| {tokens} | " + f"{'n/a' if native_core_ms is None else f'{native_core_ms:.4f}'} | " + f"{'n/a' if triton_core_ms is None else f'{triton_core_ms:.4f}'} | " + f"{'n/a' if sglang_core_ms is None else f'{sglang_core_ms:.4f}'} | " + f"{'n/a' if fused_layout_kda_ms is None else f'{fused_layout_kda_ms:.4f}'} | " + f"{'n/a' if sglang_packed_ms is None else f'{sglang_packed_ms:.4f}'} | " + f"{'n/a' if full_ms is None else f'{full_ms:.4f}'} | " + f"{'n/a' if triton_ratio is None else f'{triton_ratio:.2f}x'} | " + f"{'n/a' if sglang_ratio is None else f'{sglang_ratio:.2f}x'} | " + f"{'n/a' if packed_ratio is None else f'{packed_ratio:.2f}x'} | " + f"{'n/a' if native_core_us is None else f'{native_core_us:.2f}'} | " + f"{'n/a' if triton_core_us is None else f'{triton_core_us:.2f}'} | " + f"{'n/a' if sglang_core_us is None else f'{sglang_core_us:.2f}'} | " + f"{'n/a' if fused_layout_kda_us is None else f'{fused_layout_kda_us:.2f}'} | " + f"{'n/a' if sglang_packed_us is None else f'{sglang_packed_us:.2f}'} | " + f"{'n/a' if full_us is None else f'{full_us:.2f}'} |" + ) + rows.append( + { + "tokens": tokens, + "tp_size": args.tp_size, + "local_k_heads": config.num_k_heads, + "local_v_heads": config.num_v_heads, + "conv_dim": config.conv_dim, + "native_core_ms": native_core_ms, + "triton_core_ms": triton_core_ms, + "sglang_core_ms": sglang_core_ms, + "fused_layout_kda_ms": fused_layout_kda_ms, + "sglang_packed_ms": sglang_packed_ms, + "full_ms": full_ms, + "triton_over_native": triton_ratio, + "sglang_over_native": sglang_ratio, + "sglang_packed_over_fused": packed_ratio, + "native_core_us_per_token": native_core_us, + "triton_core_us_per_token": triton_core_us, + "sglang_core_us_per_token": sglang_core_us, + "fused_layout_kda_us_per_token": fused_layout_kda_us, + "sglang_packed_us_per_token": sglang_packed_us, + "full_us_per_token": full_us, + } + ) + + if args.csv is not None: + args.csv.parent.mkdir(parents=True, exist_ok=True) + with args.csv.open("w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter( + f, + fieldnames=[ + "tokens", + "tp_size", + "local_k_heads", + "local_v_heads", + "conv_dim", + "native_core_ms", + "triton_core_ms", + "sglang_core_ms", + "fused_layout_kda_ms", + "sglang_packed_ms", + "full_ms", + "triton_over_native", + "sglang_over_native", + "sglang_packed_over_fused", + "native_core_us_per_token", + "triton_core_us_per_token", + "sglang_core_us_per_token", + "fused_layout_kda_us_per_token", + "sglang_packed_us_per_token", + "full_us_per_token", + ], + ) + writer.writeheader() + writer.writerows(rows) + print(f"wrote {args.csv}") + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/benchmarks/bench_qwen35_prefill.py b/benchmarks/bench_qwen35_prefill.py new file mode 100644 index 00000000..bfcbe6d5 --- /dev/null +++ b/benchmarks/bench_qwen35_prefill.py @@ -0,0 +1,449 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Benchmark Qwen3.5 prefill kernels. + +Reports: + - layout: cuLA Qwen3.5 prefill layout split/repeat kernel + - cula_qk: cuLA Qwen3.5 TMA/WGMMA-or-UMMA QK chunk debug kernel + - cula_fused: cuLA generic fused KDA core through a Qwen3.5 scalar-gate adapter + - fla_gdr: optional FLA chunk_gated_delta_rule baseline + - sgl_gdr: optional SGLang vendored Triton chunk_gated_delta_rule baseline + +Baselines are optional. SGLang Qwen3.5 prefill uses the same chunked gated +delta rule family in its Triton GDN kernel; decode uses a recurrent packed +kernel instead. + +Note: cula_qk currently benchmarks the TMA tensor-core Q @ K^T subpath only, +not the full gated-delta prefill recurrence. Its output is [B,local_HV,T,T], +so long sequence lengths have quadratic memory cost. +""" + +from __future__ import annotations + +import argparse +import importlib +import inspect +import pathlib +import statistics +import sys +from collections.abc import Callable + +import torch +import torch.nn.functional as F + +ROOT = pathlib.Path(__file__).resolve().parent.parent +sys.path.insert(0, str(ROOT)) + +from cula.ops.qwen35_layout_prefill import qwen35_layout_prefill +from cula.ops.qwen35_fused_kda_prefill import qwen35_fused_kda_prefill +from cula.ops.qwen35_scalar_kda_prefill import qwen35_scalar_kda_prefill +from cula.qwen35.common import DEFAULT_QWEN35_LINEAR_ATTN_CONFIG as CONFIG +from cula.qwen35.common import Qwen35LinearAttentionConfig +from cula.utils import get_kda_fused_fwd + +try: + import cula.cudac as cula_cuda +except ImportError: + cula_cuda = None + +RCP_LN2 = 1.4426950408889634 + + +def benchmark_cuda_fn(fn: Callable[[], object], *, warmup: int, rep: int) -> float: + for _ in range(warmup): + fn() + torch.cuda.synchronize() + + starts = [torch.cuda.Event(enable_timing=True) for _ in range(rep)] + ends = [torch.cuda.Event(enable_timing=True) for _ in range(rep)] + for idx in range(rep): + starts[idx].record() + fn() + ends[idx].record() + torch.cuda.synchronize() + + times = [start.elapsed_time(end) for start, end in zip(starts, ends)] + if len(times) <= 2: + return statistics.mean(times) + times = sorted(times) + return statistics.mean(times[len(times) // 4 : 3 * len(times) // 4]) + + +def error_stats(ref: torch.Tensor, out: torch.Tensor) -> tuple[float, float, float]: + ref_f = ref.float() + out_f = out.float() + diff = (ref_f - out_f).abs() + rmse = diff.square().mean().sqrt().item() + ref_rms = ref_f.square().mean().sqrt().item() + rel_rms = rmse / (ref_rms + 1.0e-8) + rel_max = diff.max().item() / (ref_f.abs().max().item() + 1.0e-8) + mean_abs = diff.mean().item() + return rel_rms, rel_max, mean_abs + + +def resolve_fla_chunk_gdr(): + try: + module = importlib.import_module("fla.ops.gated_delta_rule") + except ImportError as exc: + return None, f"cannot import fla.ops.gated_delta_rule: {exc}" + if not hasattr(module, "chunk_gated_delta_rule"): + return None, "fla.ops.gated_delta_rule has no chunk_gated_delta_rule" + return module.chunk_gated_delta_rule, "fla.ops.gated_delta_rule.chunk_gated_delta_rule" + + +def resolve_sgl_chunk_gdr(sglang_path: pathlib.Path | None): + if sglang_path is not None: + for root in (sglang_path, sglang_path / "python"): + if root.exists(): + sys.path.insert(0, str(root)) + try: + module = importlib.import_module("sglang.srt.layers.attention.fla.chunk") + except ImportError as exc: + return None, f"cannot import sglang.srt.layers.attention.fla.chunk: {exc}" + if not hasattr(module, "chunk_gated_delta_rule"): + return None, "sglang.srt.layers.attention.fla.chunk has no chunk_gated_delta_rule" + return module.chunk_gated_delta_rule, "sglang.srt.layers.attention.fla.chunk.chunk_gated_delta_rule" + + +def local_config_from_tp_size(tp_size: int) -> Qwen35LinearAttentionConfig: + if tp_size not in (1, 2, 4, 8): + raise ValueError(f"tp_size must be one of 1, 2, 4, 8, got {tp_size}") + return Qwen35LinearAttentionConfig( + hidden_size=CONFIG.hidden_size // tp_size, + conv_kernel_size=CONFIG.conv_kernel_size, + num_k_heads=CONFIG.num_k_heads // tp_size, + num_v_heads=CONFIG.num_v_heads // tp_size, + head_k_dim=CONFIG.head_k_dim, + head_v_dim=CONFIG.head_v_dim, + qkv_dtype=CONFIG.qkv_dtype, + state_dtype=CONFIG.state_dtype, + ) + + +def make_inputs(batch: int, seq_len: int, *, device: torch.device, seed: int, config: Qwen35LinearAttentionConfig): + torch.manual_seed(seed) + q = torch.randn(batch, seq_len, config.num_v_heads, config.head_k_dim, device=device, dtype=config.qkv_dtype) + k = torch.randn_like(q) + v = torch.randn_like(q) + a = torch.randn(batch, seq_len, config.num_v_heads, device=device, dtype=config.qkv_dtype) + b = torch.randn(batch, seq_len, config.num_v_heads, device=device, dtype=config.qkv_dtype) + beta = torch.sigmoid(b.float()).to(dtype=config.qkv_dtype) + A_log = -torch.rand(config.num_v_heads, device=device, dtype=torch.float32) + dt_bias = torch.randn(config.num_v_heads, device=device, dtype=torch.float32) * 0.1 + log_gate = (-torch.exp(A_log).view(1, 1, -1) * torch.nn.functional.softplus(a.float() + dt_bias.view(1, 1, -1))).to( + dtype=config.qkv_dtype + ) + initial_state = torch.randn( + batch, + config.num_v_heads, + config.head_k_dim, + config.head_v_dim, + device=device, + dtype=torch.float32, + ) * 0.01 + mixed_qkv_conv = torch.randn(batch * seq_len, config.conv_dim, device=device, dtype=config.qkv_dtype) + a_flat = a.reshape(batch * seq_len, config.num_v_heads).contiguous() + b_flat = b.reshape(batch * seq_len, config.num_v_heads).contiguous() + return q, k, v, a, b, beta, log_gate, A_log, dt_bias, initial_state, mixed_qkv_conv, a_flat, b_flat + + +def run_cula_chunk_qk(q, k, out): + if cula_cuda is None or not hasattr(cula_cuda, "qwen35_chunk_qk_prefill_sm90"): + raise RuntimeError("cula.cudac.qwen35_chunk_qk_prefill_sm90 is not available. Rebuild the CUDA extension.") + cula_cuda.qwen35_chunk_qk_prefill_sm90(q.contiguous(), k.contiguous(), out) + return out + + +def run_cula_scalar(q, k, v, a, b, A_log, dt_bias, initial_state): + return qwen35_scalar_kda_prefill( + q, + k, + v, + a, + b, + A_log, + dt_bias, + initial_state=initial_state, + backend="cudac", + ) + + +def run_cula_fused(q, k, v, a, b, A_log, dt_bias, initial_state): + return qwen35_fused_kda_prefill( + q, + k, + v, + a, + b, + A_log, + dt_bias, + initial_state=initial_state, + ) + + +def prepare_cula_fused_core_inputs(q, k, a, b, A_log, dt_bias, initial_state): + B, T, HV, K = q.shape + q_norm = F.normalize(q.float(), dim=-1).to(q.dtype).contiguous() + k_norm = F.normalize(k.float(), dim=-1).to(k.dtype).contiguous() + log_gate_scalar = -torch.exp(A_log.float()).view(1, 1, HV, 1) * F.softplus( + a.float().unsqueeze(-1) + dt_bias.float().view(1, 1, HV, 1) + ) + log_gate = log_gate_scalar.expand(B, T, HV, K).contiguous() + chunks = [] + for chunk_start in range(0, T, 64): + chunks.append(log_gate[:, chunk_start : chunk_start + 64].cumsum(dim=1) * RCP_LN2) + log_gate_cumsum = torch.cat(chunks, dim=1).contiguous() + beta = torch.sigmoid(b.float()).contiguous() + initial_state_vk = initial_state.float().transpose(-1, -2).contiguous() + return q_norm, k_norm, log_gate_cumsum, beta, initial_state_vk + + +def run_cula_fused_core(q_norm, k_norm, v, log_gate_cumsum, beta, initial_state_vk, config: Qwen35LinearAttentionConfig): + fused_kda_prefill = get_kda_fused_fwd(q_norm.device) + return fused_kda_prefill( + q=q_norm, + k=k_norm, + v=v.contiguous(), + g=log_gate_cumsum, + beta=beta, + scale=config.head_k_dim**-0.5, + initial_state=initial_state_vk, + output_final_state=True, + use_qk_l2norm_in_kernel=False, + use_gate_in_kernel=False, + safe_gate=False, + g_is_cumsum=True, + ) + + +def run_chunk_gdr(chunk_gdr, q, k, v, log_gate, beta, initial_state, initial_state_indices, config: Qwen35LinearAttentionConfig): + # SGLang/FLA GDR chunk kernels use [N, H, V, K] state layout. cuLA's + # Qwen3.5 wrapper uses [N, H, K, V], so pass the transposed view here. + initial_state_vk = initial_state.transpose(-1, -2).contiguous() + kwargs = dict( + q=q, + k=k, + v=v, + g=log_gate, + beta=beta, + initial_state=initial_state_vk, + initial_state_indices=initial_state_indices, + output_final_state=True, + scale=config.head_k_dim**-0.5, + use_qk_l2norm_in_kernel=True, + head_first=False, + ) + try: + sig = inspect.signature(chunk_gdr) + kwargs = {key: value for key, value in kwargs.items() if key in sig.parameters} + except (TypeError, ValueError): + pass + return chunk_gdr(**kwargs) + + +def _normalize_chunk_result(result): + if isinstance(result, tuple): + out = result[0] + state = result[-1] if len(result) >= 2 else None + return out, state + return result, None + + +def _state_to_cula_layout(state: torch.Tensor | None) -> torch.Tensor | None: + if state is None: + return None + return state.transpose(-1, -2).contiguous() + + +def print_header(device: torch.device, args: argparse.Namespace, baseline_sources: dict[str, str]) -> None: + config = local_config_from_tp_size(args.tp_size) + print("Qwen3.5 prefill benchmark") + print(f" device: {torch.cuda.get_device_name(device)}") + print(f" dtype: {config.qkv_dtype}") + print(f" batch: {args.batch}") + print( + f" tp/local config: tp={args.tp_size} local_k_heads={config.num_k_heads} " + f"local_v_heads={config.num_v_heads} conv_dim={config.conv_dim}" + ) + print(f" seq lens: {args.seq_lens}") + print(f" warmup/rep: {args.warmup}/{args.rep}") + print(f" baselines: {baseline_sources or 'disabled/unavailable'}") + if args.cula_mode == "qk": + print(" cula: qwen35_chunk_qk_prefill_sm90 QK subpath only; baselines are full Triton GDR chunk kernels") + elif args.cula_mode == "scalar": + print(" cula: qwen35_scalar_kda_prefill full recurrence fallback") + elif args.cula_mode == "fused": + print(" cula: qwen35_fused_kda_prefill full recurrence via fused KDA CuTe core") + elif args.cula_mode == "fused-core": + print(" cula: fused KDA CuTe core only; Qwen gate/l2norm/cumsum/state prep is outside timing") + print() + cula_col = f"cula_{args.cula_mode}_ms" + print( + f"{'baseline':>8} {'B':>3} {'T':>7} {'layout_ms':>11} {cula_col:>13} {'cula_total':>11} " + f"{'base_ms':>11} {'base/cula':>10} {'rel_rms':>11} {'rel_max':>11}" + ) + print("-" * 113) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=1) + parser.add_argument("--seq-lens", type=int, nargs="+", default=[128, 256, 512, 1024]) + parser.add_argument("--warmup", type=int, default=10) + parser.add_argument("--rep", type=int, default=30) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--tp-size", type=int, choices=[1, 2, 4, 8], default=1) + parser.add_argument("--baseline", choices=["none", "fla", "sgl", "all"], default="sgl") + parser.add_argument("--sglang-path", type=pathlib.Path, default=None) + parser.add_argument( + "--cula-mode", + choices=["qk", "scalar", "fused", "fused-core"], + default="qk", + help="cuLA path to benchmark: qk is QK subpath, scalar is old full fallback, fused is wrapper, fused-core is kernel only.", + ) + parser.add_argument("--skip-accuracy", action="store_true") + parser.add_argument( + "--max-qk-elements", + type=int, + default=512 * 1024 * 1024, + help="Skip cuLA QK timings when B*local_HV*T*T exceeds this element count.", + ) + args = parser.parse_args() + config = local_config_from_tp_size(args.tp_size) + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for this benchmark.") + device = torch.device("cuda") + + baselines: dict[str, Callable] = {} + baseline_sources: dict[str, str] = {} + if args.baseline in ("fla", "all"): + fla_chunk_gdr, fla_source_or_error = resolve_fla_chunk_gdr() + if fla_chunk_gdr is None: + print(f"Skipping FLA baseline: {fla_source_or_error}") + else: + baselines["fla"] = fla_chunk_gdr + baseline_sources["fla"] = fla_source_or_error + if args.baseline in ("sgl", "all"): + sgl_chunk_gdr, sgl_source_or_error = resolve_sgl_chunk_gdr(args.sglang_path) + if sgl_chunk_gdr is None: + print(f"Skipping SGLang baseline: {sgl_source_or_error}") + else: + baselines["sgl"] = sgl_chunk_gdr + baseline_sources["sgl"] = sgl_source_or_error + + print_header(device, args, baseline_sources) + + for seq_len in args.seq_lens: + q, k, v, a, b, beta, log_gate, A_log, dt_bias, initial_state, mixed_qkv_conv, a_flat, b_flat = make_inputs( + args.batch, + seq_len, + device=device, + seed=args.seed, + config=config, + ) + initial_state_indices = torch.arange(args.batch, device=device, dtype=torch.int32) + + def layout_fn(): + return qwen35_layout_prefill(mixed_qkv_conv, a_flat, b_flat, backend="cudac") + + qk_elements = args.batch * config.num_v_heads * seq_len * seq_len + qk_out = None + if args.cula_mode == "qk" and qk_elements <= args.max_qk_elements: + qk_out = torch.empty( + args.batch, + config.num_v_heads, + seq_len, + seq_len, + device=device, + dtype=torch.float32, + ) + fused_core_inputs = None + if args.cula_mode == "fused-core": + fused_core_inputs = prepare_cula_fused_core_inputs(q, k, a, b, A_log, dt_bias, initial_state) + + def cula_fn(): + if args.cula_mode == "scalar": + return run_cula_scalar(q, k, v, a, b, A_log, dt_bias, initial_state) + if args.cula_mode == "fused": + return run_cula_fused(q, k, v, a, b, A_log, dt_bias, initial_state) + if args.cula_mode == "fused-core": + return run_cula_fused_core(*fused_core_inputs[:2], v, *fused_core_inputs[2:], config) + if qk_out is None: + raise RuntimeError( + f"Skipping cuLA QK: B*H*T*T={qk_elements} exceeds --max-qk-elements={args.max_qk_elements}" + ) + return run_cula_chunk_qk(q, k, qk_out) + + layout_ms = benchmark_cuda_fn(layout_fn, warmup=args.warmup, rep=args.rep) + cula_ms = ( + float("nan") + if args.cula_mode == "qk" and qk_out is None + else benchmark_cuda_fn(cula_fn, warmup=args.warmup, rep=args.rep) + ) + cula_total_ms = layout_ms + cula_ms if not torch.isnan(torch.tensor(cula_ms)) else float("nan") + + rel_rms = float("nan") + rel_max = float("nan") + state_cula = None + if not args.skip_accuracy: + out_cula = cula_fn() + if args.cula_mode == "qk": + qk_ref = torch.einsum("bthd,bshd->bhts", q.float(), k.float()) + torch.cuda.synchronize() + rel_rms, rel_max, _ = error_stats(qk_ref, out_cula) + del qk_ref + else: + out_cula, state_cula = out_cula + torch.cuda.synchronize() + + if not baselines: + print( + f"{'none':>8} {args.batch:3d} {seq_len:7d} {layout_ms:11.4f} {cula_ms:13.4f} {cula_total_ms:11.4f} " + f"{float('nan'):11.4f} {float('nan'):10.3f} {rel_rms:11.3e} {rel_max:11.3e}" + ) + + for baseline_name, chunk_gdr in baselines.items(): + def baseline_fn(): + return run_chunk_gdr(chunk_gdr, q, k, v, log_gate, beta, initial_state, initial_state_indices, config) + + row_rel_rms = rel_rms + row_rel_max = rel_max + if args.cula_mode in ("scalar", "fused", "fused-core") and not args.skip_accuracy: + if state_cula is None: + out_cula, state_cula = cula_fn() + out_base, state_base = _normalize_chunk_result(baseline_fn()) + state_base = _state_to_cula_layout(state_base) + torch.cuda.synchronize() + row_rel_rms, row_rel_max, _ = error_stats(out_base, out_cula) + if state_base is not None and tuple(state_base.shape) == tuple(state_cula.shape): + rel_rms_s, rel_max_s, _ = error_stats(state_base, state_cula) + row_rel_rms = max(row_rel_rms, rel_rms_s) + row_rel_max = max(row_rel_max, rel_max_s) + + base_ms = benchmark_cuda_fn(baseline_fn, warmup=args.warmup, rep=args.rep) + speedup = base_ms / cula_ms if cula_ms > 0 else float("nan") + print( + f"{baseline_name:>8} {args.batch:3d} {seq_len:7d} {layout_ms:11.4f} {cula_ms:13.4f} {cula_total_ms:11.4f} " + f"{base_ms:11.4f} {speedup:10.3f} {row_rel_rms:11.3e} {row_rel_max:11.3e}" + ) + + del q, k, v, a, b, beta, log_gate, A_log, dt_bias, initial_state, initial_state_indices, mixed_qkv_conv, a_flat, b_flat, qk_out, fused_core_inputs + torch.cuda.empty_cache() + + +if __name__ == "__main__": + main() diff --git a/benchmarks/tune_qwen35_tp_policy.py b/benchmarks/tune_qwen35_tp_policy.py new file mode 100644 index 00000000..28653e9c --- /dev/null +++ b/benchmarks/tune_qwen35_tp_policy.py @@ -0,0 +1,362 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tune Qwen3.5 TP-local kernel policies. + +This is a configuration-driven tuner. It benchmarks only policies that are +compiled into the current extension and records unsupported candidates in the +result file. The initial compiled policy is the decode traits currently used by +the CUDA/CuTe kernels: + + layout_vec=4, kda_threads=128, kda_tile_v=16, kda_tile_k=16, heads_per_cta=1 + +When more C++ policy specializations are added, extend `compiled_policy_key` +and the kernel dispatch path; this script can then sweep them without changing +the output format. +""" + +from __future__ import annotations + +import argparse +import csv +import itertools +import json +import pathlib +import sys +from dataclasses import asdict, dataclass +from typing import Any + +ROOT = pathlib.Path(__file__).resolve().parent.parent +sys.path.insert(0, str(ROOT)) + + +@dataclass(frozen=True) +class DecodePolicy: + name: str + layout_vec: int + kda_threads: int + kda_tile_v: int + kda_tile_k: int + heads_per_cta: int = 1 + + @property + def key(self) -> tuple[int, int, int, int, int]: + return (self.layout_vec, self.kda_threads, self.kda_tile_v, self.kda_tile_k, self.heads_per_cta) + + +CURRENT_DECODE_POLICY = DecodePolicy( + name="current", + layout_vec=4, + kda_threads=128, + kda_tile_v=16, + kda_tile_k=16, + heads_per_cta=1, +) + + +def decode_benchmarks(): + from benchmarks import bench_qwen35_decode + + return bench_qwen35_decode + + +def compiled_policy_key(policy: DecodePolicy) -> str | None: + """Return the compiled backend selector for a policy, or None if absent.""" + if policy.key == CURRENT_DECODE_POLICY.key: + return "current" + return None + + +def _list_from_json(data: dict[str, Any], key: str, default: list[int]) -> list[int]: + value = data.get(key, default) + if not isinstance(value, list) or not value: + raise ValueError(f"{key} must be a non-empty list") + return [int(item) for item in value] + + +def load_decode_policies(path: pathlib.Path | None) -> list[DecodePolicy]: + if path is None: + return [CURRENT_DECODE_POLICY] + with path.open("r", encoding="utf-8") as f: + data = json.load(f) + + if isinstance(data, list): + policies = [] + for idx, item in enumerate(data): + if not isinstance(item, dict): + raise ValueError(f"Policy entry {idx} must be an object") + policies.append( + DecodePolicy( + name=str(item.get("name", f"policy_{idx}")), + layout_vec=int(item["layout_vec"]), + kda_threads=int(item["kda_threads"]), + kda_tile_v=int(item["kda_tile_v"]), + kda_tile_k=int(item["kda_tile_k"]), + heads_per_cta=int(item.get("heads_per_cta", 1)), + ) + ) + return policies + + if not isinstance(data, dict): + raise ValueError("Policy grid must be a JSON object or list") + if data.get("mode", "decode") != "decode": + raise ValueError("Only decode policy grids are supported by this tuner") + + policies = [] + for idx, combo in enumerate( + itertools.product( + _list_from_json(data, "layout_vec", [CURRENT_DECODE_POLICY.layout_vec]), + _list_from_json(data, "kda_threads", [CURRENT_DECODE_POLICY.kda_threads]), + _list_from_json(data, "kda_tile_v", [CURRENT_DECODE_POLICY.kda_tile_v]), + _list_from_json(data, "kda_tile_k", [CURRENT_DECODE_POLICY.kda_tile_k]), + _list_from_json(data, "heads_per_cta", [CURRENT_DECODE_POLICY.heads_per_cta]), + ) + ): + layout_vec, kda_threads, kda_tile_v, kda_tile_k, heads_per_cta = combo + policies.append( + DecodePolicy( + name=f"p{idx}_lv{layout_vec}_th{kda_threads}_tv{kda_tile_v}_tk{kda_tile_k}_h{heads_per_cta}", + layout_vec=layout_vec, + kda_threads=kda_threads, + kda_tile_v=kda_tile_v, + kda_tile_k=kda_tile_k, + heads_per_cta=heads_per_cta, + ) + ) + return policies + + +def write_example_grid(path: pathlib.Path) -> None: + example = { + "mode": "decode", + "layout_vec": [4, 8], + "kda_threads": [64, 128, 256], + "kda_tile_v": [8, 16, 32], + "kda_tile_k": [8, 16, 32], + "heads_per_cta": [1, 2, 4], + } + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as f: + json.dump(example, f, indent=2, sort_keys=True) + + +def bucket_name(tokens: int) -> str: + if tokens <= 4: + return "tokens<=4" + if tokens <= 16: + return "tokens<=16" + if tokens <= 64: + return "tokens<=64" + return "tokens>64" + + +def run_decode_policy( + *, + scope: str, + tokens: int, + tp_size: int, + warmup: int, + rep: int, + seed: int, + policy: DecodePolicy, +) -> dict[str, Any]: + decode_bench = decode_benchmarks() + config = decode_bench.local_config_from_tp_size(tp_size) + compiled_key = compiled_policy_key(policy) + row: dict[str, Any] = { + "mode": "decode", + "scope": scope, + "tokens": tokens, + "token_bucket": bucket_name(tokens), + "tp_size": tp_size, + "local_k_heads": config.num_k_heads, + "local_v_heads": config.num_v_heads, + "conv_dim": config.conv_dim, + "policy": policy.name, + "compiled_policy": compiled_key, + **asdict(policy), + } + if compiled_key is None: + row.update({"status": "unsupported", "ms": None, "us_per_token": None}) + return row + + device = decode_bench.accelerator_device() + if scope == "core": + ms = decode_bench.bench_native_core(tokens, device, warmup, rep, seed, config) + elif scope == "fused": + ms = decode_bench.bench_fused_layout_kda(tokens, device, warmup, rep, seed, config) + elif scope == "full": + ms = decode_bench.bench_full(tokens, device, warmup, rep, seed, config) + else: + raise ValueError(f"Unsupported decode scope={scope}") + + row.update({"status": "ok", "ms": ms, "us_per_token": ms * 1000.0 / tokens}) + return row + + +def choose_best(rows: list[dict[str, Any]]) -> list[dict[str, Any]]: + groups: dict[tuple[Any, ...], list[dict[str, Any]]] = {} + for row in rows: + if row["status"] != "ok": + continue + key = (row["mode"], row["scope"], row["tp_size"], row["local_v_heads"], row["token_bucket"]) + groups.setdefault(key, []).append(row) + + best_rows = [] + for key, candidates in sorted(groups.items()): + best = min(candidates, key=lambda row: float(row["ms"])) + mode, scope, tp_size, local_v_heads, token_bucket = key + best_rows.append( + { + "mode": mode, + "scope": scope, + "tp_size": tp_size, + "local_v_heads": local_v_heads, + "token_bucket": token_bucket, + "policy": best["policy"], + "compiled_policy": best["compiled_policy"], + "ms": best["ms"], + "us_per_token": best["us_per_token"], + "layout_vec": best["layout_vec"], + "kda_threads": best["kda_threads"], + "kda_tile_v": best["kda_tile_v"], + "kda_tile_k": best["kda_tile_k"], + "heads_per_cta": best["heads_per_cta"], + } + ) + return best_rows + + +def write_csv(path: pathlib.Path, rows: list[dict[str, Any]]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + fieldnames = [ + "mode", + "scope", + "tokens", + "token_bucket", + "tp_size", + "local_k_heads", + "local_v_heads", + "conv_dim", + "policy", + "compiled_policy", + "status", + "ms", + "us_per_token", + "layout_vec", + "kda_threads", + "kda_tile_v", + "kda_tile_k", + "heads_per_cta", + ] + with path.open("w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames, extrasaction="ignore") + writer.writeheader() + writer.writerows(rows) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Tune Qwen3.5 TP-local kernel policies.") + parser.add_argument("--mode", choices=["decode"], default="decode") + parser.add_argument("--scope", choices=["core", "fused", "full", "all"], default="fused") + parser.add_argument("--tp-sizes", nargs="+", type=int, choices=[1, 2, 4, 8], default=[1, 2, 4, 8]) + parser.add_argument("--tokens", nargs="+", type=int, default=[1, 2, 4, 8, 16, 32, 64, 128]) + parser.add_argument("--warmup", type=int, default=20) + parser.add_argument("--rep", type=int, default=100) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--policy-grid", type=pathlib.Path, default=None) + parser.add_argument("--write-example-grid", type=pathlib.Path, default=None) + parser.add_argument("--output-json", type=pathlib.Path, default=pathlib.Path("tmp/qwen35_tp_policy_tune.json")) + parser.add_argument("--csv", type=pathlib.Path, default=None) + parser.add_argument("--fail-on-unsupported", action="store_true") + return parser.parse_args() + + +def main() -> int: + args = parse_args() + if args.write_example_grid is not None: + write_example_grid(args.write_example_grid) + print(f"wrote example policy grid: {args.write_example_grid}") + return 0 + + policies = load_decode_policies(args.policy_grid) + scopes = ["core", "fused", "full"] if args.scope == "all" else [args.scope] + decode_bench = decode_benchmarks() + device = decode_bench.accelerator_device() + device_name = decode_bench.accelerator_name(device) + + print(f"Qwen3.5 TP policy tuner: mode={args.mode} device={device_name}") + print(f"tp_sizes={args.tp_sizes} tokens={args.tokens} scopes={scopes}") + print(f"policies={len(policies)} compiled={sum(compiled_policy_key(p) is not None for p in policies)}") + + rows: list[dict[str, Any]] = [] + for policy in policies: + compiled_key = compiled_policy_key(policy) + if compiled_key is None: + print(f"skip unsupported policy={policy.name} {asdict(policy)}") + for scope in scopes: + for tp_size in args.tp_sizes: + for tokens in args.tokens: + row = run_decode_policy( + scope=scope, + tokens=tokens, + tp_size=tp_size, + warmup=args.warmup, + rep=args.rep, + seed=args.seed, + policy=policy, + ) + rows.append(row) + if row["status"] == "ok": + print( + f"{scope:>5} tp={tp_size} hv={row['local_v_heads']:>2} tokens={tokens:>4} " + f"policy={policy.name} ms={row['ms']:.4f} us/tok={row['us_per_token']:.2f}" + ) + + unsupported = [row for row in rows if row["status"] == "unsupported"] + if unsupported and args.fail_on_unsupported: + raise RuntimeError(f"{len(unsupported)} policy/shape rows are unsupported by the compiled extension") + + best_rows = choose_best(rows) + result = { + "device": device_name, + "mode": args.mode, + "warmup": args.warmup, + "rep": args.rep, + "seed": args.seed, + "rows": rows, + "best": best_rows, + } + args.output_json.parent.mkdir(parents=True, exist_ok=True) + with args.output_json.open("w", encoding="utf-8") as f: + json.dump(result, f, indent=2, sort_keys=True) + print(f"wrote {args.output_json}") + + if args.csv is not None: + write_csv(args.csv, rows) + print(f"wrote {args.csv}") + + if best_rows: + print("best policies:") + for row in best_rows: + print( + f" {row['scope']:>5} tp={row['tp_size']} hv={row['local_v_heads']:>2} " + f"{row['token_bucket']}: {row['policy']} {row['ms']:.4f} ms" + ) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/csrc/api/pybind.cu b/csrc/api/pybind.cu index d14a41c5..3f741945 100644 --- a/csrc/api/pybind.cu +++ b/csrc/api/pybind.cu @@ -17,6 +17,9 @@ #include #include +#include "qwen35/decode/qwen35_decode_common.cuh" +#include "qwen35/prefill/qwen35_prefill_common.cuh" + #if defined(CULA_SM100_ENABLED) || defined(CULA_SM103_ENABLED) void ChunkKDAFwdIntra( @@ -68,6 +71,151 @@ kda_fwd_prefill( bool safe_gate); #endif +void +qwen35_conv1d_decode( + at::Tensor mixed_qkv, + at::Tensor conv_state, + at::Tensor conv_weight, + at::Tensor out) { + cula::qwen35::decode::ConvDecodeParams params{ + mixed_qkv, + conv_state, + conv_weight, + out, + }; + cula::qwen35::decode::run_qwen35_conv1d_decode(params); +} + +void +qwen35_layout_decode( + at::Tensor mixed_qkv_conv, + at::Tensor a, + at::Tensor b, + at::Tensor q_rep, + at::Tensor k_rep, + at::Tensor v, + at::Tensor a_kernel, + at::Tensor b_kernel) { + cula::qwen35::decode::LayoutDecodeParams params{ + mixed_qkv_conv, + a, + b, + q_rep, + k_rep, + v, + a_kernel, + b_kernel, + }; + cula::qwen35::decode::run_qwen35_layout_decode(params); +} + +void +qwen35_scalar_kda_decode( + at::Tensor q_rep, + at::Tensor k_rep, + at::Tensor v, + at::Tensor a_kernel, + at::Tensor b_kernel, + at::Tensor A_log, + at::Tensor dt_bias, + at::Tensor recurrent_state, + at::Tensor pool_idx, + at::Tensor out) { + cula::qwen35::decode::ScalarKdaDecodeParams params{ + q_rep, + k_rep, + v, + a_kernel, + b_kernel, + A_log, + dt_bias, + recurrent_state, + pool_idx, + out, + }; + cula::qwen35::decode::run_qwen35_scalar_kda_decode(params); +} + +void +qwen35_layout_scalar_kda_decode( + at::Tensor mixed_qkv_conv, + at::Tensor a, + at::Tensor b, + at::Tensor A_log, + at::Tensor dt_bias, + at::Tensor recurrent_state, + at::Tensor pool_idx, + at::Tensor out) { + cula::qwen35::decode::LayoutScalarKdaDecodeParams params{ + mixed_qkv_conv, + a, + b, + A_log, + dt_bias, + recurrent_state, + pool_idx, + out, + }; + cula::qwen35::decode::run_qwen35_layout_scalar_kda_decode(params); +} + +void +qwen35_scalar_kda_prefill( + at::Tensor q, + at::Tensor k, + at::Tensor v, + at::Tensor a, + at::Tensor b, + at::Tensor A_log, + at::Tensor dt_bias, + at::Tensor initial_state, + at::Tensor cu_seqlens, + at::Tensor out, + at::Tensor final_state) { + cula::qwen35::prefill::ScalarKdaPrefillParams params{ + q, + k, + v, + a, + b, + A_log, + dt_bias, + initial_state, + cu_seqlens, + out, + final_state, + }; + cula::qwen35::prefill::run_qwen35_scalar_kda_prefill(params); +} + +void +qwen35_layout_prefill( + at::Tensor mixed_qkv_conv, + at::Tensor a, + at::Tensor b, + at::Tensor q_rep, + at::Tensor k_rep, + at::Tensor v, + at::Tensor a_kernel, + at::Tensor b_kernel) { + cula::qwen35::prefill::LayoutPrefillParams params{ + mixed_qkv_conv, + a, + b, + q_rep, + k_rep, + v, + a_kernel, + b_kernel, + }; + cula::qwen35::prefill::run_qwen35_layout_prefill(params); +} + +void +qwen35_chunk_qk_prefill_sm90(at::Tensor q, at::Tensor k, at::Tensor out) { + cula::qwen35::prefill::sm90::qwen35_chunk_qk_prefill_sm90(q, k, out); +} + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "cuLA"; #if defined(CULA_SM100_ENABLED) || defined(CULA_SM103_ENABLED) @@ -77,4 +225,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { #if defined(CULA_SM90A_ENABLED) m.def("kda_fwd_prefill", &kda_fwd_prefill); #endif + m.def("qwen35_conv1d_decode", &qwen35_conv1d_decode); + m.def("qwen35_layout_decode", &qwen35_layout_decode); + m.def("qwen35_scalar_kda_decode", &qwen35_scalar_kda_decode); + m.def("qwen35_layout_scalar_kda_decode", &qwen35_layout_scalar_kda_decode); + m.def("qwen35_layout_prefill", &qwen35_layout_prefill); + m.def("qwen35_scalar_kda_prefill", &qwen35_scalar_kda_prefill); + m.def("qwen35_chunk_qk_prefill_sm90", &qwen35_chunk_qk_prefill_sm90); } diff --git a/csrc/qwen35/decode/qwen35_conv1d_decode.cu b/csrc/qwen35/decode/qwen35_conv1d_decode.cu new file mode 100644 index 00000000..5a3c3470 --- /dev/null +++ b/csrc/qwen35/decode/qwen35_conv1d_decode.cu @@ -0,0 +1,192 @@ +// Copyright 2025-2026 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "qwen35_decode_common.cuh" + +#include +#include +#include +#include +#include +#include + +namespace { + +template +__device__ inline float to_float(T x) { + return static_cast(x); +} + +template <> +__device__ inline float to_float(c10::Half x) { + return __half2float(static_cast<__half>(x)); +} + +template <> +__device__ inline float to_float(c10::BFloat16 x) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + return __bfloat162float(static_cast<__nv_bfloat16>(x)); +#else + return static_cast(x); +#endif +} + +template +__device__ inline T from_float(float x) { + return static_cast(x); +} + +template <> +__device__ inline c10::Half from_float(float x) { + return c10::Half(__float2half_rn(x)); +} + +template <> +__device__ inline c10::BFloat16 from_float(float x) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + return c10::BFloat16(__float2bfloat16(x)); +#else + return c10::BFloat16(x); +#endif +} + +template +__global__ void qwen35_conv1d_decode_kernel( + const scalar_t* __restrict__ mixed_qkv, + scalar_t* __restrict__ conv_state, + const scalar_t* __restrict__ conv_weight, + scalar_t* __restrict__ out, + int batch_size, + int conv_dim) { + constexpr int kThreads = 256; + const int64_t linear_idx = static_cast(blockIdx.x) * kThreads + threadIdx.x; + const int64_t total = static_cast(batch_size) * conv_dim; + if (linear_idx >= total) { + return; + } + + const int64_t b = linear_idx / conv_dim; + const int64_t c = linear_idx % conv_dim; + + const int64_t x_idx = b * conv_dim + c; + const int64_t state_base = + (b * conv_dim + c) * cula::qwen35::decode::kConvKernelSize; + const int64_t weight_base = c * cula::qwen35::decode::kConvKernelSize; + + const float s0 = to_float(conv_state[state_base + 1]); + const float s1 = to_float(conv_state[state_base + 2]); + const float s2 = to_float(conv_state[state_base + 3]); + const float s3 = to_float(mixed_qkv[x_idx]); + + const float w0 = to_float(conv_weight[weight_base + 0]); + const float w1 = to_float(conv_weight[weight_base + 1]); + const float w2 = to_float(conv_weight[weight_base + 2]); + const float w3 = to_float(conv_weight[weight_base + 3]); + + const float conv = s0 * w0 + s1 * w1 + s2 * w2 + s3 * w3; + const float silu = conv / (1.f + expf(-conv)); + + conv_state[state_base + 0] = from_float(s0); + conv_state[state_base + 1] = from_float(s1); + conv_state[state_base + 2] = from_float(s2); + conv_state[state_base + 3] = from_float(s3); + out[x_idx] = from_float(silu); +} + +void check_tensor_device(const at::Tensor& tensor, const char* name, const at::Device& device) { + TORCH_CHECK(tensor.device() == device, name, " must be on device ", device, "."); +} + +} // namespace + +namespace cula::qwen35::decode { + +void run_qwen35_conv1d_decode(ConvDecodeParams& params) { + const at::Tensor& mixed_qkv = params.mixed_qkv; + const at::Tensor& conv_state = params.conv_state; + const at::Tensor& conv_weight = params.conv_weight; + const at::Tensor& out = params.out; + + TORCH_CHECK(mixed_qkv.is_cuda(), "mixed_qkv must be a CUDA tensor."); + const at::Device device = mixed_qkv.device(); + + check_tensor_device(conv_state, "conv_state", device); + check_tensor_device(conv_weight, "conv_weight", device); + check_tensor_device(out, "out", device); + + TORCH_CHECK(mixed_qkv.is_contiguous(), "mixed_qkv must be contiguous."); + TORCH_CHECK(conv_state.is_contiguous(), "conv_state must be contiguous."); + TORCH_CHECK(conv_weight.is_contiguous(), "conv_weight must be contiguous."); + TORCH_CHECK(out.is_contiguous(), "out must be contiguous."); + + TORCH_CHECK( + mixed_qkv.scalar_type() == conv_state.scalar_type() && + mixed_qkv.scalar_type() == conv_weight.scalar_type() && + mixed_qkv.scalar_type() == out.scalar_type(), + "mixed_qkv/conv_state/conv_weight/out must share the same dtype."); + + TORCH_CHECK( + mixed_qkv.scalar_type() == at::kHalf || mixed_qkv.scalar_type() == at::kBFloat16, + "conv decode only supports half/bfloat16."); + + const int64_t batch_size = mixed_qkv.size(0); + const int64_t conv_dim = mixed_qkv.size(2); + TORCH_CHECK(conv_dim > 0, "conv_dim must be positive."); + TORCH_CHECK( + mixed_qkv.dim() == 3 && mixed_qkv.sizes() == at::IntArrayRef({batch_size, 1, conv_dim}), + "mixed_qkv must have shape [B, 1, local_conv_dim]."); + TORCH_CHECK( + conv_state.dim() == 3 && + conv_state.sizes() == at::IntArrayRef({batch_size, conv_dim, kConvKernelSize}), + "conv_state must have shape [B, local_conv_dim, 4]."); + TORCH_CHECK( + (conv_weight.dim() == 2 && conv_weight.sizes() == at::IntArrayRef({conv_dim, kConvKernelSize})) || + (conv_weight.dim() == 3 && + conv_weight.sizes() == at::IntArrayRef({conv_dim, 1, kConvKernelSize})), + "conv_weight must have shape [local_conv_dim, 4] or [local_conv_dim, 1, 4]."); + TORCH_CHECK( + out.dim() == 3 && out.sizes() == at::IntArrayRef({batch_size, 1, conv_dim}), + "out must have shape [B, 1, local_conv_dim]."); + + const at::cuda::OptionalCUDAGuard device_guard(device); + cudaStream_t stream = at::cuda::getDefaultCUDAStream(device.index()); + + const at::Tensor mixed_qkv_2d = mixed_qkv.view({batch_size, conv_dim}); + const at::Tensor out_2d = out.view({batch_size, conv_dim}); + const at::Tensor weight_2d = + conv_weight.dim() == 3 ? conv_weight.view({conv_dim, kConvKernelSize}) : conv_weight; + + constexpr int kThreads = 256; + const int64_t total = batch_size * conv_dim; + const dim3 block(kThreads, 1, 1); + const dim3 grid(static_cast((total + kThreads - 1) / kThreads), 1, 1); + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + mixed_qkv.scalar_type(), + "qwen35_conv1d_decode_kernel", + [&] { + qwen35_conv1d_decode_kernel<<>>( + mixed_qkv_2d.data_ptr(), + conv_state.data_ptr(), + weight_2d.data_ptr(), + out_2d.data_ptr(), + static_cast(batch_size), + static_cast(conv_dim)); + }); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +} // namespace cula::qwen35::decode diff --git a/csrc/qwen35/decode/qwen35_decode_common.cuh b/csrc/qwen35/decode/qwen35_decode_common.cuh new file mode 100644 index 00000000..2f78d22e --- /dev/null +++ b/csrc/qwen35/decode/qwen35_decode_common.cuh @@ -0,0 +1,132 @@ +// Copyright 2025-2026 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +namespace cula::qwen35::decode { + +inline constexpr int kNumQKHeads = 16; +inline constexpr int kNumVHeads = 48; +inline constexpr int kHeadDimQK = 128; +inline constexpr int kHeadDimV = 128; +inline constexpr int kConvKernelSize = 4; +inline constexpr int kQDim = kNumQKHeads * kHeadDimQK; +inline constexpr int kKDim = kNumQKHeads * kHeadDimQK; +inline constexpr int kVDim = kNumVHeads * kHeadDimV; +inline constexpr int kMixedQKVDim = kQDim + kKDim + kVDim; + +inline constexpr int local_qk_heads_from_v_heads(int local_v_heads) { + return local_v_heads / (kNumVHeads / kNumQKHeads); +} + +inline constexpr int local_q_dim(int local_qk_heads) { + return local_qk_heads * kHeadDimQK; +} + +inline constexpr int local_v_dim(int local_v_heads) { + return local_v_heads * kHeadDimV; +} + +inline constexpr int local_mixed_qkv_dim(int local_qk_heads, int local_v_heads) { + return 2 * local_q_dim(local_qk_heads) + local_v_dim(local_v_heads); +} + +inline constexpr bool is_supported_local_v_heads(int local_v_heads) { + return local_v_heads == 48 || local_v_heads == 24 || local_v_heads == 12 || local_v_heads == 6; +} + +template +struct Qwen35DecodeLocalShape { + static_assert(is_supported_local_v_heads(kLocalVHeads_), "Unsupported Qwen3.5 local V-head count."); + static constexpr int kLocalVHeads = kLocalVHeads_; + static constexpr int kLocalQKHeads = local_qk_heads_from_v_heads(kLocalVHeads); + static constexpr int kRepeatFactor = kLocalVHeads / kLocalQKHeads; + static constexpr int kLocalQDim = local_q_dim(kLocalQKHeads); + static constexpr int kLocalKDim = local_q_dim(kLocalQKHeads); + static constexpr int kLocalVDim = local_v_dim(kLocalVHeads); + static constexpr int kLocalMixedQKVDim = local_mixed_qkv_dim(kLocalQKHeads, kLocalVHeads); + + // Decode shape policy. Head dimension is fixed at 128 for Qwen3.5, but keep + // these knobs with the local-head traits so future TP-shape tuning has one + // place to specialize. + static constexpr int kLayoutVec = 4; + static constexpr int kLayoutThreads = kHeadDimQK / kLayoutVec; + static constexpr int kKdaThreads = 128; + static constexpr int kKdaTileV = 32; + static constexpr int kKdaTileK = kHeadDimQK; + + static_assert(kLocalVHeads % kLocalQKHeads == 0); + static_assert(kHeadDimQK == kHeadDimV); + static_assert(kHeadDimQK % kLayoutVec == 0); + static_assert(kHeadDimV % kKdaTileV == 0); + static_assert(kHeadDimQK % kKdaTileK == 0); +}; + +struct ConvDecodeParams { + at::Tensor mixed_qkv; // [B, 1, local_conv_dim] + at::Tensor conv_state; // [B, local_conv_dim, 4] + at::Tensor conv_weight; // [local_conv_dim, 4] + at::Tensor out; // [B, 1, local_conv_dim] +}; + +struct LayoutDecodeParams { + at::Tensor mixed_qkv_conv; // [N, local_conv_dim] + at::Tensor a; // [N, local_v_heads] + at::Tensor b; // [N, local_v_heads] + at::Tensor q_rep; // [N, local_v_heads, 128] + at::Tensor k_rep; // [N, local_v_heads, 128] + at::Tensor v; // [N, local_v_heads, 128] + at::Tensor a_kernel; // [N, local_v_heads] + at::Tensor b_kernel; // [N, local_v_heads] +}; + +struct ScalarKdaDecodeParams { + // Dtype contract for the first implementation: + // - activations / outputs: half or bf16 + // q_rep, k_rep, v, a_kernel, b_kernel, out + // - recurrent parameters / state: float32 + // A_log, dt_bias, recurrent_state + at::Tensor q_rep; // [N, local_v_heads, 128] + at::Tensor k_rep; // [N, local_v_heads, 128] + at::Tensor v; // [N, local_v_heads, 128] + at::Tensor a_kernel; // [N, local_v_heads] + at::Tensor b_kernel; // [N, local_v_heads] + at::Tensor A_log; // [local_v_heads], float32 + at::Tensor dt_bias; // [local_v_heads], float32 + at::Tensor recurrent_state; // [pool, local_v_heads, 128, 128], float32 + at::Tensor pool_idx; // [N], int32 + at::Tensor out; // [N, local_v_heads, 128] +}; + +struct LayoutScalarKdaDecodeParams { + at::Tensor mixed_qkv_conv; // [N, local_conv_dim] + at::Tensor a; // [N, local_v_heads] + at::Tensor b; // [N, local_v_heads] + at::Tensor A_log; // [local_v_heads], float32 + at::Tensor dt_bias; // [local_v_heads], float32 + at::Tensor recurrent_state; // [pool, local_v_heads, 128, 128], float32 + at::Tensor pool_idx; // [N], int32 + at::Tensor out; // [N, local_v_heads, 128] +}; + +void run_qwen35_conv1d_decode(ConvDecodeParams& params); +void run_qwen35_layout_decode(LayoutDecodeParams& params); +void run_qwen35_scalar_kda_decode(ScalarKdaDecodeParams& params); +void run_qwen35_layout_scalar_kda_decode(LayoutScalarKdaDecodeParams& params); + +} // namespace cula::qwen35::decode diff --git a/csrc/qwen35/decode/qwen35_layout_decode.cu b/csrc/qwen35/decode/qwen35_layout_decode.cu new file mode 100644 index 00000000..811e98b9 --- /dev/null +++ b/csrc/qwen35/decode/qwen35_layout_decode.cu @@ -0,0 +1,167 @@ +// Copyright 2025-2026 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "qwen35_decode_common.cuh" +#include "qwen35_layout_kernel.hpp" + +#include +#include +#include +#include +#include +#include + +namespace { + +void check_tensor_device(const at::Tensor& tensor, const char* name, const at::Device& device) { + TORCH_CHECK(tensor.device() == device, name, " must be on device ", device, "."); +} + +void check_tensor_shape_2d(const at::Tensor& tensor, const char* name) { + TORCH_CHECK( + tensor.dim() == 2, + name, + " must have rank 2, but got rank ", + tensor.dim(), + "."); +} + +template +void launch_layout_decode_for_heads( + cudaStream_t stream, + const scalar_t* mixed_qkv_conv, + const scalar_t* a, + const scalar_t* b, + scalar_t* q_rep, + scalar_t* k_rep, + scalar_t* v, + scalar_t* a_kernel, + scalar_t* b_kernel, + int64_t batch_size) { + using Shape = cula::qwen35::decode::Qwen35DecodeLocalShape; + dim3 grid(Shape::kLocalVHeads, static_cast(batch_size), 1); + cula::qwen35::decode::qwen35_layout_decode_kernel_cute + <<>>( + mixed_qkv_conv, + a, + b, + q_rep, + k_rep, + v, + a_kernel, + b_kernel, + batch_size); +} + +} // namespace + +namespace cula::qwen35::decode { + +void run_qwen35_layout_decode(LayoutDecodeParams& params) { + const at::Tensor& mixed_qkv_conv = params.mixed_qkv_conv; + const at::Tensor& a = params.a; + const at::Tensor& b = params.b; + const at::Tensor& q_rep = params.q_rep; + const at::Tensor& k_rep = params.k_rep; + const at::Tensor& v = params.v; + const at::Tensor& a_kernel = params.a_kernel; + const at::Tensor& b_kernel = params.b_kernel; + + TORCH_CHECK(mixed_qkv_conv.is_cuda(), "mixed_qkv_conv must be a CUDA tensor."); + TORCH_CHECK(mixed_qkv_conv.is_contiguous(), "mixed_qkv_conv must be contiguous."); + TORCH_CHECK( + mixed_qkv_conv.scalar_type() == a.scalar_type() && + mixed_qkv_conv.scalar_type() == b.scalar_type() && + mixed_qkv_conv.scalar_type() == q_rep.scalar_type() && + mixed_qkv_conv.scalar_type() == k_rep.scalar_type() && + mixed_qkv_conv.scalar_type() == v.scalar_type() && + mixed_qkv_conv.scalar_type() == a_kernel.scalar_type() && + mixed_qkv_conv.scalar_type() == b_kernel.scalar_type(), + "All layout decode tensors must share the same dtype."); + + check_tensor_shape_2d(a, "a"); + check_tensor_shape_2d(b, "b"); + + const int64_t batch_size = mixed_qkv_conv.size(0); + const int64_t local_v_heads = a.size(1); + TORCH_CHECK(is_supported_local_v_heads(static_cast(local_v_heads)), "local V heads must be one of {48, 24, 12, 6}, got ", local_v_heads, "."); + const int local_qk_heads = local_qk_heads_from_v_heads(static_cast(local_v_heads)); + const int local_mixed_dim = local_mixed_qkv_dim(local_qk_heads, static_cast(local_v_heads)); + TORCH_CHECK( + mixed_qkv_conv.dim() == 2 && mixed_qkv_conv.size(1) == local_mixed_dim, + "mixed_qkv_conv must have shape [N, local_conv_dim=", local_mixed_dim, "], got ", + mixed_qkv_conv.sizes(), "."); + const at::Device device = mixed_qkv_conv.device(); + + check_tensor_device(a, "a", device); + check_tensor_device(b, "b", device); + check_tensor_device(q_rep, "q_rep", device); + check_tensor_device(k_rep, "k_rep", device); + check_tensor_device(v, "v", device); + check_tensor_device(a_kernel, "a_kernel", device); + check_tensor_device(b_kernel, "b_kernel", device); + + TORCH_CHECK(q_rep.is_contiguous(), "q_rep must be contiguous."); + TORCH_CHECK(k_rep.is_contiguous(), "k_rep must be contiguous."); + TORCH_CHECK(v.is_contiguous(), "v must be contiguous."); + TORCH_CHECK(a_kernel.is_contiguous(), "a_kernel must be contiguous."); + TORCH_CHECK(b_kernel.is_contiguous(), "b_kernel must be contiguous."); + + TORCH_CHECK( + q_rep.dim() == 3 && q_rep.sizes() == at::IntArrayRef({batch_size, local_v_heads, kHeadDimQK}), + "q_rep must have shape [N, local_v_heads, 128]."); + TORCH_CHECK( + k_rep.dim() == 3 && k_rep.sizes() == at::IntArrayRef({batch_size, local_v_heads, kHeadDimQK}), + "k_rep must have shape [N, local_v_heads, 128]."); + TORCH_CHECK( + v.dim() == 3 && v.sizes() == at::IntArrayRef({batch_size, local_v_heads, kHeadDimV}), + "v must have shape [N, local_v_heads, 128]."); + TORCH_CHECK( + a_kernel.dim() == 2 && a_kernel.sizes() == at::IntArrayRef({batch_size, local_v_heads}), + "a_kernel must have shape [N, local_v_heads]."); + TORCH_CHECK( + b_kernel.dim() == 2 && b_kernel.sizes() == at::IntArrayRef({batch_size, local_v_heads}), + "b_kernel must have shape [N, local_v_heads]."); + + TORCH_CHECK(a.sizes() == at::IntArrayRef({batch_size, local_v_heads}), "a must have shape [N, local_v_heads]."); + TORCH_CHECK(b.sizes() == at::IntArrayRef({batch_size, local_v_heads}), "b must have shape [N, local_v_heads]."); + + const at::cuda::OptionalCUDAGuard device_guard(device); + cudaStream_t stream = at::cuda::getDefaultCUDAStream(device.index()); + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + mixed_qkv_conv.scalar_type(), + "qwen35_layout_decode_kernel_cute", + [&] { + switch (local_v_heads) { + case 48: + launch_layout_decode_for_heads(stream, mixed_qkv_conv.data_ptr(), a.data_ptr(), b.data_ptr(), q_rep.data_ptr(), k_rep.data_ptr(), v.data_ptr(), a_kernel.data_ptr(), b_kernel.data_ptr(), batch_size); + break; + case 24: + launch_layout_decode_for_heads(stream, mixed_qkv_conv.data_ptr(), a.data_ptr(), b.data_ptr(), q_rep.data_ptr(), k_rep.data_ptr(), v.data_ptr(), a_kernel.data_ptr(), b_kernel.data_ptr(), batch_size); + break; + case 12: + launch_layout_decode_for_heads(stream, mixed_qkv_conv.data_ptr(), a.data_ptr(), b.data_ptr(), q_rep.data_ptr(), k_rep.data_ptr(), v.data_ptr(), a_kernel.data_ptr(), b_kernel.data_ptr(), batch_size); + break; + case 6: + launch_layout_decode_for_heads(stream, mixed_qkv_conv.data_ptr(), a.data_ptr(), b.data_ptr(), q_rep.data_ptr(), k_rep.data_ptr(), v.data_ptr(), a_kernel.data_ptr(), b_kernel.data_ptr(), batch_size); + break; + } + }); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +} // namespace cula::qwen35::decode diff --git a/csrc/qwen35/decode/qwen35_layout_kernel.hpp b/csrc/qwen35/decode/qwen35_layout_kernel.hpp new file mode 100644 index 00000000..e43ffc00 --- /dev/null +++ b/csrc/qwen35/decode/qwen35_layout_kernel.hpp @@ -0,0 +1,131 @@ +// Copyright 2025-2026 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "qwen35_decode_common.cuh" + +#include +#include +#include + +namespace cula::qwen35::decode { + +using namespace cute; + +template +CUTE_DEVICE void copy_vec_contiguous( + scalar_t* __restrict__ dst, + const scalar_t* __restrict__ src) { + constexpr int kBytes = sizeof(scalar_t) * kVec; + if constexpr (kBytes == 16 || kBytes == 8) { + using VecType = cutlass::AlignedArray; + auto dst_addr = reinterpret_cast(dst); + auto src_addr = reinterpret_cast(src); + if ((dst_addr % alignof(VecType) == 0) && (src_addr % alignof(VecType) == 0)) { + *reinterpret_cast(dst) = *reinterpret_cast(src); + return; + } + } + +#pragma unroll + for (int i = 0; i < kVec; ++i) { + dst[i] = src[i]; + } +} + +template +__global__ void qwen35_layout_decode_kernel_cute( + const scalar_t* __restrict__ mixed_qkv_conv, + const scalar_t* __restrict__ a, + const scalar_t* __restrict__ b, + scalar_t* __restrict__ q_rep, + scalar_t* __restrict__ k_rep, + scalar_t* __restrict__ v_out, + scalar_t* __restrict__ a_kernel, + scalar_t* __restrict__ b_kernel, + int64_t token_count) { + using Shape = Qwen35DecodeLocalShape; + static_assert(kLocalQKHeads == Shape::kLocalQKHeads); + constexpr int kRepeatFactor = Shape::kRepeatFactor; + constexpr int kLocalQDim = Shape::kLocalQDim; + constexpr int kLocalKDim = Shape::kLocalKDim; + constexpr int kLocalMixedQKVDim = Shape::kLocalMixedQKVDim; + // TODO(qwen35-layout-opt): + // - Re-evaluate whether Vec=8 is profitable for bf16/fp16 on the target GPUs. + // - Push more of the q/k repeat mapping into compile-time CuTe layout transforms. + // - Revisit whether a shared-memory staging path is worthwhile after profiling. + // - Consider widening the a/b writeback path if it shows up in profiling. + constexpr int kVec = Shape::kLayoutVec; + static_assert(kHeadDimV % kVec == 0); + static_assert(kHeadDimQK == kHeadDimV); + static_assert(kHeadDimQK % kVec == 0); + + const int token_idx = static_cast(blockIdx.y); + const int hv = static_cast(blockIdx.x); + const int tid = static_cast(threadIdx.x); + + if (token_idx >= token_count || hv >= kLocalVHeads) { + return; + } + + const int mapped_h = hv / kRepeatFactor; + + auto qk_src_layout = make_layout( + make_shape(Int{}, Int{}), + make_stride(Int{}, Int<1>{})); + auto v_src_layout = make_layout( + make_shape(Int{}, Int{}), + make_stride(Int{}, Int<1>{})); + auto out_layout = make_layout( + make_shape(Int{}, Int{}), + make_stride(Int{}, Int<1>{})); + auto head_layout = make_layout(make_shape(Int{}), make_stride(Int<1>{})); + + const scalar_t* token_ptr = mixed_qkv_conv + static_cast(token_idx) * kLocalMixedQKVDim; + const scalar_t* q_src_ptr = token_ptr; + const scalar_t* k_src_ptr = token_ptr + kLocalQDim; + const scalar_t* v_src_ptr = token_ptr + kLocalQDim + kLocalKDim; + + scalar_t* q_dst_ptr = q_rep + static_cast(token_idx) * kLocalVHeads * kHeadDimQK; + scalar_t* k_dst_ptr = k_rep + static_cast(token_idx) * kLocalVHeads * kHeadDimQK; + scalar_t* v_dst_ptr = v_out + static_cast(token_idx) * kLocalVHeads * kHeadDimV; + + // Current version uses a direct GMEM->GMEM vector copy path. This keeps the + // kernel simple while already removing the scalar-copy bottleneck from the + // first draft. More aggressive staging/copy strategies should be driven by + // profiling rather than added pre-emptively. + for (int vec_idx = tid; vec_idx < kHeadDimV / kVec; vec_idx += blockDim.x) { + const int d = vec_idx * kVec; + const int q_src_idx = crd2idx(make_coord(mapped_h, d), qk_src_layout); + const int k_src_idx = crd2idx(make_coord(mapped_h, d), qk_src_layout); + const int v_src_idx = crd2idx(make_coord(hv, d), v_src_layout); + const int dst_idx = crd2idx(make_coord(hv, d), out_layout); + + copy_vec_contiguous(q_dst_ptr + dst_idx, q_src_ptr + q_src_idx); + copy_vec_contiguous(k_dst_ptr + dst_idx, k_src_ptr + k_src_idx); + copy_vec_contiguous(v_dst_ptr + dst_idx, v_src_ptr + v_src_idx); + } + + if (tid == 0) { + // TODO(qwen35-layout-opt): If a/b copy becomes measurable, fuse a wider + // per-head copy path here instead of scalar head writes. + const int head_idx = crd2idx(make_coord(hv), head_layout); + const int64_t token_head_offset = static_cast(token_idx) * kLocalVHeads + head_idx; + a_kernel[token_head_offset] = a[token_head_offset]; + b_kernel[token_head_offset] = b[token_head_offset]; + } +} + +} // namespace cula::qwen35::decode diff --git a/csrc/qwen35/decode/qwen35_scalar_kda_decode.cu b/csrc/qwen35/decode/qwen35_scalar_kda_decode.cu new file mode 100644 index 00000000..ab822752 --- /dev/null +++ b/csrc/qwen35/decode/qwen35_scalar_kda_decode.cu @@ -0,0 +1,293 @@ +// Copyright 2025-2026 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "qwen35_decode_common.cuh" +#include "qwen35_scalar_kda_kernel.hpp" + +#include +#include +#include +#include + +namespace cula::qwen35::decode { + +namespace { + +void check_tensor_device(const at::Tensor& tensor, const char* name, const at::Device& device) { + TORCH_CHECK(tensor.device() == device, name, " must be on device ", device, "."); +} + +template +void dispatch_scalar_decode_for_heads( + cudaStream_t stream, + const scalar_t* q_rep, + const scalar_t* k_rep, + const scalar_t* v, + const scalar_t* a_kernel, + const scalar_t* b_kernel, + const float* A_log, + const float* dt_bias, + float* recurrent_state, + const int32_t* pool_idx, + scalar_t* out, + int token_count) { + using Shape = Qwen35DecodeLocalShape; + kernel::launch_qwen35_scalar_kda_decode_kernel( + stream, + q_rep, + k_rep, + v, + a_kernel, + b_kernel, + A_log, + dt_bias, + recurrent_state, + pool_idx, + out, + token_count); +} + +template +void dispatch_layout_scalar_decode_for_heads( + cudaStream_t stream, + const scalar_t* mixed_qkv_conv, + const scalar_t* a, + const scalar_t* b, + const float* A_log, + const float* dt_bias, + float* recurrent_state, + const int32_t* pool_idx, + scalar_t* out, + int token_count) { + using Shape = Qwen35DecodeLocalShape; + kernel::launch_qwen35_layout_scalar_kda_decode_kernel( + stream, + mixed_qkv_conv, + a, + b, + A_log, + dt_bias, + recurrent_state, + pool_idx, + out, + token_count); +} + +} // namespace + +void run_qwen35_scalar_kda_decode(ScalarKdaDecodeParams& params) { + const at::Tensor& q_rep = params.q_rep; + const at::Tensor& k_rep = params.k_rep; + const at::Tensor& v = params.v; + const at::Tensor& a_kernel = params.a_kernel; + const at::Tensor& b_kernel = params.b_kernel; + const at::Tensor& A_log = params.A_log; + const at::Tensor& dt_bias = params.dt_bias; + const at::Tensor& recurrent_state = params.recurrent_state; + const at::Tensor& pool_idx = params.pool_idx; + const at::Tensor& out = params.out; + + TORCH_CHECK(q_rep.is_cuda(), "q_rep must be a CUDA tensor."); + const at::Device device = q_rep.device(); + + check_tensor_device(k_rep, "k_rep", device); + check_tensor_device(v, "v", device); + check_tensor_device(a_kernel, "a_kernel", device); + check_tensor_device(b_kernel, "b_kernel", device); + check_tensor_device(A_log, "A_log", device); + check_tensor_device(dt_bias, "dt_bias", device); + check_tensor_device(recurrent_state, "recurrent_state", device); + check_tensor_device(pool_idx, "pool_idx", device); + check_tensor_device(out, "out", device); + + TORCH_CHECK(q_rep.is_contiguous(), "q_rep must be contiguous."); + TORCH_CHECK(k_rep.is_contiguous(), "k_rep must be contiguous."); + TORCH_CHECK(v.is_contiguous(), "v must be contiguous."); + TORCH_CHECK(a_kernel.is_contiguous(), "a_kernel must be contiguous."); + TORCH_CHECK(b_kernel.is_contiguous(), "b_kernel must be contiguous."); + TORCH_CHECK(A_log.is_contiguous(), "A_log must be contiguous."); + TORCH_CHECK(dt_bias.is_contiguous(), "dt_bias must be contiguous."); + TORCH_CHECK(recurrent_state.is_contiguous(), "recurrent_state must be contiguous."); + TORCH_CHECK(pool_idx.is_contiguous(), "pool_idx must be contiguous."); + TORCH_CHECK(out.is_contiguous(), "out must be contiguous."); + + TORCH_CHECK( + q_rep.scalar_type() == k_rep.scalar_type() && q_rep.scalar_type() == v.scalar_type() && + q_rep.scalar_type() == a_kernel.scalar_type() && q_rep.scalar_type() == b_kernel.scalar_type() && + q_rep.scalar_type() == out.scalar_type(), + "q_rep/k_rep/v/a_kernel/b_kernel/out must share the same dtype."); + TORCH_CHECK(A_log.scalar_type() == at::kFloat, "A_log must be float32."); + TORCH_CHECK(dt_bias.scalar_type() == at::kFloat, "dt_bias must be float32."); + TORCH_CHECK(recurrent_state.scalar_type() == at::kFloat, "recurrent_state must be float32."); + TORCH_CHECK(pool_idx.scalar_type() == at::kInt, "pool_idx must be int32."); + + TORCH_CHECK(q_rep.dim() == 3, "q_rep must have shape [N, local_v_heads, 128]."); + const int64_t token_count = q_rep.size(0); + const int64_t local_v_heads = q_rep.size(1); + TORCH_CHECK(is_supported_local_v_heads(static_cast(local_v_heads)), "local V heads must be one of {48, 24, 12, 6}, got ", local_v_heads, "."); + TORCH_CHECK( + q_rep.sizes() == at::IntArrayRef({token_count, local_v_heads, kHeadDimQK}), + "q_rep must have shape [N, local_v_heads, 128]."); + TORCH_CHECK( + k_rep.dim() == 3 && k_rep.sizes() == at::IntArrayRef({token_count, local_v_heads, kHeadDimQK}), + "k_rep must have shape [N, local_v_heads, 128]."); + TORCH_CHECK( + v.dim() == 3 && v.sizes() == at::IntArrayRef({token_count, local_v_heads, kHeadDimV}), + "v must have shape [N, local_v_heads, 128]."); + TORCH_CHECK( + a_kernel.dim() == 2 && a_kernel.sizes() == at::IntArrayRef({token_count, local_v_heads}), + "a_kernel must have shape [N, local_v_heads]."); + TORCH_CHECK( + b_kernel.dim() == 2 && b_kernel.sizes() == at::IntArrayRef({token_count, local_v_heads}), + "b_kernel must have shape [N, local_v_heads]."); + TORCH_CHECK(A_log.dim() == 1 && A_log.size(0) == local_v_heads, "A_log must have shape [local_v_heads]."); + TORCH_CHECK(dt_bias.dim() == 1 && dt_bias.size(0) == local_v_heads, "dt_bias must have shape [local_v_heads]."); + TORCH_CHECK( + recurrent_state.dim() == 4 && + recurrent_state.size(1) == local_v_heads && + recurrent_state.size(2) == kHeadDimQK && + recurrent_state.size(3) == kHeadDimV, + "recurrent_state must have shape [pool, local_v_heads, 128, 128]."); + TORCH_CHECK(pool_idx.dim() == 1 && pool_idx.size(0) == token_count, "pool_idx must have shape [N]."); + TORCH_CHECK( + out.dim() == 3 && out.sizes() == at::IntArrayRef({token_count, local_v_heads, kHeadDimV}), + "out must have shape [N, local_v_heads, 128]."); + + const at::cuda::OptionalCUDAGuard device_guard(device); + cudaStream_t stream = at::cuda::getDefaultCUDAStream(device.index()); + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + q_rep.scalar_type(), + "launch_qwen35_scalar_kda_decode_kernel", + [&] { + switch (local_v_heads) { + case 48: + dispatch_scalar_decode_for_heads(stream, q_rep.data_ptr(), k_rep.data_ptr(), v.data_ptr(), a_kernel.data_ptr(), b_kernel.data_ptr(), A_log.data_ptr(), dt_bias.data_ptr(), recurrent_state.data_ptr(), pool_idx.data_ptr(), out.data_ptr(), static_cast(token_count)); + break; + case 24: + dispatch_scalar_decode_for_heads(stream, q_rep.data_ptr(), k_rep.data_ptr(), v.data_ptr(), a_kernel.data_ptr(), b_kernel.data_ptr(), A_log.data_ptr(), dt_bias.data_ptr(), recurrent_state.data_ptr(), pool_idx.data_ptr(), out.data_ptr(), static_cast(token_count)); + break; + case 12: + dispatch_scalar_decode_for_heads(stream, q_rep.data_ptr(), k_rep.data_ptr(), v.data_ptr(), a_kernel.data_ptr(), b_kernel.data_ptr(), A_log.data_ptr(), dt_bias.data_ptr(), recurrent_state.data_ptr(), pool_idx.data_ptr(), out.data_ptr(), static_cast(token_count)); + break; + case 6: + dispatch_scalar_decode_for_heads(stream, q_rep.data_ptr(), k_rep.data_ptr(), v.data_ptr(), a_kernel.data_ptr(), b_kernel.data_ptr(), A_log.data_ptr(), dt_bias.data_ptr(), recurrent_state.data_ptr(), pool_idx.data_ptr(), out.data_ptr(), static_cast(token_count)); + break; + } + }); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +void run_qwen35_layout_scalar_kda_decode(LayoutScalarKdaDecodeParams& params) { + const at::Tensor& mixed_qkv_conv = params.mixed_qkv_conv; + const at::Tensor& a = params.a; + const at::Tensor& b = params.b; + const at::Tensor& A_log = params.A_log; + const at::Tensor& dt_bias = params.dt_bias; + const at::Tensor& recurrent_state = params.recurrent_state; + const at::Tensor& pool_idx = params.pool_idx; + const at::Tensor& out = params.out; + + TORCH_CHECK(mixed_qkv_conv.is_cuda(), "mixed_qkv_conv must be a CUDA tensor."); + const at::Device device = mixed_qkv_conv.device(); + + check_tensor_device(a, "a", device); + check_tensor_device(b, "b", device); + check_tensor_device(A_log, "A_log", device); + check_tensor_device(dt_bias, "dt_bias", device); + check_tensor_device(recurrent_state, "recurrent_state", device); + check_tensor_device(pool_idx, "pool_idx", device); + check_tensor_device(out, "out", device); + + TORCH_CHECK(mixed_qkv_conv.is_contiguous(), "mixed_qkv_conv must be contiguous."); + TORCH_CHECK(a.is_contiguous(), "a must be contiguous."); + TORCH_CHECK(b.is_contiguous(), "b must be contiguous."); + TORCH_CHECK(A_log.is_contiguous(), "A_log must be contiguous."); + TORCH_CHECK(dt_bias.is_contiguous(), "dt_bias must be contiguous."); + TORCH_CHECK(recurrent_state.is_contiguous(), "recurrent_state must be contiguous."); + TORCH_CHECK(pool_idx.is_contiguous(), "pool_idx must be contiguous."); + TORCH_CHECK(out.is_contiguous(), "out must be contiguous."); + + TORCH_CHECK( + mixed_qkv_conv.scalar_type() == a.scalar_type() && + mixed_qkv_conv.scalar_type() == b.scalar_type() && + mixed_qkv_conv.scalar_type() == out.scalar_type(), + "mixed_qkv_conv/a/b/out must share the same dtype."); + TORCH_CHECK( + mixed_qkv_conv.scalar_type() == at::kHalf || mixed_qkv_conv.scalar_type() == at::kBFloat16, + "mixed_qkv_conv must be float16 or bfloat16."); + TORCH_CHECK(A_log.scalar_type() == at::kFloat, "A_log must be float32."); + TORCH_CHECK(dt_bias.scalar_type() == at::kFloat, "dt_bias must be float32."); + TORCH_CHECK(recurrent_state.scalar_type() == at::kFloat, "recurrent_state must be float32."); + TORCH_CHECK(pool_idx.scalar_type() == at::kInt, "pool_idx must be int32."); + + TORCH_CHECK(mixed_qkv_conv.dim() == 2, "mixed_qkv_conv must have shape [N, local_conv_dim]."); + TORCH_CHECK(a.dim() == 2, "a must have shape [N, local_v_heads]."); + const int64_t token_count = mixed_qkv_conv.size(0); + const int64_t local_v_heads = a.size(1); + TORCH_CHECK(is_supported_local_v_heads(static_cast(local_v_heads)), "local V heads must be one of {48, 24, 12, 6}, got ", local_v_heads, "."); + const int local_qk_heads = local_qk_heads_from_v_heads(static_cast(local_v_heads)); + const int local_mixed_dim = local_mixed_qkv_dim(local_qk_heads, static_cast(local_v_heads)); + TORCH_CHECK( + mixed_qkv_conv.sizes() == at::IntArrayRef({token_count, local_mixed_dim}), + "mixed_qkv_conv must have shape [N, local_conv_dim]."); + TORCH_CHECK( + a.sizes() == at::IntArrayRef({token_count, local_v_heads}), + "a must have shape [N, local_v_heads]."); + TORCH_CHECK( + b.dim() == 2 && b.sizes() == at::IntArrayRef({token_count, local_v_heads}), + "b must have shape [N, local_v_heads]."); + TORCH_CHECK(A_log.dim() == 1 && A_log.size(0) == local_v_heads, "A_log must have shape [local_v_heads]."); + TORCH_CHECK(dt_bias.dim() == 1 && dt_bias.size(0) == local_v_heads, "dt_bias must have shape [local_v_heads]."); + TORCH_CHECK( + recurrent_state.dim() == 4 && + recurrent_state.size(1) == local_v_heads && + recurrent_state.size(2) == kHeadDimQK && + recurrent_state.size(3) == kHeadDimV, + "recurrent_state must have shape [pool, local_v_heads, 128, 128]."); + TORCH_CHECK(pool_idx.dim() == 1 && pool_idx.size(0) == token_count, "pool_idx must have shape [N]."); + TORCH_CHECK( + out.dim() == 3 && out.sizes() == at::IntArrayRef({token_count, local_v_heads, kHeadDimV}), + "out must have shape [N, local_v_heads, 128]."); + + const at::cuda::OptionalCUDAGuard device_guard(device); + cudaStream_t stream = at::cuda::getDefaultCUDAStream(device.index()); + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + mixed_qkv_conv.scalar_type(), + "launch_qwen35_layout_scalar_kda_decode_kernel", + [&] { + switch (local_v_heads) { + case 48: + dispatch_layout_scalar_decode_for_heads(stream, mixed_qkv_conv.data_ptr(), a.data_ptr(), b.data_ptr(), A_log.data_ptr(), dt_bias.data_ptr(), recurrent_state.data_ptr(), pool_idx.data_ptr(), out.data_ptr(), static_cast(token_count)); + break; + case 24: + dispatch_layout_scalar_decode_for_heads(stream, mixed_qkv_conv.data_ptr(), a.data_ptr(), b.data_ptr(), A_log.data_ptr(), dt_bias.data_ptr(), recurrent_state.data_ptr(), pool_idx.data_ptr(), out.data_ptr(), static_cast(token_count)); + break; + case 12: + dispatch_layout_scalar_decode_for_heads(stream, mixed_qkv_conv.data_ptr(), a.data_ptr(), b.data_ptr(), A_log.data_ptr(), dt_bias.data_ptr(), recurrent_state.data_ptr(), pool_idx.data_ptr(), out.data_ptr(), static_cast(token_count)); + break; + case 6: + dispatch_layout_scalar_decode_for_heads(stream, mixed_qkv_conv.data_ptr(), a.data_ptr(), b.data_ptr(), A_log.data_ptr(), dt_bias.data_ptr(), recurrent_state.data_ptr(), pool_idx.data_ptr(), out.data_ptr(), static_cast(token_count)); + break; + } + }); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +} // namespace cula::qwen35::decode diff --git a/csrc/qwen35/decode/qwen35_scalar_kda_kernel.hpp b/csrc/qwen35/decode/qwen35_scalar_kda_kernel.hpp new file mode 100644 index 00000000..ed67ccc6 --- /dev/null +++ b/csrc/qwen35/decode/qwen35_scalar_kda_kernel.hpp @@ -0,0 +1,402 @@ +// Copyright 2025-2026 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "qwen35_decode_common.cuh" +#include "qwen35_scalar_kda_mainloop.hpp" + +#include + +namespace cula::qwen35::decode::kernel { + +using namespace cute; + +template +struct Qwen35ScalarKdaDecodeKernel { + using Shape = cula::qwen35::decode::Qwen35DecodeLocalShape; + static_assert(kLocalQKHeads == Shape::kLocalQKHeads); + // Decode-first design: + // - 1 CTA owns 1 (token_idx, hv) + // - 1 warpgroup (128 threads) per CTA + // - recurrent state stays fp32 and is traversed as 16x16 tiles over the + // internal [V, K] view + // - the intended optimized path is fp32 FFMA on CUDA cores, not a forced + // Tensor Core lowering + static constexpr int kThreads = Shape::kKdaThreads; + static constexpr int kWarpGroupThreads = Shape::kKdaThreads; + static constexpr int kTileV = Shape::kKdaTileV; + static constexpr int kTileK = Shape::kKdaTileK; + static constexpr int kTilesPerV = kHeadDimV / kTileV; + static constexpr int kTilesPerK = kHeadDimQK / kTileK; + + static_assert(kLocalQKHeads < kLocalVHeads); + static_assert(kHeadDimQK == 128); + static_assert(kHeadDimV == 128); + static_assert(kThreads == kWarpGroupThreads); + static_assert(kHeadDimV % kTileV == 0); + static_assert(kHeadDimQK % kTileK == 0); + + struct SharedStorage { + // Shared staging plan for the fp32 decode path: + // - q/k/v are staged once per CTA + // - proj/out intermediates remain in fp32 + // - recurrent state itself remains in fp32 global storage + alignas(16) float q_smem[kHeadDimQK]; + alignas(16) float k_smem[kHeadDimQK]; + alignas(16) scalar_t v_smem[kHeadDimV]; + alignas(16) float norm_smem[2]; + alignas(16) float state_smem[kHeadDimQK * kTileV]; + alignas(16) float proj_smem[kHeadDimV]; + alignas(16) float out_smem[kHeadDimV]; + }; + + static dim3 block_shape() { + return dim3(kThreads, 1, 1); + } + + static dim3 grid_shape(int token_count) { + // One block owns one V tile for one (token_idx, hv) pair. + constexpr int kNumVTiles = (kHeadDimV + kTileV - 1) / kTileV; + return dim3( + static_cast(kNumVTiles), + static_cast(Shape::kLocalVHeads), + static_cast(token_count)); + } + + template + CUTE_DEVICE static void run_device( + const scalar_t* __restrict__ q_rep, + const scalar_t* __restrict__ k_rep, + const scalar_t* __restrict__ v, + const scalar_t* __restrict__ a_kernel, + const scalar_t* __restrict__ b_kernel, + const float* __restrict__ A_log, + const float* __restrict__ dt_bias, + float* __restrict__ recurrent_state, + const int32_t* __restrict__ pool_idx, + scalar_t* __restrict__ out, + int token_count, + SharedStorage& storage) { + const int v_tile = static_cast(blockIdx.x); + const int hv = static_cast(blockIdx.y); + const int token_idx = static_cast(blockIdx.z); + const int tid = static_cast(threadIdx.x); + if (token_idx >= token_count || hv >= kLocalVHeads) { + return; + } + + // Internal tensor-view contract fixed for the first implementation pass: + // + // 1. q_rep / k_rep / v / out stay in their external contiguous layouts: + // - q_rep : [N, HV, K] with stride (HV*K, K, 1) + // - k_rep : [N, HV, K] with stride (HV*K, K, 1) + // - v : [N, HV, V] with stride (HV*V, V, 1) + // - out : [N, HV, V] with stride (HV*V, V, 1) + // + // 2. a_kernel / b_kernel are treated as: + // - [N, HV] with stride (HV, 1) + // + // 3. A_log / dt_bias are treated as: + // - [HV] with stride (1) + // + // 4. recurrent_state keeps the external physical storage contract: + // - [pool, HV, K, V] + // but the kernel's main computation will use an internal VK view: + // - [pool, HV, V, K] + // + // This lets the recurrent update consume one V-row of state against q/k + // more naturally in the first mainloop design, while preserving the + // existing external state ABI. + // + // The current block owns exactly one (token_idx, hv) pair. That means one + // warpgroup-sized CTA updates one 128x128 recurrent-state tile for one + // v-head. + // + // TODO(qwen35-scalar-kda-opt): + // - Likely next optimization path: keep one CTA per (token_idx, hv), but + // tile the 128x128 state more aggressively inside the block (for example + // along V tiles or KxV subtiles assigned per warp). + // - More complex alternative: split one (token_idx, hv) tile across + // multiple CTAs and coordinate updates. Not a first-pass target. + // - After the fp32 decode path is stable, evaluate warp specialization: + // dedicated producer/load warp(s) vs consumer/compute warp(s), instead + // of introducing that complexity before the math path itself is stable. + + auto q_layout = make_layout( + make_shape(token_count, Int{}, Int{}), + make_stride(kLocalVHeads * kHeadDimQK, kHeadDimQK, Int<1>{})); + auto v_layout = make_layout( + make_shape(token_count, Int{}, Int{}), + make_stride(kLocalVHeads * kHeadDimV, kHeadDimV, Int<1>{})); + auto head_layout = make_layout( + make_shape(token_count, Int{}), + make_stride(kLocalVHeads, Int<1>{})); + auto hv_layout = make_layout(make_shape(Int{}), make_stride(Int<1>{})); + auto state_layout_kv = make_layout( + make_shape(_, Int{}, Int{}, Int{}), + make_stride(Int{} * kHeadDimQK * kHeadDimV, kHeadDimQK * kHeadDimV, kHeadDimV, Int<1>{})); + auto state_layout_vk = make_layout( + make_shape(_, Int{}, Int{}, Int{}), + make_stride(Int{} * kHeadDimQK * kHeadDimV, kHeadDimQK * kHeadDimV, Int<1>{}, kHeadDimV)); + + auto gQ = make_tensor(make_gmem_ptr(q_rep), q_layout); + auto gK = make_tensor(make_gmem_ptr(k_rep), q_layout); + auto gV = make_tensor(make_gmem_ptr(v), v_layout); + auto gO = make_tensor(make_gmem_ptr(out), v_layout); + auto gA = make_tensor(make_gmem_ptr(a_kernel), head_layout); + auto gB = make_tensor(make_gmem_ptr(b_kernel), head_layout); + auto gAlog = make_tensor(make_gmem_ptr(A_log), hv_layout); + auto gDt = make_tensor(make_gmem_ptr(dt_bias), hv_layout); + auto gH_kv = make_tensor(make_gmem_ptr(recurrent_state), state_layout_kv); + auto gH_vk = make_tensor(make_gmem_ptr(recurrent_state), state_layout_vk); + (void)gH_kv; // Keep the physical KV view documented and available. + + const int state_row = pool_idx[token_idx]; + if (state_row < 0) { + return; + } + + auto q_vec = gQ(token_idx, hv, _); + auto k_vec = gK(token_idx, hv, _); + auto v_vec = gV(token_idx, hv, _); + auto out_vec = gO(token_idx, hv, _); + auto a_scalar = gA(token_idx, hv); + auto b_scalar = gB(token_idx, hv); + auto A_log_scalar = gAlog(hv); + auto dt_bias_scalar = gDt(hv); + auto state_vk = gH_vk(state_row, hv, _, _); + + Mainloop::run( + q_vec, + k_vec, + v_vec, + a_scalar, + b_scalar, + A_log_scalar, + dt_bias_scalar, + state_vk, + out_vec, + storage, + v_tile * kTileV, + tid, + kThreads); + } + + template + CUTE_DEVICE static void run_layout_device( + const scalar_t* __restrict__ mixed_qkv_conv, + const scalar_t* __restrict__ a, + const scalar_t* __restrict__ b, + const float* __restrict__ A_log, + const float* __restrict__ dt_bias, + float* __restrict__ recurrent_state, + const int32_t* __restrict__ pool_idx, + scalar_t* __restrict__ out, + int token_count, + SharedStorage& storage) { + constexpr int kRepeatFactor = Shape::kRepeatFactor; + constexpr int kLocalQDim = Shape::kLocalQDim; + constexpr int kLocalKDim = Shape::kLocalKDim; + constexpr int kLocalMixedQKVDim = Shape::kLocalMixedQKVDim; + + const int v_tile = static_cast(blockIdx.x); + const int hv = static_cast(blockIdx.y); + const int token_idx = static_cast(blockIdx.z); + const int tid = static_cast(threadIdx.x); + if (token_idx >= token_count || hv >= kLocalVHeads) { + return; + } + + const int mapped_h = hv / kRepeatFactor; + + auto qk_src_layout = make_layout( + make_shape(token_count, Int{}, Int{}), + make_stride(kLocalMixedQKVDim, kHeadDimQK, Int<1>{})); + auto v_src_layout = make_layout( + make_shape(token_count, Int{}, Int{}), + make_stride(kLocalMixedQKVDim, kHeadDimV, Int<1>{})); + auto out_layout = make_layout( + make_shape(token_count, Int{}, Int{}), + make_stride(kLocalVHeads * kHeadDimV, kHeadDimV, Int<1>{})); + auto head_layout = make_layout( + make_shape(token_count, Int{}), + make_stride(kLocalVHeads, Int<1>{})); + auto hv_layout = make_layout(make_shape(Int{}), make_stride(Int<1>{})); + auto state_layout_kv = make_layout( + make_shape(_, Int{}, Int{}, Int{}), + make_stride(Int{} * kHeadDimQK * kHeadDimV, kHeadDimQK * kHeadDimV, kHeadDimV, Int<1>{})); + auto state_layout_vk = make_layout( + make_shape(_, Int{}, Int{}, Int{}), + make_stride(Int{} * kHeadDimQK * kHeadDimV, kHeadDimQK * kHeadDimV, Int<1>{}, kHeadDimV)); + + const scalar_t* q_src = mixed_qkv_conv; + const scalar_t* k_src = mixed_qkv_conv + kLocalQDim; + const scalar_t* v_src = mixed_qkv_conv + kLocalQDim + kLocalKDim; + + auto gQ = make_tensor(make_gmem_ptr(q_src), qk_src_layout); + auto gK = make_tensor(make_gmem_ptr(k_src), qk_src_layout); + auto gV = make_tensor(make_gmem_ptr(v_src), v_src_layout); + auto gO = make_tensor(make_gmem_ptr(out), out_layout); + auto gA = make_tensor(make_gmem_ptr(a), head_layout); + auto gB = make_tensor(make_gmem_ptr(b), head_layout); + auto gAlog = make_tensor(make_gmem_ptr(A_log), hv_layout); + auto gDt = make_tensor(make_gmem_ptr(dt_bias), hv_layout); + auto gH_kv = make_tensor(make_gmem_ptr(recurrent_state), state_layout_kv); + auto gH_vk = make_tensor(make_gmem_ptr(recurrent_state), state_layout_vk); + (void)gH_kv; + + const int state_row = pool_idx[token_idx]; + if (state_row < 0) { + return; + } + + auto q_vec = gQ(token_idx, mapped_h, _); + auto k_vec = gK(token_idx, mapped_h, _); + auto v_vec = gV(token_idx, hv, _); + auto out_vec = gO(token_idx, hv, _); + auto a_scalar = gA(token_idx, hv); + auto b_scalar = gB(token_idx, hv); + auto A_log_scalar = gAlog(hv); + auto dt_bias_scalar = gDt(hv); + auto state_vk = gH_vk(state_row, hv, _, _); + + Mainloop::run( + q_vec, + k_vec, + v_vec, + a_scalar, + b_scalar, + A_log_scalar, + dt_bias_scalar, + state_vk, + out_vec, + storage, + v_tile * kTileV, + tid, + kThreads); + } +}; + +template > +__global__ void qwen35_scalar_kda_decode_kernel( + const scalar_t* __restrict__ q_rep, + const scalar_t* __restrict__ k_rep, + const scalar_t* __restrict__ v, + const scalar_t* __restrict__ a_kernel, + const scalar_t* __restrict__ b_kernel, + const float* __restrict__ A_log, + const float* __restrict__ dt_bias, + float* __restrict__ recurrent_state, + const int32_t* __restrict__ pool_idx, + scalar_t* __restrict__ out, + int token_count) { + __shared__ typename Qwen35ScalarKdaDecodeKernel::SharedStorage storage; + Qwen35ScalarKdaDecodeKernel::template run_device( + q_rep, + k_rep, + v, + a_kernel, + b_kernel, + A_log, + dt_bias, + recurrent_state, + pool_idx, + out, + token_count, + storage); +} + +template > +void launch_qwen35_scalar_kda_decode_kernel( + cudaStream_t stream, + const scalar_t* q_rep, + const scalar_t* k_rep, + const scalar_t* v, + const scalar_t* a_kernel, + const scalar_t* b_kernel, + const float* A_log, + const float* dt_bias, + float* recurrent_state, + const int32_t* pool_idx, + scalar_t* out, + int token_count) { + auto grid = Qwen35ScalarKdaDecodeKernel::grid_shape(token_count); + auto block = Qwen35ScalarKdaDecodeKernel::block_shape(); + qwen35_scalar_kda_decode_kernel<<>>( + q_rep, + k_rep, + v, + a_kernel, + b_kernel, + A_log, + dt_bias, + recurrent_state, + pool_idx, + out, + token_count); +} + +template > +__global__ void qwen35_layout_scalar_kda_decode_kernel( + const scalar_t* __restrict__ mixed_qkv_conv, + const scalar_t* __restrict__ a, + const scalar_t* __restrict__ b, + const float* __restrict__ A_log, + const float* __restrict__ dt_bias, + float* __restrict__ recurrent_state, + const int32_t* __restrict__ pool_idx, + scalar_t* __restrict__ out, + int token_count) { + __shared__ typename Qwen35ScalarKdaDecodeKernel::SharedStorage storage; + Qwen35ScalarKdaDecodeKernel::template run_layout_device( + mixed_qkv_conv, + a, + b, + A_log, + dt_bias, + recurrent_state, + pool_idx, + out, + token_count, + storage); +} + +template > +void launch_qwen35_layout_scalar_kda_decode_kernel( + cudaStream_t stream, + const scalar_t* mixed_qkv_conv, + const scalar_t* a, + const scalar_t* b, + const float* A_log, + const float* dt_bias, + float* recurrent_state, + const int32_t* pool_idx, + scalar_t* out, + int token_count) { + auto grid = Qwen35ScalarKdaDecodeKernel::grid_shape(token_count); + auto block = Qwen35ScalarKdaDecodeKernel::block_shape(); + qwen35_layout_scalar_kda_decode_kernel<<>>( + mixed_qkv_conv, + a, + b, + A_log, + dt_bias, + recurrent_state, + pool_idx, + out, + token_count); +} + +} // namespace cula::qwen35::decode::kernel diff --git a/csrc/qwen35/decode/qwen35_scalar_kda_mainloop.hpp b/csrc/qwen35/decode/qwen35_scalar_kda_mainloop.hpp new file mode 100644 index 00000000..206cbf91 --- /dev/null +++ b/csrc/qwen35/decode/qwen35_scalar_kda_mainloop.hpp @@ -0,0 +1,504 @@ +// Copyright 2025-2026 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "qwen35_decode_common.cuh" + +#include +#include +#include +#include + +namespace cula::qwen35::decode::kernel { + +using namespace cute; + +template +struct Qwen35ScalarKdaDecodeMainloop { + // Decode design decision: + // - recurrent_state remains fp32 both physically and mathematically + // - decode is treated as a register-level recurrent GEMV/rank-1-update + // problem, not as a Tensor Core GEMM problem + // - you can think of the target implementation style as a + // flash_linear_decode_kernel: pure CUDA Core math, warp-shuffle vector + // sharing, and fp32 state kept live as long as possible during one token + // update + // + // Reason: + // - state participates in a long recurrent chain; lowering master-state + // precision is risky and usually not worth it + // - decode operates on a single-token q/k vector, so the dominant kernels + // are GEMV-like: + // proj = state @ k + // out = state' @ q + // This is typically register / memory bound rather than Tensor-Core bound + // + // Practical consequence: + // - the first production-worthy decode path should be built around fp32 FFMA + // on CUDA cores + // - tile structure is still useful, but it should serve register ownership, + // reduction, and cache behavior instead of forcing a GMMA lowering + // + // The remaining work for this kernel is therefore: + // 1. tighten thread ownership of the fp32 state tile + // 2. optimize proj / update / out reductions + // 3. evaluate warp-specialized load/compute roles only after the fp32 path + // is stable and measured + static constexpr int kTileV = 32; + static constexpr int kTileK = kHeadDimQK; + static constexpr int kTilesPerV = kHeadDimV / kTileV; + static constexpr int kTilesPerK = kHeadDimQK / kTileK; + static constexpr int kWarpSize = 32; + static constexpr int kRowsPerTile = kTileV; + static constexpr int kWarpsPerCta = 4; + static constexpr int kRowsPerWarp = kWarpSize; + static constexpr int kRowsPerThread = 1; + static constexpr int kThreadsPerVRow = 4; + static constexpr int kKPerThread = kHeadDimQK / kThreadsPerVRow; + + static_assert(kHeadDimV == 128); + static_assert(kHeadDimQK == 128); + static_assert(kWarpsPerCta * kRowsPerWarp == kHeadDimV); + static_assert(kTileV * kThreadsPerVRow == kWarpsPerCta * kWarpSize); + static_assert(kHeadDimQK % kThreadsPerVRow == 0); + + // First concrete decode threading plan: + // + // - 1 CTA = 1 (token, hv) + // - 128 threads = 4 warps + // - 1 thread owns exactly 1 V-row of the 128x128 recurrent state + // - Therefore one CTA covers all 128 V-rows exactly once + // + // For the owned row, the thread streams over K in 16-wide tiles: + // + // state_row[0:15] -> registers + // state_row[16:31] -> registers + // ... + // state_row[112:127]-> registers + // + // This means the first concrete fp32 path does NOT attempt to keep the + // whole 128-float row resident in registers at once. Instead it keeps the + // current K tile resident: + // + // - state_regs[16] : current fp32 state tile + // - k_regs[16] : current key tile + // - q_regs[16] : current query tile + // + // plus a handful of scalar accumulators: + // + // - proj_row + // - out_row + // - v_new_row + // - gate scalars + // + // This is a practical first step toward the user's desired "state stays in + // registers for the current token" behavior while keeping register pressure + // manageable. + // + // Reduction policy for this first concrete plan: + // + // - proj/out are row-local, so no warp reduction is required + // - each row is fully owned by one thread across all K tiles + // - warp shuffle is reserved for future vector-broadcast refinements if we + // decide to move q/k staging from shared memory into warp-register paths + struct ThreadRowPlan { + int warp_id; + int lane_id; + int v_row; + bool owns_row; + }; + + struct TileCoords { + int v_base; + int k_base; + }; + + CUTE_DEVICE static ThreadRowPlan make_thread_row_plan(int tid) { + const int warp_id = tid / kWarpSize; + const int lane_id = tid % kWarpSize; + const int v_row = warp_id * kRowsPerWarp + lane_id; + const bool owns_row = v_row < kHeadDimV; + return ThreadRowPlan{warp_id, lane_id, v_row, owns_row}; + } + + template + CUTE_DEVICE static void load_vec_tile_to_regs( + TensorVec const& vec, + TileCoords coords, + float (®s)[kTileK]) { +#pragma unroll + for (int kk = 0; kk < kTileK; ++kk) { + regs[kk] = static_cast(vec(coords.k_base + kk)); + } + } + + template + CUTE_DEVICE static void load_state_row_tile_to_regs( + TensorState const& state_vk, + int v_row, + TileCoords coords, + float (&state_regs)[kTileK]) { +#pragma unroll + for (int kk = 0; kk < kTileK; ++kk) { + state_regs[kk] = static_cast(state_vk(v_row, coords.k_base + kk)); + } + } + + template + CUTE_DEVICE static void store_state_row_tile_from_regs( + TensorState& state_vk, + int v_row, + TileCoords coords, + float const (&state_regs)[kTileK]) { +#pragma unroll + for (int kk = 0; kk < kTileK; ++kk) { + state_vk(v_row, coords.k_base + kk) = state_regs[kk]; + } + } + + struct RowTileProjPlan { + int v_base; + int k_base; + int warp_id; + int lane_id; + bool owns_row; + int row_in_tile; + int v_row; + }; + + CUTE_DEVICE static TileCoords make_tile_coords(int tile_v, int tile_k) { + return TileCoords{tile_v * kTileV, tile_k * kTileK}; + } + + CUTE_DEVICE static RowTileProjPlan make_row_tile_proj_plan( + TileCoords coords, + int warp_id, + int lane_id) { + const bool owns_row = lane_id < kTileV; + const int row_in_tile = lane_id; + const int v_row = coords.v_base + row_in_tile; + return RowTileProjPlan{ + coords.v_base, + coords.k_base, + warp_id, + lane_id, + owns_row, + row_in_tile, + v_row, + }; + } + + struct RowTileUpdatePlan { + TileCoords coords; + int warp_id; + int lane_id; + bool owns_row; + int row_in_tile; + int v_row; + }; + + CUTE_DEVICE static RowTileUpdatePlan make_row_tile_update_plan( + TileCoords coords, + int warp_id, + int lane_id) { + const bool owns_row = lane_id < kTileV; + const int row_in_tile = lane_id; + const int v_row = coords.v_base + row_in_tile; + return RowTileUpdatePlan{ + coords, + warp_id, + lane_id, + owns_row, + row_in_tile, + v_row, + }; + } + + template + CUTE_DEVICE static float accumulate_proj_row_tile( + TensorState const& state_vk, + TensorKTile const& k_smem, + int v_row, + TileCoords coords) { + float state_regs[kTileK]; + float k_regs[kTileK]; + load_state_row_tile_to_regs(state_vk, v_row, coords, state_regs); + load_vec_tile_to_regs(k_smem, coords, k_regs); + + float accum = 0.f; +#pragma unroll + for (int kk = 0; kk < kTileK; ++kk) { + accum += state_regs[kk] * k_regs[kk]; + } + return accum; + } + + template + CUTE_DEVICE static float update_state_row_tile_and_accumulate_out( + TensorState& state_vk, + TensorKTile const& k_smem, + TensorQTile const& q_smem, + int v_row, + TileCoords coords, + float decay, + float v_new) { + float state_regs[kTileK]; + float k_regs[kTileK]; + float q_regs[kTileK]; + load_state_row_tile_to_regs(state_vk, v_row, coords, state_regs); + load_vec_tile_to_regs(k_smem, coords, k_regs); + load_vec_tile_to_regs(q_smem, coords, q_regs); + + float out_acc = 0.f; +#pragma unroll + for (int kk = 0; kk < kTileK; ++kk) { + const float state_new = decay * state_regs[kk] + v_new * k_regs[kk]; + state_regs[kk] = state_new; + out_acc += state_new * q_regs[kk]; + } + store_state_row_tile_from_regs(state_vk, v_row, coords, state_regs); + return out_acc; + } + + template + CUTE_DEVICE static float project_row_tile( + TensorState const& state_vk, + TensorKTile const& k_smem, + int v_row, + TileCoords coords) { + // Current decode path: + // - one thread owns one full V-row + // - this helper computes the row-local proj contribution for one K tile + // - no cross-thread reduction is needed + return accumulate_proj_row_tile(state_vk, k_smem, v_row, coords); + } + + template + CUTE_DEVICE static float project_row_tile( + TensorState const& state_vk, + TensorKTile const& k_smem, + RowTileProjPlan const& plan) { + if (!plan.owns_row) { + return 0.f; + } + return project_row_tile( + state_vk, k_smem, plan.v_row, TileCoords{plan.v_base, plan.k_base}); + } + + template + CUTE_DEVICE static float update_and_output_row_tile( + TensorState& state_vk, + TensorKTile const& k_smem, + TensorQTile const& q_smem, + int v_row, + TileCoords coords, + float decay, + float v_new) { + // Current decode path: + // - read one 16-wide state tile for the owned row into registers + // - apply decay and rank-1 update in fp32 + // - accumulate the matching out contribution against q + // - write the updated state tile back + return update_state_row_tile_and_accumulate_out( + state_vk, k_smem, q_smem, v_row, coords, decay, v_new); + } + + template + CUTE_DEVICE static float update_and_output_row_tile( + TensorState& state_vk, + TensorKTile const& k_smem, + TensorQTile const& q_smem, + RowTileUpdatePlan const& plan, + float decay, + float v_new) { + if (!plan.owns_row) { + return 0.f; + } + return update_and_output_row_tile( + state_vk, + k_smem, + q_smem, + plan.v_row, + plan.coords, + decay, + v_new); + } + + CUTE_DEVICE static float softplusf_approx(float x) { + return x > 20.f ? x : log1pf(expf(x)); + } + + template < + typename TensorQ, + typename TensorK, + typename TensorV, + typename TensorA, + typename TensorB, + typename TensorAlog, + typename TensorDt, + typename TensorHvk, + typename TensorOut, + typename SharedStorage> + CUTE_DEVICE static void run( + TensorQ const& q_vec, + TensorK const& k_vec, + TensorV const& v_vec, + TensorA const& a_scalar, + TensorB const& b_scalar, + TensorAlog const& A_log_scalar, + TensorDt const& dt_bias_scalar, + TensorHvk& state_vk, + TensorOut& out_vec, + SharedStorage& storage, + int v_tile_base, + int tid, + int num_threads) { + // Decode organization: + // - 1 warpgroup owns the full [128, 128] state tile for one (token, hv) + // - state is traversed as 16x16 tiles over the internal VK view + // - q/k/v are staged once into shared memory + // - proj/out are accumulated over K tiles + // - rank-1 update is applied tile-by-tile in the same traversal order + // + // This pass establishes the tile-first organization for the final fp32 + // decode kernel. The next implementation step should optimize the scalar + // inner loops with better register ownership / reductions rather than + // forcing Tensor Core math. + // + // TODO(qwen35-decode-fp32): + // - evaluate whether q/k should move from shared-memory staging to + // warp-shuffle broadcast + // - evaluate whether one thread should own more than one V-row + // - evaluate whether some parts of the state row can remain resident in + // registers across both proj and update/out passes with acceptable + // register pressure + + const float a_val = static_cast(a_scalar); + const float b_val = static_cast(b_scalar); + const float A_log_val = static_cast(A_log_scalar); + const float dt_bias_val = static_cast(dt_bias_scalar); + + const float g = -expf(A_log_val) * softplusf_approx(a_val + dt_bias_val); + const float decay = expf(g); + const float beta = 1.f / (1.f + expf(-b_val)); + + auto q_smem = make_tensor(make_smem_ptr(storage.q_smem), make_layout(make_shape(Int{}))); + auto k_smem = make_tensor(make_smem_ptr(storage.k_smem), make_layout(make_shape(Int{}))); + auto v_smem = make_tensor(make_smem_ptr(storage.v_smem), make_layout(make_shape(Int{}))); + auto norm_smem = make_tensor(make_smem_ptr(storage.norm_smem), make_layout(make_shape(Int<2>{}))); + auto state_smem = make_tensor( + make_smem_ptr(storage.state_smem), + make_layout(make_shape(Int{}, Int{}), make_stride(Int{}, Int<1>{}))); + auto proj_smem = make_tensor(make_smem_ptr(storage.proj_smem), make_layout(make_shape(Int{}))); + auto out_smem = make_tensor(make_smem_ptr(storage.out_smem), make_layout(make_shape(Int{}))); + + // Stage q/k/v once per CTA for the current decode token. + for (int idx = tid; idx < kHeadDimQK; idx += num_threads) { + q_smem(idx) = static_cast(q_vec(idx)); + k_smem(idx) = static_cast(k_vec(idx)); + } + for (int local_v = tid; local_v < kTileV; local_v += num_threads) { + const int v_global = v_tile_base + local_v; + if (v_global < kHeadDimV) { + v_smem(local_v) = v_vec(v_global); + } + } + for (int idx = tid; idx < kHeadDimV; idx += num_threads) { + proj_smem(idx) = 0.f; + out_smem(idx) = 0.f; + } + for (int idx = tid; idx < kHeadDimQK * kTileV; idx += num_threads) { + const int k_idx = idx / kTileV; + const int local_v = idx - k_idx * kTileV; + const int v_global = v_tile_base + local_v; + state_smem(k_idx, local_v) = v_global < kHeadDimV ? static_cast(state_vk(v_global, k_idx)) : 0.f; + } + __syncthreads(); + + if (tid == 0) { + float q_norm_sq = 0.f; + float k_norm_sq = 0.f; +#pragma unroll + for (int idx = 0; idx < kHeadDimQK; ++idx) { + const float q_val = q_smem(idx); + const float k_val = k_smem(idx); + q_norm_sq += q_val * q_val; + k_norm_sq += k_val * k_val; + } + norm_smem(0) = rsqrtf(q_norm_sq + 1e-6f) * rsqrtf(static_cast(kHeadDimQK)); + norm_smem(1) = rsqrtf(k_norm_sq + 1e-6f); + } + __syncthreads(); + + for (int idx = tid; idx < kHeadDimQK; idx += num_threads) { + q_smem(idx) = q_smem(idx) * norm_smem(0); + k_smem(idx) = k_smem(idx) * norm_smem(1); + } + __syncthreads(); + + const int local_v = tid / kThreadsPerVRow; + const int row_lane = tid - local_v * kThreadsPerVRow; + const int v_global = v_tile_base + local_v; + if (local_v < kTileV && v_global < kHeadDimV) { + const int k_begin = row_lane * kKPerThread; + float proj_part = 0.f; +#pragma unroll + for (int kk = 0; kk < kKPerThread; ++kk) { + const int k_idx = k_begin + kk; + proj_part += state_smem(k_idx, local_v) * k_smem(k_idx); + } + proj_smem(tid) = proj_part; + } + __syncthreads(); + + if (local_v < kTileV && row_lane == 0 && v_global < kHeadDimV) { + float proj_row = 0.f; +#pragma unroll + for (int lane = 0; lane < kThreadsPerVRow; ++lane) { + proj_row += proj_smem(local_v * kThreadsPerVRow + lane); + } + const float v_val = static_cast(v_smem(local_v)); + proj_smem(local_v) = beta * (v_val - decay * proj_row); + } + __syncthreads(); + + if (local_v < kTileV && v_global < kHeadDimV) { + const int k_begin = row_lane * kKPerThread; + const float v_new_row = proj_smem(local_v); + float out_part = 0.f; +#pragma unroll + for (int kk = 0; kk < kKPerThread; ++kk) { + const int k_idx = k_begin + kk; + const float state_new = decay * state_smem(k_idx, local_v) + v_new_row * k_smem(k_idx); + state_smem(k_idx, local_v) = state_new; + out_part += state_new * q_smem(k_idx); + state_vk(v_global, k_idx) = state_new; + } + out_smem(local_v * kThreadsPerVRow + row_lane) = out_part; + } + __syncthreads(); + + if (local_v < kTileV && row_lane == 0 && v_global < kHeadDimV) { + float out_row = 0.f; +#pragma unroll + for (int lane = 0; lane < kThreadsPerVRow; ++lane) { + out_row += out_smem(local_v * kThreadsPerVRow + lane); + } + out_vec(v_global) = static_cast(out_row); + } + } +}; + +} // namespace cula::qwen35::decode::kernel diff --git a/csrc/qwen35/prefill/qwen35_layout_prefill.cu b/csrc/qwen35/prefill/qwen35_layout_prefill.cu new file mode 100644 index 00000000..5cdf1218 --- /dev/null +++ b/csrc/qwen35/prefill/qwen35_layout_prefill.cu @@ -0,0 +1,188 @@ +// Copyright 2025-2026 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "qwen35_layout_prefill_kernel.hpp" +#include "qwen35_prefill_common.cuh" + +#include +#include +#include +#include + +namespace cula::qwen35::prefill { + +namespace { + +void check_tensor_device(const at::Tensor& tensor, const char* name, const at::Device& device) { + TORCH_CHECK(tensor.device() == device, name, " must be on device ", device, "."); +} + +void check_rank_2(const at::Tensor& tensor, const char* name) { + TORCH_CHECK(tensor.dim() == 2, name, " must be rank 2, got rank ", tensor.dim(), "."); +} + +template +void dispatch_layout_prefill_for_heads( + cudaStream_t stream, + const scalar_t* mixed_qkv_conv, + const scalar_t* a, + const scalar_t* b, + scalar_t* q_rep, + scalar_t* k_rep, + scalar_t* v, + scalar_t* a_kernel, + scalar_t* b_kernel, + int64_t token_count) { + constexpr int kLocalQKHeads = decode::local_qk_heads_from_v_heads(kLocalVHeads); + kernel::launch_qwen35_layout_prefill_kernel( + stream, + mixed_qkv_conv, + a, + b, + q_rep, + k_rep, + v, + a_kernel, + b_kernel, + token_count); +} + +template +void dispatch_layout_prefill( + int64_t local_v_heads, + cudaStream_t stream, + const scalar_t* mixed_qkv_conv, + const scalar_t* a, + const scalar_t* b, + scalar_t* q_rep, + scalar_t* k_rep, + scalar_t* v, + scalar_t* a_kernel, + scalar_t* b_kernel, + int64_t token_count) { + switch (local_v_heads) { + case 48: + dispatch_layout_prefill_for_heads(stream, mixed_qkv_conv, a, b, q_rep, k_rep, v, a_kernel, b_kernel, token_count); + break; + case 24: + dispatch_layout_prefill_for_heads(stream, mixed_qkv_conv, a, b, q_rep, k_rep, v, a_kernel, b_kernel, token_count); + break; + case 12: + dispatch_layout_prefill_for_heads(stream, mixed_qkv_conv, a, b, q_rep, k_rep, v, a_kernel, b_kernel, token_count); + break; + case 6: + dispatch_layout_prefill_for_heads(stream, mixed_qkv_conv, a, b, q_rep, k_rep, v, a_kernel, b_kernel, token_count); + break; + } +} + +} // namespace + +void run_qwen35_layout_prefill(LayoutPrefillParams& params) { + const at::Tensor& mixed_qkv_conv = params.mixed_qkv_conv; + const at::Tensor& a = params.a; + const at::Tensor& b = params.b; + const at::Tensor& q_rep = params.q_rep; + const at::Tensor& k_rep = params.k_rep; + const at::Tensor& v = params.v; + const at::Tensor& a_kernel = params.a_kernel; + const at::Tensor& b_kernel = params.b_kernel; + + TORCH_CHECK(mixed_qkv_conv.is_cuda(), "mixed_qkv_conv must be a CUDA tensor."); + const at::Device device = mixed_qkv_conv.device(); + + check_tensor_device(a, "a", device); + check_tensor_device(b, "b", device); + check_tensor_device(q_rep, "q_rep", device); + check_tensor_device(k_rep, "k_rep", device); + check_tensor_device(v, "v", device); + check_tensor_device(a_kernel, "a_kernel", device); + check_tensor_device(b_kernel, "b_kernel", device); + + TORCH_CHECK(mixed_qkv_conv.is_contiguous(), "mixed_qkv_conv must be contiguous."); + TORCH_CHECK(a.is_contiguous(), "a must be contiguous."); + TORCH_CHECK(b.is_contiguous(), "b must be contiguous."); + TORCH_CHECK(q_rep.is_contiguous(), "q_rep must be contiguous."); + TORCH_CHECK(k_rep.is_contiguous(), "k_rep must be contiguous."); + TORCH_CHECK(v.is_contiguous(), "v must be contiguous."); + TORCH_CHECK(a_kernel.is_contiguous(), "a_kernel must be contiguous."); + TORCH_CHECK(b_kernel.is_contiguous(), "b_kernel must be contiguous."); + + TORCH_CHECK( + mixed_qkv_conv.scalar_type() == a.scalar_type() && + mixed_qkv_conv.scalar_type() == b.scalar_type() && + mixed_qkv_conv.scalar_type() == q_rep.scalar_type() && + mixed_qkv_conv.scalar_type() == k_rep.scalar_type() && + mixed_qkv_conv.scalar_type() == v.scalar_type() && + mixed_qkv_conv.scalar_type() == a_kernel.scalar_type() && + mixed_qkv_conv.scalar_type() == b_kernel.scalar_type(), + "All layout prefill tensors must share the same dtype."); + TORCH_CHECK( + mixed_qkv_conv.scalar_type() == at::kHalf || mixed_qkv_conv.scalar_type() == at::kBFloat16, + "mixed_qkv_conv must be float16 or bfloat16."); + + check_rank_2(mixed_qkv_conv, "mixed_qkv_conv"); + check_rank_2(a, "a"); + check_rank_2(b, "b"); + + const int64_t token_count = mixed_qkv_conv.size(0); + const int64_t local_v_heads = a.size(1); + TORCH_CHECK(decode::is_supported_local_v_heads(static_cast(local_v_heads)), "local V heads must be one of {48, 24, 12, 6}, got ", local_v_heads, "."); + const int local_qk_heads = decode::local_qk_heads_from_v_heads(static_cast(local_v_heads)); + const int local_mixed_dim = decode::local_mixed_qkv_dim(local_qk_heads, static_cast(local_v_heads)); + TORCH_CHECK(mixed_qkv_conv.size(1) == local_mixed_dim, "mixed_qkv_conv must be [N, local_conv_dim=", local_mixed_dim, "]."); + TORCH_CHECK(a.sizes() == at::IntArrayRef({token_count, local_v_heads}), "a must be [N, local_v_heads]."); + TORCH_CHECK(b.sizes() == at::IntArrayRef({token_count, local_v_heads}), "b must be [N, local_v_heads]."); + TORCH_CHECK( + q_rep.dim() == 3 && q_rep.sizes() == at::IntArrayRef({token_count, local_v_heads, kHeadDimQK}), + "q_rep must be [N, local_v_heads, 128]."); + TORCH_CHECK(k_rep.sizes() == q_rep.sizes(), "k_rep must match q_rep shape."); + TORCH_CHECK(v.sizes() == q_rep.sizes(), "v must match q_rep shape."); + TORCH_CHECK(a_kernel.sizes() == a.sizes(), "a_kernel must match a shape."); + TORCH_CHECK(b_kernel.sizes() == b.sizes(), "b_kernel must match b shape."); + + const at::cuda::OptionalCUDAGuard device_guard(device); + cudaStream_t stream = at::cuda::getDefaultCUDAStream(device.index()); + + if (mixed_qkv_conv.scalar_type() == at::kHalf) { + dispatch_layout_prefill( + local_v_heads, + stream, + mixed_qkv_conv.data_ptr(), + a.data_ptr(), + b.data_ptr(), + q_rep.data_ptr(), + k_rep.data_ptr(), + v.data_ptr(), + a_kernel.data_ptr(), + b_kernel.data_ptr(), + token_count); + } else { + dispatch_layout_prefill( + local_v_heads, + stream, + mixed_qkv_conv.data_ptr(), + a.data_ptr(), + b.data_ptr(), + q_rep.data_ptr(), + k_rep.data_ptr(), + v.data_ptr(), + a_kernel.data_ptr(), + b_kernel.data_ptr(), + token_count); + } + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +} // namespace cula::qwen35::prefill diff --git a/csrc/qwen35/prefill/qwen35_layout_prefill_kernel.hpp b/csrc/qwen35/prefill/qwen35_layout_prefill_kernel.hpp new file mode 100644 index 00000000..7dbcae3d --- /dev/null +++ b/csrc/qwen35/prefill/qwen35_layout_prefill_kernel.hpp @@ -0,0 +1,144 @@ +// Copyright 2025-2026 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "qwen35_prefill_common.cuh" + +#include +#include +#include +#include + +namespace cula::qwen35::prefill::kernel { + +using namespace cute; + +template +CUTE_DEVICE void copy_prefill_vec_contiguous( + scalar_t* __restrict__ dst, + const scalar_t* __restrict__ src) { + constexpr int kBytes = sizeof(scalar_t) * kVec; + if constexpr (kBytes == 16 || kBytes == 8) { + using VecType = cutlass::AlignedArray; + const auto dst_addr = reinterpret_cast(dst); + const auto src_addr = reinterpret_cast(src); + if ((dst_addr % alignof(VecType) == 0) && (src_addr % alignof(VecType) == 0)) { + *reinterpret_cast(dst) = *reinterpret_cast(src); + return; + } + } + +#pragma unroll + for (int i = 0; i < kVec; ++i) { + dst[i] = src[i]; + } +} + +template +__global__ void qwen35_layout_prefill_kernel( + const scalar_t* __restrict__ mixed_qkv_conv, + const scalar_t* __restrict__ a, + const scalar_t* __restrict__ b, + scalar_t* __restrict__ q_rep, + scalar_t* __restrict__ k_rep, + scalar_t* __restrict__ v_out, + scalar_t* __restrict__ a_kernel, + scalar_t* __restrict__ b_kernel, + int64_t token_count) { + static_assert(kLocalVHeads % kLocalQKHeads == 0); + static_assert(kHeadDimQK == kHeadDimV); + constexpr int kRepeatFactor = kLocalVHeads / kLocalQKHeads; + constexpr int kLocalQDim = kLocalQKHeads * kHeadDimQK; + constexpr int kLocalKDim = kLocalQKHeads * kHeadDimQK; + constexpr int kLocalMixedQKVDim = 2 * kLocalQDim + kLocalVHeads * kHeadDimV; + constexpr int kVec = 4; + static_assert(kHeadDimQK % kVec == 0); + + const int hv = static_cast(blockIdx.x); + const int token_idx = static_cast(blockIdx.y); + const int tid = static_cast(threadIdx.x); + if (token_idx >= token_count || hv >= kLocalVHeads) { + return; + } + + const int mapped_h = hv / kRepeatFactor; + + auto qk_src_layout = make_layout( + make_shape(Int{}, Int{}), + make_stride(Int{}, Int<1>{})); + auto v_src_layout = make_layout( + make_shape(Int{}, Int{}), + make_stride(Int{}, Int<1>{})); + auto hv_layout = make_layout( + make_shape(Int{}, Int{}), + make_stride(Int{}, Int<1>{})); + auto head_layout = make_layout(make_shape(Int{}), make_stride(Int<1>{})); + + const scalar_t* token_ptr = mixed_qkv_conv + static_cast(token_idx) * kLocalMixedQKVDim; + const scalar_t* q_src_ptr = token_ptr; + const scalar_t* k_src_ptr = token_ptr + kLocalQDim; + const scalar_t* v_src_ptr = token_ptr + kLocalQDim + kLocalKDim; + + scalar_t* q_dst_ptr = q_rep + static_cast(token_idx) * kLocalVHeads * kHeadDimQK; + scalar_t* k_dst_ptr = k_rep + static_cast(token_idx) * kLocalVHeads * kHeadDimQK; + scalar_t* v_dst_ptr = v_out + static_cast(token_idx) * kLocalVHeads * kHeadDimV; + + for (int vec_idx = tid; vec_idx < kHeadDimQK / kVec; vec_idx += blockDim.x) { + const int d = vec_idx * kVec; + const int q_src_idx = crd2idx(make_coord(mapped_h, d), qk_src_layout); + const int k_src_idx = crd2idx(make_coord(mapped_h, d), qk_src_layout); + const int v_src_idx = crd2idx(make_coord(hv, d), v_src_layout); + const int dst_idx = crd2idx(make_coord(hv, d), hv_layout); + + copy_prefill_vec_contiguous(q_dst_ptr + dst_idx, q_src_ptr + q_src_idx); + copy_prefill_vec_contiguous(k_dst_ptr + dst_idx, k_src_ptr + k_src_idx); + copy_prefill_vec_contiguous(v_dst_ptr + dst_idx, v_src_ptr + v_src_idx); + } + + if (tid == 0) { + const int head_idx = crd2idx(make_coord(hv), head_layout); + const int64_t token_head_offset = static_cast(token_idx) * kLocalVHeads + head_idx; + a_kernel[token_head_offset] = a[token_head_offset]; + b_kernel[token_head_offset] = b[token_head_offset]; + } +} + +template +void launch_qwen35_layout_prefill_kernel( + cudaStream_t stream, + const scalar_t* mixed_qkv_conv, + const scalar_t* a, + const scalar_t* b, + scalar_t* q_rep, + scalar_t* k_rep, + scalar_t* v, + scalar_t* a_kernel, + scalar_t* b_kernel, + int64_t token_count) { + constexpr int kThreads = 32; + dim3 grid(kLocalVHeads, static_cast(token_count), 1); + qwen35_layout_prefill_kernel<<>>( + mixed_qkv_conv, + a, + b, + q_rep, + k_rep, + v, + a_kernel, + b_kernel, + token_count); +} + +} // namespace cula::qwen35::prefill::kernel diff --git a/csrc/qwen35/prefill/qwen35_prefill_common.cuh b/csrc/qwen35/prefill/qwen35_prefill_common.cuh new file mode 100644 index 00000000..ba3f3b70 --- /dev/null +++ b/csrc/qwen35/prefill/qwen35_prefill_common.cuh @@ -0,0 +1,66 @@ +// Copyright 2025-2026 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "qwen35/decode/qwen35_decode_common.cuh" + +#include + +namespace cula::qwen35::prefill { + +using decode::kHeadDimQK; +using decode::kHeadDimV; +using decode::kKDim; +using decode::kMixedQKVDim; +using decode::kNumQKHeads; +using decode::kNumVHeads; +using decode::kQDim; +using decode::kVDim; + +struct LayoutPrefillParams { + at::Tensor mixed_qkv_conv; // [N, local_conv_dim] + at::Tensor a; // [N, local_v_heads] + at::Tensor b; // [N, local_v_heads] + at::Tensor q_rep; // [N, local_v_heads, 128] + at::Tensor k_rep; // [N, local_v_heads, 128] + at::Tensor v; // [N, local_v_heads, 128] + at::Tensor a_kernel; // [N, local_v_heads] + at::Tensor b_kernel; // [N, local_v_heads] +}; + +struct ScalarKdaPrefillParams { + at::Tensor q; // [B, T, local_v_heads, 128] + at::Tensor k; // [B, T, local_v_heads, 128] + at::Tensor v; // [B, T, local_v_heads, 128] + at::Tensor a; // [B, T, local_v_heads] + at::Tensor b; // [B, T, local_v_heads] + at::Tensor A_log; // [local_v_heads], float32 + at::Tensor dt_bias; // [local_v_heads], float32 + at::Tensor initial_state; // [N, local_v_heads, 128, 128], float32, may be empty + at::Tensor cu_seqlens; // [N + 1], int32, may be empty + at::Tensor out; // [B, T, local_v_heads, 128] + at::Tensor final_state; // [N, local_v_heads, 128, 128], float32 +}; + +void run_qwen35_scalar_kda_prefill(ScalarKdaPrefillParams& params); +void run_qwen35_layout_prefill(LayoutPrefillParams& params); + +} // namespace cula::qwen35::prefill + +namespace cula::qwen35::prefill::sm90 { + +void qwen35_chunk_qk_prefill_sm90(at::Tensor q, at::Tensor k, at::Tensor out); + +} // namespace cula::qwen35::prefill::sm90 diff --git a/csrc/qwen35/prefill/qwen35_scalar_kda_prefill.cu b/csrc/qwen35/prefill/qwen35_scalar_kda_prefill.cu new file mode 100644 index 00000000..a92481a9 --- /dev/null +++ b/csrc/qwen35/prefill/qwen35_scalar_kda_prefill.cu @@ -0,0 +1,250 @@ +// Copyright 2025-2026 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "qwen35_prefill_common.cuh" +#include "qwen35_scalar_kda_prefill_kernel.hpp" + +#include +#include +#include +#include + +namespace cula::qwen35::prefill { + +namespace { + +void check_tensor_device(const at::Tensor& tensor, const char* name, const at::Device& device) { + if (tensor.defined() && tensor.numel() > 0) { + TORCH_CHECK(tensor.device() == device, name, " must be on device ", device, "."); + } +} + +void check_contiguous(const at::Tensor& tensor, const char* name) { + if (tensor.defined() && tensor.numel() > 0) { + TORCH_CHECK(tensor.is_contiguous(), name, " must be contiguous."); + } +} + +template +void dispatch_scalar_prefill_for_heads( + cudaStream_t stream, + const scalar_t* q, + const scalar_t* k, + const scalar_t* v, + const scalar_t* a, + const scalar_t* b, + const float* A_log, + const float* dt_bias, + const float* initial_state, + const int32_t* cu_seqlens, + scalar_t* out, + float* final_state, + int batch_size, + int seq_len, + int sequence_count, + bool is_varlen, + bool has_initial_state) { + kernel::launch_qwen35_scalar_kda_prefill_kernel( + stream, + q, + k, + v, + a, + b, + A_log, + dt_bias, + initial_state, + cu_seqlens, + out, + final_state, + batch_size, + seq_len, + sequence_count, + is_varlen, + has_initial_state); +} + +template +void dispatch_scalar_prefill( + int64_t local_v_heads, + cudaStream_t stream, + const scalar_t* q, + const scalar_t* k, + const scalar_t* v, + const scalar_t* a, + const scalar_t* b, + const float* A_log, + const float* dt_bias, + const float* initial_state, + const int32_t* cu_seqlens, + scalar_t* out, + float* final_state, + int batch_size, + int seq_len, + int sequence_count, + bool is_varlen, + bool has_initial_state) { + switch (local_v_heads) { + case 48: + dispatch_scalar_prefill_for_heads(stream, q, k, v, a, b, A_log, dt_bias, initial_state, cu_seqlens, out, final_state, batch_size, seq_len, sequence_count, is_varlen, has_initial_state); + break; + case 24: + dispatch_scalar_prefill_for_heads(stream, q, k, v, a, b, A_log, dt_bias, initial_state, cu_seqlens, out, final_state, batch_size, seq_len, sequence_count, is_varlen, has_initial_state); + break; + case 12: + dispatch_scalar_prefill_for_heads(stream, q, k, v, a, b, A_log, dt_bias, initial_state, cu_seqlens, out, final_state, batch_size, seq_len, sequence_count, is_varlen, has_initial_state); + break; + case 6: + dispatch_scalar_prefill_for_heads(stream, q, k, v, a, b, A_log, dt_bias, initial_state, cu_seqlens, out, final_state, batch_size, seq_len, sequence_count, is_varlen, has_initial_state); + break; + } +} + +} // namespace + +void run_qwen35_scalar_kda_prefill(ScalarKdaPrefillParams& params) { + const at::Tensor& q = params.q; + const at::Tensor& k = params.k; + const at::Tensor& v = params.v; + const at::Tensor& a = params.a; + const at::Tensor& b = params.b; + const at::Tensor& A_log = params.A_log; + const at::Tensor& dt_bias = params.dt_bias; + const at::Tensor& initial_state = params.initial_state; + const at::Tensor& cu_seqlens = params.cu_seqlens; + const at::Tensor& out = params.out; + const at::Tensor& final_state = params.final_state; + + TORCH_CHECK(q.is_cuda(), "q must be a CUDA tensor."); + const at::Device device = q.device(); + + check_tensor_device(k, "k", device); + check_tensor_device(v, "v", device); + check_tensor_device(a, "a", device); + check_tensor_device(b, "b", device); + check_tensor_device(A_log, "A_log", device); + check_tensor_device(dt_bias, "dt_bias", device); + check_tensor_device(initial_state, "initial_state", device); + check_tensor_device(cu_seqlens, "cu_seqlens", device); + check_tensor_device(out, "out", device); + check_tensor_device(final_state, "final_state", device); + + check_contiguous(q, "q"); + check_contiguous(k, "k"); + check_contiguous(v, "v"); + check_contiguous(a, "a"); + check_contiguous(b, "b"); + check_contiguous(A_log, "A_log"); + check_contiguous(dt_bias, "dt_bias"); + check_contiguous(initial_state, "initial_state"); + check_contiguous(cu_seqlens, "cu_seqlens"); + check_contiguous(out, "out"); + check_contiguous(final_state, "final_state"); + + TORCH_CHECK( + q.scalar_type() == k.scalar_type() && q.scalar_type() == v.scalar_type() && + q.scalar_type() == a.scalar_type() && q.scalar_type() == b.scalar_type() && + q.scalar_type() == out.scalar_type(), + "q/k/v/a/b/out must share the same dtype."); + TORCH_CHECK(q.scalar_type() == at::kHalf || q.scalar_type() == at::kBFloat16, "q must be float16 or bfloat16."); + TORCH_CHECK(A_log.scalar_type() == at::kFloat, "A_log must be float32."); + TORCH_CHECK(dt_bias.scalar_type() == at::kFloat, "dt_bias must be float32."); + TORCH_CHECK(final_state.scalar_type() == at::kFloat, "final_state must be float32."); + TORCH_CHECK( + !initial_state.defined() || initial_state.numel() == 0 || initial_state.scalar_type() == at::kFloat, + "initial_state must be float32 when provided."); + TORCH_CHECK( + !cu_seqlens.defined() || cu_seqlens.numel() == 0 || cu_seqlens.scalar_type() == at::kInt, + "cu_seqlens must be int32 when provided."); + + TORCH_CHECK(q.dim() == 4, "q must be [B, T, 48, 128]."); + const int64_t B = q.size(0); + const int64_t T = q.size(1); + const int64_t local_v_heads = q.size(2); + TORCH_CHECK(decode::is_supported_local_v_heads(static_cast(local_v_heads)), "local V heads must be one of {48, 24, 12, 6}, got ", local_v_heads, "."); + TORCH_CHECK( + q.sizes() == at::IntArrayRef({B, T, local_v_heads, kHeadDimQK}), + "q must have shape [B, T, local_v_heads, 128]."); + TORCH_CHECK(k.sizes() == q.sizes(), "k must match q shape."); + TORCH_CHECK(v.sizes() == q.sizes(), "v must match q shape."); + TORCH_CHECK(a.dim() == 3 && a.sizes() == at::IntArrayRef({B, T, local_v_heads}), "a must be [B, T, local_v_heads]."); + TORCH_CHECK(b.sizes() == a.sizes(), "b must match a shape."); + TORCH_CHECK(A_log.dim() == 1 && A_log.size(0) == local_v_heads, "A_log must be [local_v_heads]."); + TORCH_CHECK(dt_bias.dim() == 1 && dt_bias.size(0) == local_v_heads, "dt_bias must be [local_v_heads]."); + TORCH_CHECK(out.sizes() == q.sizes(), "out must match q shape."); + + const bool is_varlen = cu_seqlens.defined() && cu_seqlens.numel() > 0; + const int64_t sequence_count = is_varlen ? cu_seqlens.numel() - 1 : B; + TORCH_CHECK(sequence_count > 0, "sequence_count must be positive."); + if (is_varlen) { + TORCH_CHECK(B == 1, "cu_seqlens mode expects flattened q/k/v with batch size 1."); + } + + TORCH_CHECK( + final_state.dim() == 4 && + final_state.sizes() == at::IntArrayRef({sequence_count, local_v_heads, kHeadDimQK, kHeadDimV}), + "final_state must be [N, local_v_heads, 128, 128]."); + const bool has_initial_state = initial_state.defined() && initial_state.numel() > 0; + if (has_initial_state) { + TORCH_CHECK(initial_state.sizes() == final_state.sizes(), "initial_state must match final_state shape."); + } + + const at::cuda::OptionalCUDAGuard device_guard(device); + cudaStream_t stream = at::cuda::getDefaultCUDAStream(device.index()); + + if (q.scalar_type() == at::kHalf) { + dispatch_scalar_prefill( + local_v_heads, + stream, + q.data_ptr(), + k.data_ptr(), + v.data_ptr(), + a.data_ptr(), + b.data_ptr(), + A_log.data_ptr(), + dt_bias.data_ptr(), + has_initial_state ? initial_state.data_ptr() : nullptr, + is_varlen ? cu_seqlens.data_ptr() : nullptr, + out.data_ptr(), + final_state.data_ptr(), + static_cast(B), + static_cast(T), + static_cast(sequence_count), + is_varlen, + has_initial_state); + } else { + dispatch_scalar_prefill( + local_v_heads, + stream, + q.data_ptr(), + k.data_ptr(), + v.data_ptr(), + a.data_ptr(), + b.data_ptr(), + A_log.data_ptr(), + dt_bias.data_ptr(), + has_initial_state ? initial_state.data_ptr() : nullptr, + is_varlen ? cu_seqlens.data_ptr() : nullptr, + out.data_ptr(), + final_state.data_ptr(), + static_cast(B), + static_cast(T), + static_cast(sequence_count), + is_varlen, + has_initial_state); + } + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +} // namespace cula::qwen35::prefill diff --git a/csrc/qwen35/prefill/qwen35_scalar_kda_prefill_kernel.hpp b/csrc/qwen35/prefill/qwen35_scalar_kda_prefill_kernel.hpp new file mode 100644 index 00000000..c8b2e396 --- /dev/null +++ b/csrc/qwen35/prefill/qwen35_scalar_kda_prefill_kernel.hpp @@ -0,0 +1,279 @@ +// Copyright 2025-2026 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "qwen35_prefill_common.cuh" + +#include +#include + +namespace cula::qwen35::prefill::kernel { + +using namespace cute; + +template +struct Qwen35ScalarKdaPrefillKernel { + static constexpr int kThreads = 128; + static constexpr int kHeadDim = kHeadDimQK; + // Keep the scalar CUDA fallback at one V row per CTA for correctness while + // the SM90 chunk/TMA path is being wired in. The previous multi-row V tile + // version exposed a correctness bug with non-zero initial_state; the chunk + // path should own the next parallelization step. + static constexpr int kVTile = 1; + static constexpr int kNumVTiles = kHeadDimV / kVTile; + + static_assert(kHeadDimQK == 128); + static_assert(kHeadDimV == 128); + static_assert(kHeadDimV % kVTile == 0); + + struct SharedStorage { + float scratch[kThreads]; + }; + + static dim3 block_shape() { + return dim3(kThreads, 1, 1); + } + + CUTE_HOST_DEVICE static auto make_v_work_tiles(int sequence_count) { + auto problem_layout = make_layout( + make_shape(Int{}, Int{}, sequence_count), + make_stride(Int<1>{}, Int{}, Int{})); + return zipped_divide(problem_layout, make_shape(Int{}, Int<1>{}, Int<1>{})); + } + + static dim3 grid_shape(int sequence_count) { + auto v_work_tiles = make_v_work_tiles(sequence_count); + return dim3(static_cast(size<1>(v_work_tiles)), 1, 1); + } + + CUTE_DEVICE static float load_as_float(scalar_t value) { + return static_cast(value); + } + + CUTE_DEVICE static scalar_t cast_output(float value) { + return static_cast(value); + } + + CUTE_DEVICE static float softplus(float x) { + return x > 20.0f ? x : log1pf(expf(x)); + } + + CUTE_DEVICE static float block_sum(float value, SharedStorage& storage, int tid) { + storage.scratch[tid] = value; + __syncthreads(); + + for (int stride = kThreads / 2; stride > 0; stride >>= 1) { + if (tid < stride) { + storage.scratch[tid] += storage.scratch[tid + stride]; + } + __syncthreads(); + } + const float result = storage.scratch[0]; + __syncthreads(); + return result; + } + + CUTE_DEVICE static void run_device( + const scalar_t* __restrict__ q, + const scalar_t* __restrict__ k, + const scalar_t* __restrict__ v, + const scalar_t* __restrict__ a, + const scalar_t* __restrict__ b, + const float* __restrict__ A_log, + const float* __restrict__ dt_bias, + const float* __restrict__ initial_state, + const int32_t* __restrict__ cu_seqlens, + scalar_t* __restrict__ out, + float* __restrict__ final_state, + int batch_size, + int seq_len, + int sequence_count, + bool is_varlen, + bool has_initial_state, + SharedStorage& storage) { + auto v_work_tiles = make_v_work_tiles(sequence_count); + auto work_layout = make_layout(get<1>(v_work_tiles.shape()), LayoutLeft{}); + auto work_coord = work_layout.get_hier_coord(static_cast(blockIdx.x)); + const int v_tile_idx = static_cast(get<0>(work_coord)); + const int hv = static_cast(get<1>(work_coord)); + const int seq_idx = static_cast(get<2>(work_coord)); + const int v_base = v_tile_idx * kVTile; + const int tid = static_cast(threadIdx.x); + + if (hv >= kLocalVHeads || seq_idx >= sequence_count) { + return; + } + + const int token_begin = is_varlen ? static_cast(cu_seqlens[seq_idx]) : seq_idx * seq_len; + const int token_end = is_varlen ? static_cast(cu_seqlens[seq_idx + 1]) : token_begin + seq_len; + const int state_base = ((seq_idx * kLocalVHeads + hv) * kHeadDimQK) * kHeadDimV; + + const int kk = tid; + float state_vals[kVTile]; + +#pragma unroll + for (int lane = 0; lane < kVTile; ++lane) { + const int v_row = v_base + lane; + const int state_off = state_base + kk * kHeadDimV + v_row; + state_vals[lane] = 0.0f; + if (kk < kHeadDimQK && v_row < kHeadDimV) { + state_vals[lane] = has_initial_state ? initial_state[state_off] : 0.0f; + } + } + __syncthreads(); + + const float scale = rsqrtf(static_cast(kHeadDimQK)); + const float exp_A = expf(A_log[hv]); + const float dt = dt_bias[hv]; + + for (int token = token_begin; token < token_end; ++token) { + const int local_t = is_varlen ? token : token - token_begin; + const int qkv_base = ((token * kLocalVHeads + hv) * kHeadDimQK); + const int gate_base = token * kLocalVHeads + hv; + + const float q_val = kk < kHeadDimQK ? load_as_float(q[qkv_base + kk]) : 0.0f; + const float k_val = kk < kHeadDimQK ? load_as_float(k[qkv_base + kk]) : 0.0f; + const float q_norm_sq = block_sum(q_val * q_val, storage, tid); + const float k_norm_sq = block_sum(k_val * k_val, storage, tid); + const float q_rnorm = rsqrtf(fmaxf(q_norm_sq, 1.0e-20f)) * scale; + const float k_rnorm = rsqrtf(fmaxf(k_norm_sq, 1.0e-20f)); + + const float decay = expf(-exp_A * softplus(load_as_float(a[gate_base]) + dt)); + const float beta = 1.0f / (1.0f + expf(-load_as_float(b[gate_base]))); + + const float k_norm = k_val * k_rnorm; + const float q_norm = q_val * q_rnorm; + +#pragma unroll + for (int lane = 0; lane < kVTile; ++lane) { + const int v_row = v_base + lane; + if (v_row < kHeadDimV) { + const float proj_partial = kk < kHeadDimQK ? state_vals[lane] * k_norm : 0.0f; + const float proj = block_sum(proj_partial, storage, tid); + + const float v_val = load_as_float(v[qkv_base + v_row]); + const float v_new = beta * (v_val - decay * proj); + + float out_partial = 0.0f; + if (kk < kHeadDimQK) { + const float state_new = decay * state_vals[lane] + k_norm * v_new; + state_vals[lane] = state_new; + out_partial = state_new * q_norm; + } + const float out_acc = block_sum(out_partial, storage, tid); + + if (tid == 0) { + const int out_off = + (((is_varlen ? 0 : seq_idx) * seq_len + local_t) * kLocalVHeads + hv) * kHeadDimV + v_row; + out[out_off] = cast_output(out_acc); + } + } + } + __syncthreads(); + } + +#pragma unroll + for (int lane = 0; lane < kVTile; ++lane) { + const int v_row = v_base + lane; + if (kk < kHeadDimQK && v_row < kHeadDimV) { + const int state_off = state_base + kk * kHeadDimV + v_row; + final_state[state_off] = state_vals[lane]; + } + } + + (void)batch_size; + } +}; + +template +__global__ void qwen35_scalar_kda_prefill_kernel( + const scalar_t* __restrict__ q, + const scalar_t* __restrict__ k, + const scalar_t* __restrict__ v, + const scalar_t* __restrict__ a, + const scalar_t* __restrict__ b, + const float* __restrict__ A_log, + const float* __restrict__ dt_bias, + const float* __restrict__ initial_state, + const int32_t* __restrict__ cu_seqlens, + scalar_t* __restrict__ out, + float* __restrict__ final_state, + int batch_size, + int seq_len, + int sequence_count, + bool is_varlen, + bool has_initial_state) { + __shared__ typename Qwen35ScalarKdaPrefillKernel::SharedStorage storage; + Qwen35ScalarKdaPrefillKernel::run_device( + q, + k, + v, + a, + b, + A_log, + dt_bias, + initial_state, + cu_seqlens, + out, + final_state, + batch_size, + seq_len, + sequence_count, + is_varlen, + has_initial_state, + storage); +} + +template +void launch_qwen35_scalar_kda_prefill_kernel( + cudaStream_t stream, + const scalar_t* q, + const scalar_t* k, + const scalar_t* v, + const scalar_t* a, + const scalar_t* b, + const float* A_log, + const float* dt_bias, + const float* initial_state, + const int32_t* cu_seqlens, + scalar_t* out, + float* final_state, + int batch_size, + int seq_len, + int sequence_count, + bool is_varlen, + bool has_initial_state) { + const auto grid = Qwen35ScalarKdaPrefillKernel::grid_shape(sequence_count); + const auto block = Qwen35ScalarKdaPrefillKernel::block_shape(); + qwen35_scalar_kda_prefill_kernel<<>>( + q, + k, + v, + a, + b, + A_log, + dt_bias, + initial_state, + cu_seqlens, + out, + final_state, + batch_size, + seq_len, + sequence_count, + is_varlen, + has_initial_state); +} + +} // namespace cula::qwen35::prefill::kernel diff --git a/csrc/qwen35/prefill/sm90/qwen35_chunk_prefill_sm90.cu b/csrc/qwen35/prefill/sm90/qwen35_chunk_prefill_sm90.cu new file mode 100644 index 00000000..88b157b4 --- /dev/null +++ b/csrc/qwen35/prefill/sm90/qwen35_chunk_prefill_sm90.cu @@ -0,0 +1,180 @@ +// Copyright 2025-2026 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "qwen35_chunk_prefill_traits_sm90.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace cula::qwen35::prefill::sm90 { + +namespace { + +using DefaultTraits = Qwen35ChunkPrefillSm90DefaultTraits; + +static_assert(DefaultTraits::kBlockT == 64); +static_assert(DefaultTraits::kBlockV == 64); +static_assert(DefaultTraits::kStages == 2); +static_assert(size(typename DefaultTraits::TiledMmaQK{}) == 128); +static_assert(size(typename DefaultTraits::TiledMmaOV{}) == 128); +static_assert(cosize(typename DefaultTraits::SmemLayoutQ{}) > 0); +static_assert(cosize(typename DefaultTraits::SmemLayoutK{}) > 0); + +void check_cutlass_status(cutlass::Status status, const char* what) { + TORCH_CHECK(status == cutlass::Status::kSuccess, what, " failed with CUTLASS status ", static_cast(status)); +} + +template +void run_qwen35_chunk_qk_prefill_sm90_impl(const at::Tensor& q, const at::Tensor& k, const at::Tensor& out) { + using ElementA = cutlass::bfloat16_t; + using ElementB = cutlass::bfloat16_t; + using ElementC = float; + using ElementD = float; + using ElementAccumulator = float; + using ElementCompute = float; + + using LayoutA = cute::tuple; + using LayoutB = cute::tuple; + using LayoutC = cute::tuple; + using LayoutD = LayoutC; + + constexpr int kAlignmentA = 16 / sizeof(ElementA); + constexpr int kAlignmentB = 16 / sizeof(ElementB); + constexpr int kAlignmentC = 16 / sizeof(ElementC); + constexpr int kAlignmentD = 16 / sizeof(ElementD); + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; +#if defined(CULA_SM100_ENABLED) || defined(CULA_SM103_ENABLED) + using ArchTag = cutlass::arch::Sm100; + using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; + using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; +#else + using ArchTag = cutlass::arch::Sm90; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecialized; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; +#endif +#if defined(CULA_SM100_ENABLED) || defined(CULA_SM103_ENABLED) + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; +#else + using EpilogueTileType = decltype(cute::take<0, 2>(TileShape{})); +#endif + using FusionOperation = + typename cutlass::epilogue::fusion::LinearCombination; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + TileShape, + ClusterShape, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC, + LayoutC, + kAlignmentC, + ElementD, + LayoutD, + kAlignmentD, + EpilogueSchedule, + FusionOperation>::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementAccumulator, + TileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal, CollectiveMainloop, CollectiveEpilogue>; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + const int64_t B = q.size(0); + const int64_t T = q.size(1); + const int64_t HV = q.size(2); + constexpr int K = kHeadDimQK; + const int64_t L = B * HV; + + LayoutA stride_A{HV * K, cute::_1{}, K}; + LayoutB stride_B{HV * K, cute::_1{}, K}; + LayoutC stride_C{T, cute::_1{}, T * T}; + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {static_cast(T), static_cast(T), K, static_cast(L)}, + { + reinterpret_cast(q.data_ptr()), + stride_A, + reinterpret_cast(k.data_ptr()), + stride_B, + }, + { + {1.0f, 0.0f}, + out.data_ptr(), + stride_C, + out.data_ptr(), + stride_C, + }, + }; + + Gemm gemm; + const size_t workspace_size = Gemm::get_workspace_size(arguments); + at::Tensor workspace = at::empty({static_cast(workspace_size)}, q.options().dtype(at::kByte)); + check_cutlass_status(gemm.can_implement(arguments), "qwen35_chunk_qk_prefill_sm90 can_implement"); + check_cutlass_status(gemm.initialize(arguments, workspace.data_ptr(), at::cuda::getCurrentCUDAStream(q.device().index())), "qwen35_chunk_qk_prefill_sm90 initialize"); + check_cutlass_status(gemm.run(at::cuda::getCurrentCUDAStream(q.device().index())), "qwen35_chunk_qk_prefill_sm90 run"); +} + +} // namespace + +void qwen35_chunk_qk_prefill_sm90(at::Tensor q, at::Tensor k, at::Tensor out) { + TORCH_CHECK(q.is_cuda(), "q must be CUDA"); + TORCH_CHECK(k.is_cuda(), "k must be CUDA"); + TORCH_CHECK(out.is_cuda(), "out must be CUDA"); + TORCH_CHECK(q.scalar_type() == at::kBFloat16, "q must be bfloat16"); + TORCH_CHECK(k.scalar_type() == at::kBFloat16, "k must be bfloat16"); + TORCH_CHECK(out.scalar_type() == at::kFloat, "out must be float32"); + TORCH_CHECK(q.is_contiguous(), "q must be contiguous [B,T,HV,128]"); + TORCH_CHECK(k.is_contiguous(), "k must be contiguous [B,T,HV,128]"); + TORCH_CHECK(out.is_contiguous(), "out must be contiguous [B,HV,T,T]"); + TORCH_CHECK(q.dim() == 4, "q must be [B,T,HV,128]"); + TORCH_CHECK(k.sizes() == q.sizes(), "k must match q"); + const int64_t B = q.size(0); + const int64_t T = q.size(1); + const int64_t HV = q.size(2); + TORCH_CHECK(decode::is_supported_local_v_heads(static_cast(HV)), "expected local HV in {48, 24, 12, 6}, got ", HV); + TORCH_CHECK(q.size(3) == kHeadDimQK, "expected D=128"); + TORCH_CHECK(out.sizes() == at::IntArrayRef({B, HV, T, T}), "out must be [B,HV,T,T]"); + + const at::cuda::OptionalCUDAGuard device_guard(q.device()); + run_qwen35_chunk_qk_prefill_sm90_impl(q, k, out); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +} // namespace cula::qwen35::prefill::sm90 diff --git a/csrc/qwen35/prefill/sm90/qwen35_chunk_prefill_traits_sm90.hpp b/csrc/qwen35/prefill/sm90/qwen35_chunk_prefill_traits_sm90.hpp new file mode 100644 index 00000000..c39e926e --- /dev/null +++ b/csrc/qwen35/prefill/sm90/qwen35_chunk_prefill_traits_sm90.hpp @@ -0,0 +1,112 @@ +// Copyright 2025-2026 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "qwen35/prefill/qwen35_prefill_common.cuh" + +#include +#include +#include +#include +#include +#include + +namespace cula::qwen35::prefill::sm90 { + +using namespace cute; + +// First SM90 chunk shape for Qwen3.5 prefill. +// +// This intentionally only describes the TMA/WGMMA tiles. The full chunk +// algorithm still needs a local chunk recurrence and inter-chunk state scan; +// those should be built on top of these traits instead of extending the scalar +// fallback kernel. +template +struct Qwen35ChunkPrefillSm90Traits { + static constexpr int kBlockT = kBlockT_; + static constexpr int kBlockV = kBlockV_; + static constexpr int kStages = kStages_; + + static_assert(kBlockT == 64 || kBlockT == 128, "GMMA chunk tiles expect BT=64 or BT=128."); + static_assert(kBlockV == 64 || kBlockV == 128, "V chunk tiles expect BV=64 or BV=128."); + static_assert(kHeadDimQK == 128); + static_assert(kHeadDimV == 128); + + using Element = cutlass::bfloat16_t; + using Accumulator = float; + static constexpr int kAlignment = 16 / sizeof(Element); + + using ClusterShape = Shape<_1, _1, _1>; + using StageCount = cutlass::gemm::collective::StageCount; + + // q/k/v are materialized by qwen35_layout_prefill as contiguous + // [total_tokens, 48, 128]. The TMA tensor view below exposes them as + // (token, dim, head), with dynamic strides: + // token stride = 48 * 128 + // dim stride = 1 + // head stride = 128 + using GmemStrideTDH = cute::tuple; + + using TileShapeQK = decltype(make_shape(Int{}, Int{}, Int{})); + using TileShapeOV = decltype(make_shape(Int{}, Int{}, Int{})); + + // Q @ K^T => [BT, BT]. CollectiveBuilder selects GMMA and TMA-compatible + // shared-memory layouts for SM90. + using CollectiveQK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, + cutlass::arch::OpClassTensorOp, + Element, + GmemStrideTDH, + kAlignment, + Element, + GmemStrideTDH, + kAlignment, + Accumulator, + TileShapeQK, + ClusterShape, + StageCount, + cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp; + + using TiledMmaQK = typename CollectiveQK::TiledMma; + using SmemLayoutQ = typename CollectiveQK::SmemLayoutA; + using SmemLayoutK = typename CollectiveQK::SmemLayoutB; + using TmaQ = typename CollectiveQK::Params::TMA_A; + using TmaK = typename CollectiveQK::Params::TMA_B; + + // Q @ state / local_value => [BT, BV]. This is the second core WGMMA shape + // needed once chunk-local state summaries are available. + using CollectiveOV = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, + cutlass::arch::OpClassTensorOp, + Element, + GmemStrideTDH, + kAlignment, + Element, + GmemStrideTDH, + kAlignment, + Accumulator, + TileShapeOV, + ClusterShape, + StageCount, + cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp; + + using TiledMmaOV = typename CollectiveOV::TiledMma; + using SmemLayoutOV_A = typename CollectiveOV::SmemLayoutA; + using SmemLayoutOV_B = typename CollectiveOV::SmemLayoutB; +}; + +using Qwen35ChunkPrefillSm90DefaultTraits = Qwen35ChunkPrefillSm90Traits<64, 64, 2>; + +} // namespace cula::qwen35::prefill::sm90 diff --git a/cula/__init__.py b/cula/__init__.py index 7272e289..2688dab2 100644 --- a/cula/__init__.py +++ b/cula/__init__.py @@ -14,8 +14,9 @@ __version__ = "0.1.0" -from cula.ops.lightning_attn_sm100 import LinearAttentionChunkwiseDecay +try: + from cula.ops.lightning_attn_sm100 import LinearAttentionChunkwiseDecay +except Exception: # pragma: no cover - optional runtime dependency + LinearAttentionChunkwiseDecay = None -__all__ = [ - "LinearAttentionChunkwiseDecay", -] +__all__ = ["LinearAttentionChunkwiseDecay"] diff --git a/cula/kda/__init__.py b/cula/kda/__init__.py index ee1a2bb9..98f5a14d 100644 --- a/cula/kda/__init__.py +++ b/cula/kda/__init__.py @@ -13,9 +13,19 @@ # limitations under the License. from cula.kda.blackwell_fused_fwd import flash_kda_prefill as kda_prefill_blackwell -from cula.kda.chunk import chunk_kda -from cula.kda.hopper_fused_fwd import cula_kda_prefill as kda_prefill_hopper -from cula.ops.kda_decode import fused_sigmoid_gating_delta_rule_update, kda_decode +try: + from cula.kda.chunk import chunk_kda +except Exception: # pragma: no cover - optional FLA dependency + chunk_kda = None +try: + from cula.kda.hopper_fused_fwd import cula_kda_prefill as kda_prefill_hopper +except Exception: # pragma: no cover - optional FLA/Hopper dependency + kda_prefill_hopper = None +try: + from cula.ops.kda_decode import fused_sigmoid_gating_delta_rule_update, kda_decode +except Exception: # pragma: no cover - optional CUDA/CuTe dependency + fused_sigmoid_gating_delta_rule_update = None + kda_decode = None __all__ = [ "chunk_kda", diff --git a/cula/kda/blackwell_fused_fwd.py b/cula/kda/blackwell_fused_fwd.py index c7dec95c..5ac821ba 100644 --- a/cula/kda/blackwell_fused_fwd.py +++ b/cula/kda/blackwell_fused_fwd.py @@ -17,6 +17,7 @@ import warnings import torch +import torch.nn.functional as F sys.path.insert(0, str(pathlib.Path(__file__).parent.parent)) @@ -24,13 +25,58 @@ import cutlass.cute as cute import cutlass.torch as cutlass_torch from cutlass.cute.runtime import from_dlpack -from fla.modules.l2norm import l2norm_fwd # from fla.ops.kda.chunk_inter import chunk_kda_bwd_dqkwg -from fla.ops.kda.gate import kda_gate_fwd -from fla.ops.utils import chunk_local_cumsum -from fla.ops.utils.constant import RCP_LN2 -from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard +try: + from fla.modules.l2norm import l2norm_fwd + from fla.ops.kda.gate import kda_gate_fwd + from fla.ops.utils import chunk_local_cumsum + from fla.ops.utils.constant import RCP_LN2 + from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard +except ImportError: + RCP_LN2 = 1.4426950408889634 + + def input_guard(fn): + return fn + + def autocast_custom_fwd(fn): + return fn + + def autocast_custom_bwd(fn): + return fn + + def l2norm_fwd(x: torch.Tensor): + rstd = torch.rsqrt(x.float().square().sum(dim=-1, keepdim=True).clamp_min(1.0e-12)) + return (x.float() * rstd).to(x.dtype), rstd + + def kda_gate_fwd(*args, **kwargs): + raise ImportError("fla is required for use_gate_in_kernel=True in blackwell_fused_fwd") + + def chunk_local_cumsum( + g: torch.Tensor, + chunk_size: int, + scale: float = 1.0, + cu_seqlens: torch.Tensor | None = None, + chunk_indices: torch.Tensor | None = None, + ) -> torch.Tensor: + if chunk_indices is not None: + raise ImportError("fla is required for chunk_indices support in blackwell_fused_fwd") + if cu_seqlens is not None: + if g.shape[0] != 1: + raise ValueError("cu_seqlens mode expects flattened g with batch size 1") + out = torch.empty_like(g.float()) + for seq_idx in range(cu_seqlens.numel() - 1): + start = int(cu_seqlens[seq_idx].item()) + end = int(cu_seqlens[seq_idx + 1].item()) + for chunk_start in range(start, end, chunk_size): + chunk_end = min(chunk_start + chunk_size, end) + out[:, chunk_start:chunk_end] = g[:, chunk_start:chunk_end].float().cumsum(dim=1) * scale + return out + chunks = [] + for chunk_start in range(0, g.shape[1], chunk_size): + chunk = g[:, chunk_start : chunk_start + chunk_size].float().cumsum(dim=1) * scale + chunks.append(chunk) + return torch.cat(chunks, dim=1).contiguous() from cula.ops.kda_fully_fused_sm100_wip import KDAChunkwise from cula.utils import USE_FAST_MATH, assert_blackwell @@ -64,6 +110,7 @@ def forward( use_gate_in_kernel: bool = False, safe_gate: bool = False, lower_bound: float | None = None, + g_is_cumsum: bool = False, cu_seqlens: torch.IntTensor | None = None, chunk_indices: torch.IntTensor | None = None, ): @@ -106,7 +153,7 @@ def forward( A_log=A_log, dt_bias=dt_bias, ) - if not (safe_gate and use_gate_in_kernel): + if not g_is_cumsum and not (safe_gate and use_gate_in_kernel): g = chunk_local_cumsum( g=g, chunk_size=chunk_size, scale=RCP_LN2, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices ) @@ -267,6 +314,7 @@ def flash_kda_prefill( use_gate_in_kernel: bool = False, safe_gate: bool = False, lower_bound: float | None = None, + g_is_cumsum: bool = False, cu_seqlens: torch.IntTensor | None = None, chunk_indices: torch.IntTensor | None = None, **kwargs, @@ -326,6 +374,7 @@ def flash_kda_prefill( use_gate_in_kernel, safe_gate, lower_bound, + g_is_cumsum, cu_seqlens, chunk_indices, ) diff --git a/cula/ops/__init__.py b/cula/ops/__init__.py index 6450488b..99d34e0d 100644 --- a/cula/ops/__init__.py +++ b/cula/ops/__init__.py @@ -12,11 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from cula.ops.kda_decode import fused_sigmoid_gating_delta_rule_update, kda_decode -from cula.ops.la_decode import linear_attention_decode +try: + from cula.ops.kda_decode import fused_sigmoid_gating_delta_rule_update, kda_decode +except Exception: # pragma: no cover - optional runtime dependency + fused_sigmoid_gating_delta_rule_update = None + kda_decode = None -__all__ = [ - "kda_decode", - "fused_sigmoid_gating_delta_rule_update", - "linear_attention_decode", -] +try: + from cula.ops.la_decode import linear_attention_decode +except Exception: # pragma: no cover - optional runtime dependency + linear_attention_decode = None + +__all__ = ["kda_decode", "fused_sigmoid_gating_delta_rule_update", "linear_attention_decode"] diff --git a/cula/ops/kda_fully_fused_sm100_wip.py b/cula/ops/kda_fully_fused_sm100_wip.py index ab09f0b2..ec3d5fe9 100644 --- a/cula/ops/kda_fully_fused_sm100_wip.py +++ b/cula/ops/kda_fully_fused_sm100_wip.py @@ -68,7 +68,12 @@ from cutlass.cute.nvgpu import cpasync, tcgen05 from cutlass.cute.runtime import from_dlpack from cutlass.cute.typing import Int32, Int64 -from fla.modules.l2norm import l2norm_fwd +try: + from fla.modules.l2norm import l2norm_fwd +except ImportError: + def l2norm_fwd(x: torch.Tensor): + rstd = torch.rsqrt(x.float().square().sum(dim=-1, keepdim=True).clamp_min(1.0e-12)) + return (x.float() * rstd).to(x.dtype), rstd from cula.utils import assert_blackwell @@ -2993,13 +2998,16 @@ def index_transform_half(index_q, index_k): # ------------------------------------------------------------ # NOTE: Save exp(g) of last VALID row to rG_last for state update in next chunk # For full chunks, directly use C-1; only loop for partial chunks (varlen only) + rG_last = cutlass.Float32(1.0) if cutlass.const_expr(self.is_varlen): if valid_len_chunk < C: - rG_last = exp_g[valid_len_chunk - 1] + for _zr in cutlass.range(0, Constant.C, unroll_full=True): + if _zr == valid_len_chunk - 1: + rG_last = cute.exp2(tRS_rG[0, _zr, 0], fastmath=self.use_fast_math) else: - rG_last = exp_g[Constant.C - 1] + rG_last = cute.exp2(tRS_rG[0, Constant.C - 1, 0], fastmath=self.use_fast_math) else: - rG_last = exp_g[Constant.C - 1] + rG_last = cute.exp2(tRS_rG[0, Constant.C - 1, 0], fastmath=self.use_fast_math) # NOTE: each thread save one element sG_last[local_tidx, g_stage_idx] = rG_last diff --git a/cula/ops/qwen35_conv1d_decode.py b/cula/ops/qwen35_conv1d_decode.py new file mode 100644 index 00000000..bd0af8f2 --- /dev/null +++ b/cula/ops/qwen35_conv1d_decode.py @@ -0,0 +1,119 @@ +# Copyright 2025-2026 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Qwen3.5 single-token conv-state update wrapper.""" + +from __future__ import annotations + +import torch + +try: + import cula.cudac as cula_cuda +except ImportError: + cula_cuda = None + + +def qwen35_conv1d_decode_reference( + x_t: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Pure torch reference for Qwen3.5 single-token depthwise conv decode.""" + if weight.ndim == 3: + weight = weight.squeeze(1) + + state_tail = conv_state[..., 1:].to(torch.float32) + x_last = x_t.unsqueeze(-1).to(torch.float32) + window = torch.cat([state_tail, x_last], dim=-1) + conv = (window * weight.to(torch.float32).unsqueeze(0)).sum(dim=-1) + y = torch.nn.functional.silu(conv).to(dtype=x_t.dtype) + + conv_state_out = conv_state.clone() + conv_state_out[..., 0] = conv_state[..., 1] + conv_state_out[..., 1] = conv_state[..., 2] + conv_state_out[..., 2] = conv_state[..., 3] + conv_state_out[..., 3] = x_t + return y, conv_state_out + + +def qwen35_conv1d_decode_update( + x_t: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + *, + activation: str = "silu", + backend: str = "auto", +) -> tuple[torch.Tensor, torch.Tensor]: + """Single-token depthwise causal conv1d update. + + Expected shapes: + - x_t: [B, C] + - conv_state: [B, C, 4] + - weight: [C, 1, 4] or [C, 4] + """ + + if activation != "silu": + raise ValueError(f"Only silu activation is currently supported, got {activation}") + if x_t.ndim != 2: + raise ValueError(f"x_t must be 2D [batch, channels], got {tuple(x_t.shape)}") + if conv_state.ndim != 3: + raise ValueError(f"conv_state must be 3D [batch, channels, kernel], got {tuple(conv_state.shape)}") + if conv_state.shape[:2] != x_t.shape: + raise ValueError(f"conv_state batch/channel dims must match x_t, got x_t={tuple(x_t.shape)} conv_state={tuple(conv_state.shape)}") + kernel_size = conv_state.shape[-1] + if kernel_size != 4: + raise ValueError(f"Expected kernel_size=4 for Qwen3.5, got {kernel_size}") + + if weight.ndim == 3: + if weight.shape[1] != 1 or weight.shape[2] != kernel_size: + raise ValueError(f"weight must be [channels,1,{kernel_size}], got {tuple(weight.shape)}") + weight_2d = weight.squeeze(1) + elif weight.ndim == 2: + if weight.shape[1] != kernel_size: + raise ValueError(f"weight must be [channels,{kernel_size}], got {tuple(weight.shape)}") + weight_2d = weight + else: + raise ValueError(f"weight must be 2D or 3D, got {tuple(weight.shape)}") + + if weight_2d.shape[0] != x_t.shape[1]: + raise ValueError(f"weight channels must match x_t channels, got weight={tuple(weight_2d.shape)} x_t={tuple(x_t.shape)}") + + x_t = x_t.contiguous() + conv_state = conv_state.contiguous() + weight_2d = weight_2d.contiguous() + + use_cudac = ( + backend in ("auto", "cudac") + and cula_cuda is not None + and hasattr(cula_cuda, "qwen35_conv1d_decode") + and x_t.is_cuda + ) + if backend == "cudac" and not use_cudac: + raise RuntimeError("Requested backend='cudac' but qwen35_conv1d_decode is not available.") + + if use_cudac: + mixed_qkv_3d = x_t.unsqueeze(1).contiguous() + out_3d = torch.empty_like(mixed_qkv_3d) + conv_state_out = conv_state.clone() + cula_cuda.qwen35_conv1d_decode( + mixed_qkv_3d, + conv_state_out, + weight_2d, + out_3d, + ) + return out_3d.squeeze(1), conv_state_out + + if backend not in ("auto", "reference"): + raise ValueError(f"Unsupported backend={backend}") + return qwen35_conv1d_decode_reference(x_t, conv_state, weight_2d) diff --git a/cula/ops/qwen35_conv1d_prefill.py b/cula/ops/qwen35_conv1d_prefill.py new file mode 100644 index 00000000..3a32b1e0 --- /dev/null +++ b/cula/ops/qwen35_conv1d_prefill.py @@ -0,0 +1,99 @@ +# Copyright 2025-2026 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Qwen3.5 depthwise causal conv1d prefill wrapper.""" + +from __future__ import annotations + +import torch + + +def qwen35_conv1d_prefill( + x: torch.Tensor, + weight: torch.Tensor, + *, + activation: str = "silu", + cu_seqlens: torch.Tensor | None = None, + output_final_state: bool = False, +) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + """Depthwise causal conv1d over a full sequence. + + Expected shapes: + - x: [B, C, S] or flattened [T, C] + - weight: [C, 1, 4] or [C, 4] + """ + + if activation != "silu": + raise ValueError(f"Unsupported activation={activation}") + if weight.ndim == 3: + if weight.shape[1] != 1: + raise ValueError(f"weight must be [C,1,K] or [C,K], got {tuple(weight.shape)}") + weight_2d = weight[:, 0, :] + elif weight.ndim == 2: + weight_2d = weight + else: + raise ValueError(f"weight must be [C,1,K] or [C,K], got {tuple(weight.shape)}") + + kernel_size = weight_2d.shape[1] + + def _conv_one(seq: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + # seq: [S, C] + if seq.ndim != 2 or seq.shape[1] != weight_2d.shape[0]: + raise ValueError(f"sequence must be [S,C={weight_2d.shape[0]}], got {tuple(seq.shape)}") + seq_f = seq.float() + weight_f = weight_2d.float() + out = torch.empty_like(seq) + for t in range(seq.shape[0]): + acc = torch.zeros(seq.shape[1], device=seq.device, dtype=torch.float32) + for kk in range(kernel_size): + src_t = t - (kernel_size - 1 - kk) + if src_t >= 0: + acc = acc + seq_f[src_t] * weight_f[:, kk] + out[t] = torch.nn.functional.silu(acc).to(seq.dtype) + + state = torch.zeros(seq.shape[1], kernel_size, device=seq.device, dtype=seq.dtype) + take = min(kernel_size, seq.shape[0]) + if take > 0: + state[:, kernel_size - take :] = seq[-take:].transpose(0, 1) + return out, state + + if x.ndim == 3: + # Public op shape follows the Qwen conv convention [B, C, S]. + if x.shape[1] != weight_2d.shape[0]: + raise ValueError(f"x channel dim must match weight, got x={tuple(x.shape)} weight={tuple(weight_2d.shape)}") + y = torch.empty_like(x) + states = torch.empty(x.shape[0], x.shape[1], kernel_size, device=x.device, dtype=x.dtype) + for bidx in range(x.shape[0]): + y_b, state_b = _conv_one(x[bidx].transpose(0, 1).contiguous()) + y[bidx] = y_b.transpose(0, 1).contiguous() + states[bidx] = state_b + return (y, states) if output_final_state else y + + if x.ndim == 2: + if cu_seqlens is None: + y, state = _conv_one(x) + return (y, state.unsqueeze(0)) if output_final_state else y + if cu_seqlens.ndim != 1 or cu_seqlens.dtype != torch.int32: + raise ValueError(f"cu_seqlens must be 1D int32, got {tuple(cu_seqlens.shape)} {cu_seqlens.dtype}") + y = torch.empty_like(x) + states = torch.empty(cu_seqlens.numel() - 1, x.shape[1], kernel_size, device=x.device, dtype=x.dtype) + for sidx in range(cu_seqlens.numel() - 1): + start = int(cu_seqlens[sidx].item()) + end = int(cu_seqlens[sidx + 1].item()) + y_s, state_s = _conv_one(x[start:end]) + y[start:end] = y_s + states[sidx] = state_s + return (y, states) if output_final_state else y + + raise ValueError(f"x must be [B,C,S] or [T,C], got {tuple(x.shape)}") diff --git a/cula/ops/qwen35_fused_kda_prefill.py b/cula/ops/qwen35_fused_kda_prefill.py new file mode 100644 index 00000000..6691fdd8 --- /dev/null +++ b/cula/ops/qwen35_fused_kda_prefill.py @@ -0,0 +1,132 @@ +# Copyright 2025-2026 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Qwen3.5 adapter for the generic fused KDA prefill core.""" + +from __future__ import annotations + +import torch +import torch.nn.functional as F + + +def _resolve_fused_kda_prefill(device: torch.device | str | int | None = None): + try: + from cula.utils import get_kda_fused_fwd + except Exception as exc: # pragma: no cover - depends on optional runtime deps + raise RuntimeError(f"Cannot import fused KDA selector: {exc}") from exc + + try: + return get_kda_fused_fwd(device) + except Exception as exc: + raise RuntimeError(f"Cannot resolve fused KDA prefill for device={device}: {exc}") from exc + + +def has_qwen35_fused_kda_prefill(device: torch.device | str | int | None = None) -> bool: + try: + _resolve_fused_kda_prefill(device) + except Exception: + return False + return True + + +def _validate_inputs( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + A_log: torch.Tensor, + dt_bias: torch.Tensor, + initial_state: torch.Tensor | None, + cu_seqlens: torch.Tensor | None, +) -> tuple[int, int, int, int, torch.Tensor, torch.Tensor]: + if q.ndim != 4 or k.ndim != 4 or v.ndim != 4: + raise ValueError(f"q/k/v must be 4D [B,T,HV,D], got q={tuple(q.shape)} k={tuple(k.shape)} v={tuple(v.shape)}") + if q.shape != k.shape or q.shape != v.shape: + raise ValueError(f"q/k/v must have the same shape, got q={tuple(q.shape)} k={tuple(k.shape)} v={tuple(v.shape)}") + B, T, HV, K = q.shape + if K != 128: + raise ValueError(f"Qwen3.5 fused prefill expects head dim 128, got {K}") + if a.ndim == 2: + a = a.unsqueeze(0) + if b.ndim == 2: + b = b.unsqueeze(0) + if a.shape != (B, T, HV) or b.shape != (B, T, HV): + raise ValueError(f"a/b must be [B,T,HV], got a={tuple(a.shape)} b={tuple(b.shape)} expected={(B, T, HV)}") + if A_log.shape != (HV,) or dt_bias.shape != (HV,): + raise ValueError(f"A_log/dt_bias must be [HV], got A_log={tuple(A_log.shape)} dt_bias={tuple(dt_bias.shape)}") + if cu_seqlens is not None: + if B != 1: + raise ValueError("cu_seqlens mode expects flattened q/k/v with batch size 1") + if cu_seqlens.ndim != 1 or cu_seqlens.dtype != torch.int32: + raise ValueError(f"cu_seqlens must be 1D int32, got {tuple(cu_seqlens.shape)} {cu_seqlens.dtype}") + state_count = B if cu_seqlens is None else cu_seqlens.numel() - 1 + if initial_state is not None and initial_state.shape != (state_count, HV, K, K): + raise ValueError(f"initial_state must be [{state_count},{HV},128,128], got {tuple(initial_state.shape)}") + return B, T, HV, K, a, b + + +def qwen35_fused_kda_prefill( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + A_log: torch.Tensor, + dt_bias: torch.Tensor, + *, + initial_state: torch.Tensor | None = None, + cu_seqlens: torch.Tensor | None = None, + output_final_state: bool = True, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """Run Qwen3.5 scalar-gated KDA prefill through the fused CuTe KDA core. + + Qwen uses a scalar gate per token/head. The generic KDA fused core expects + a vector gate, so this adapter broadcasts the scalar log-gate over D=128. + State is exposed in Qwen layout [N, HV, K, V]. The fused core consumes the + transposed initial-state layout, but returns final state in Qwen layout. + """ + + if not q.is_cuda: + raise RuntimeError("qwen35_fused_kda_prefill requires CUDA tensors.") + B, T, HV, K, a, b = _validate_inputs(q, k, v, a, b, A_log, dt_bias, initial_state, cu_seqlens) + fused_kda_prefill = _resolve_fused_kda_prefill(q.device) + + log_gate_scalar = -torch.exp(A_log.float()).view(1, 1, HV, 1) * F.softplus( + a.float().unsqueeze(-1) + dt_bias.float().view(1, 1, HV, 1) + ) + log_gate = log_gate_scalar.expand(B, T, HV, K).contiguous() + beta = torch.sigmoid(b.float()).contiguous() + + initial_state_vk = None + if initial_state is not None: + initial_state_vk = initial_state.float().transpose(-1, -2).contiguous() + + out, final_state_vk = fused_kda_prefill( + q=q.contiguous(), + k=k.contiguous(), + v=v.contiguous(), + g=log_gate, + beta=beta, + scale=K**-0.5, + initial_state=initial_state_vk, + output_final_state=output_final_state, + use_qk_l2norm_in_kernel=True, + use_gate_in_kernel=False, + safe_gate=False, + lower_bound=None, + cu_seqlens=cu_seqlens, + ) + final_state = None if final_state_vk is None else final_state_vk.contiguous() + return out, final_state diff --git a/cula/ops/qwen35_layout_decode.py b/cula/ops/qwen35_layout_decode.py new file mode 100644 index 00000000..eec5f975 --- /dev/null +++ b/cula/ops/qwen35_layout_decode.py @@ -0,0 +1,112 @@ +# Copyright 2025-2026 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Qwen3.5 layout decode wrapper.""" + +from __future__ import annotations + +import torch + +from cula.qwen35.common import DEFAULT_QWEN35_LINEAR_ATTN_CONFIG, Qwen35LinearAttentionConfig, infer_local_config + +try: + import cula.cudac as cula_cuda +except ImportError: + cula_cuda = None + + +def qwen35_layout_decode_reference( + mixed_qkv_conv: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + *, + config: Qwen35LinearAttentionConfig = DEFAULT_QWEN35_LINEAR_ATTN_CONFIG, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + tokens = mixed_qkv_conv.shape[0] + local_num_v_heads = a.shape[1] + local_key_dim, _, local_num_k_heads = infer_local_config( + mixed_qkv_conv.shape[1], + local_num_v_heads, + config=config, + ) + + q_end = local_key_dim + k_end = q_end + local_key_dim + q_flat = mixed_qkv_conv[:, :q_end] + k_flat = mixed_qkv_conv[:, q_end:k_end] + v_flat = mixed_qkv_conv[:, k_end:] + + q = q_flat.view(tokens, local_num_k_heads, config.head_k_dim) + k = k_flat.view(tokens, local_num_k_heads, config.head_k_dim) + v = v_flat.view(tokens, local_num_v_heads, config.head_v_dim) + + repeat_factor = local_num_v_heads // local_num_k_heads + q_rep = q.repeat_interleave(repeat_factor, dim=1).contiguous() + k_rep = k.repeat_interleave(repeat_factor, dim=1).contiguous() + return q_rep, k_rep, v.contiguous(), a.contiguous(), b.contiguous() + + +def qwen35_layout_decode( + mixed_qkv_conv: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + *, + config: Qwen35LinearAttentionConfig = DEFAULT_QWEN35_LINEAR_ATTN_CONFIG, + backend: str = "auto", +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + use_cudac = ( + backend in ("auto", "cudac") + and cula_cuda is not None + and hasattr(cula_cuda, "qwen35_layout_decode") + and mixed_qkv_conv.is_cuda + ) + if backend == "cudac" and not use_cudac: + raise RuntimeError("Requested backend='cudac' but qwen35_layout_decode is not available.") + + if use_cudac: + tokens = mixed_qkv_conv.shape[0] + local_num_v_heads = a.shape[1] + infer_local_config(mixed_qkv_conv.shape[1], local_num_v_heads, config=config) + q_rep = torch.empty( + tokens, + local_num_v_heads, + config.head_k_dim, + device=mixed_qkv_conv.device, + dtype=mixed_qkv_conv.dtype, + ) + k_rep = torch.empty_like(q_rep) + v = torch.empty( + tokens, + local_num_v_heads, + config.head_v_dim, + device=mixed_qkv_conv.device, + dtype=mixed_qkv_conv.dtype, + ) + a_kernel = torch.empty_like(a) + b_kernel = torch.empty_like(b) + cula_cuda.qwen35_layout_decode( + mixed_qkv_conv.contiguous(), + a.contiguous(), + b.contiguous(), + q_rep, + k_rep, + v, + a_kernel, + b_kernel, + ) + return q_rep, k_rep, v, a_kernel, b_kernel + + if backend not in ("auto", "reference"): + raise ValueError(f"Unsupported backend={backend}") + return qwen35_layout_decode_reference(mixed_qkv_conv, a, b, config=config) diff --git a/cula/ops/qwen35_layout_prefill.py b/cula/ops/qwen35_layout_prefill.py new file mode 100644 index 00000000..4a6a2867 --- /dev/null +++ b/cula/ops/qwen35_layout_prefill.py @@ -0,0 +1,96 @@ +# Copyright 2025-2026 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Qwen3.5 layout prefill wrapper.""" + +from __future__ import annotations + +import torch + +from cula.qwen35.common import DEFAULT_QWEN35_LINEAR_ATTN_CONFIG, Qwen35LinearAttentionConfig, infer_local_config + +try: + import cula.cudac as cula_cuda +except ImportError: + cula_cuda = None + + +def qwen35_layout_prefill_reference( + mixed_qkv_conv: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + *, + config: Qwen35LinearAttentionConfig = DEFAULT_QWEN35_LINEAR_ATTN_CONFIG, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + tokens = mixed_qkv_conv.shape[0] + local_num_v_heads = a.shape[1] + local_key_dim, _, local_num_k_heads = infer_local_config( + mixed_qkv_conv.shape[1], + local_num_v_heads, + config=config, + ) + + q_end = local_key_dim + k_end = q_end + local_key_dim + q = mixed_qkv_conv[:, :q_end].view(tokens, local_num_k_heads, config.head_k_dim) + k = mixed_qkv_conv[:, q_end:k_end].view(tokens, local_num_k_heads, config.head_k_dim) + v = mixed_qkv_conv[:, k_end:].view(tokens, local_num_v_heads, config.head_v_dim) + + repeat_factor = local_num_v_heads // local_num_k_heads + q_rep = q.repeat_interleave(repeat_factor, dim=1).contiguous() + k_rep = k.repeat_interleave(repeat_factor, dim=1).contiguous() + return q_rep, k_rep, v.contiguous(), a.contiguous(), b.contiguous() + + +def qwen35_layout_prefill( + mixed_qkv_conv: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + *, + config: Qwen35LinearAttentionConfig = DEFAULT_QWEN35_LINEAR_ATTN_CONFIG, + backend: str = "auto", +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + use_cudac = ( + backend in ("auto", "cudac") + and cula_cuda is not None + and hasattr(cula_cuda, "qwen35_layout_prefill") + and mixed_qkv_conv.is_cuda + ) + if backend == "cudac" and not use_cudac: + raise RuntimeError("Requested backend='cudac' but qwen35_layout_prefill is not available.") + + if use_cudac: + tokens = mixed_qkv_conv.shape[0] + local_num_v_heads = a.shape[1] + infer_local_config(mixed_qkv_conv.shape[1], local_num_v_heads, config=config) + q_rep = torch.empty(tokens, local_num_v_heads, config.head_k_dim, device=mixed_qkv_conv.device, dtype=mixed_qkv_conv.dtype) + k_rep = torch.empty_like(q_rep) + v = torch.empty(tokens, local_num_v_heads, config.head_v_dim, device=mixed_qkv_conv.device, dtype=mixed_qkv_conv.dtype) + a_kernel = torch.empty_like(a) + b_kernel = torch.empty_like(b) + cula_cuda.qwen35_layout_prefill( + mixed_qkv_conv.contiguous(), + a.contiguous(), + b.contiguous(), + q_rep, + k_rep, + v, + a_kernel, + b_kernel, + ) + return q_rep, k_rep, v, a_kernel, b_kernel + + if backend not in ("auto", "reference"): + raise ValueError(f"Unsupported backend={backend}") + return qwen35_layout_prefill_reference(mixed_qkv_conv, a, b, config=config) diff --git a/cula/ops/qwen35_scalar_kda_decode.py b/cula/ops/qwen35_scalar_kda_decode.py new file mode 100644 index 00000000..45d2c6a6 --- /dev/null +++ b/cula/ops/qwen35_scalar_kda_decode.py @@ -0,0 +1,227 @@ +# Copyright 2025-2026 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CuTe DSL placeholder for Qwen3.5 scalar-gated KDA decode.""" + +from __future__ import annotations + +import torch + +from cula.qwen35.common import DEFAULT_QWEN35_LINEAR_ATTN_CONFIG, Qwen35LinearAttentionConfig + +try: + import cula.cudac as cula_cuda +except ImportError: + cula_cuda = None + + +def has_qwen35_layout_scalar_kda_decode_cudac() -> bool: + return cula_cuda is not None and hasattr(cula_cuda, "qwen35_layout_scalar_kda_decode") + + +def _validate_cudac_state_indices(state_indices: torch.Tensor, *, rows: int, pool_size: int) -> None: + if state_indices.ndim != 1 or state_indices.numel() != rows: + raise ValueError(f"state_indices must be 1D with {rows} entries, got {tuple(state_indices.shape)}") + if rows == 0: + return + min_idx = int(state_indices.min().item()) + max_idx = int(state_indices.max().item()) + if min_idx < 0 or max_idx >= pool_size: + raise ValueError(f"state_indices must be in [0, {pool_size}), got min={min_idx} max={max_idx}") + if torch.unique(state_indices).numel() != rows: + raise ValueError( + "backend='cudac' requires unique state_indices within one decode launch; " + "duplicate rows need a sequential decode path." + ) + + +def qwen35_scalar_kda_decode( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + A_log: torch.Tensor, + dt_bias: torch.Tensor, + recurrent_state: torch.Tensor, + *, + state_indices: torch.Tensor | None = None, + backend: str = "auto", +) -> tuple[torch.Tensor, torch.Tensor]: + """Single-token scalar-gated delta-rule decode for Qwen3.5.""" + if q.ndim != 4 or k.ndim != 4 or v.ndim != 4: + raise ValueError(f"q/k/v must be 4D, got q={tuple(q.shape)} k={tuple(k.shape)} v={tuple(v.shape)}") + if q.shape != k.shape: + raise ValueError(f"q and k must have the same shape, got q={tuple(q.shape)} vs k={tuple(k.shape)}") + if q.shape[1] != 1 or v.shape[1] != 1: + raise ValueError(f"Decode expects single-token sequence dim, got q={tuple(q.shape)} v={tuple(v.shape)}") + + N, _, HV, K = q.shape + if a.ndim == 2: + a = a.unsqueeze(1) + if b.ndim == 2: + b = b.unsqueeze(1) + if a.shape != (N, 1, HV) or b.shape != (N, 1, HV): + raise ValueError(f"a/b must be [N,1,HV], got a={tuple(a.shape)} b={tuple(b.shape)}") + if A_log.shape != (HV,) or dt_bias.shape != (HV,): + raise ValueError(f"A_log/dt_bias must be [HV], got A_log={tuple(A_log.shape)} dt_bias={tuple(dt_bias.shape)}") + + state_indices = ( + torch.arange(N, device=q.device, dtype=torch.int32) + if state_indices is None + else state_indices.to(device=q.device, dtype=torch.int32) + ) + + use_cudac = ( + backend in ("auto", "cudac") + and cula_cuda is not None + and hasattr(cula_cuda, "qwen35_scalar_kda_decode") + and q.is_cuda + ) + if backend == "cudac" and not use_cudac: + raise RuntimeError("Requested backend='cudac' but qwen35_scalar_kda_decode is not available.") + + if use_cudac: + _validate_cudac_state_indices(state_indices, rows=N, pool_size=recurrent_state.shape[0]) + q_rep = q.squeeze(1).contiguous() + k_rep = k.squeeze(1).contiguous() + v_rep = v.squeeze(1).contiguous() + a_kernel = a.squeeze(1).contiguous() + b_kernel = b.squeeze(1).contiguous() + out = torch.empty_like(v_rep) + recurrent_state_out = recurrent_state.clone() + cula_cuda.qwen35_scalar_kda_decode( + q_rep, + k_rep, + v_rep, + a_kernel, + b_kernel, + A_log.contiguous(), + dt_bias.contiguous(), + recurrent_state_out, + state_indices, + out, + ) + return out.unsqueeze(1), recurrent_state_out + + if backend not in ("auto", "generic_kda"): + raise ValueError(f"Unsupported backend={backend}") + + from cula.ops.kda_decode import kda_decode + + a_expanded = a.unsqueeze(-1).expand(N, 1, HV, K) + dt_bias_expanded = dt_bias[:, None].expand(HV, K).contiguous() + o = kda_decode( + A_log=A_log.contiguous(), + dt_bias=dt_bias_expanded, + q=q.contiguous(), + k=k.contiguous(), + v=v.contiguous(), + a=a_expanded.contiguous(), + b=b.contiguous(), + initial_state_source=recurrent_state, + initial_state_indices=state_indices, + scale=K**-0.5, + use_qk_l2norm_in_kernel=True, + state_layout="kv", + ) + return o, recurrent_state + + +def qwen35_layout_scalar_kda_decode( + mixed_qkv_conv: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + A_log: torch.Tensor, + dt_bias: torch.Tensor, + recurrent_state: torch.Tensor, + *, + state_indices: torch.Tensor | None = None, + config: Qwen35LinearAttentionConfig = DEFAULT_QWEN35_LINEAR_ATTN_CONFIG, + backend: str = "auto", +) -> tuple[torch.Tensor, torch.Tensor]: + """Fused Qwen3.5 layout decode + scalar-gated KDA decode.""" + if mixed_qkv_conv.ndim != 2: + raise ValueError(f"mixed_qkv_conv must be 2D, got {tuple(mixed_qkv_conv.shape)}") + if a.ndim == 3: + if a.shape[1] != 1: + raise ValueError(f"a sequence dim must be 1 for decode, got {tuple(a.shape)}") + a = a.squeeze(1) + if b.ndim == 3: + if b.shape[1] != 1: + raise ValueError(f"b sequence dim must be 1 for decode, got {tuple(b.shape)}") + b = b.squeeze(1) + + N = mixed_qkv_conv.shape[0] + if a.ndim != 2 or b.ndim != 2 or a.shape != b.shape or a.shape[0] != N: + raise ValueError(f"a/b must be [N, HV], got a={tuple(a.shape)} b={tuple(b.shape)}") + HV = a.shape[1] + if A_log.shape != (HV,) or dt_bias.shape != (HV,): + raise ValueError(f"A_log/dt_bias must be [HV], got A_log={tuple(A_log.shape)} dt_bias={tuple(dt_bias.shape)}") + + state_indices = ( + torch.arange(N, device=mixed_qkv_conv.device, dtype=torch.int32) + if state_indices is None + else state_indices.to(device=mixed_qkv_conv.device, dtype=torch.int32) + ) + + use_cudac = ( + backend in ("auto", "cudac") + and cula_cuda is not None + and hasattr(cula_cuda, "qwen35_layout_scalar_kda_decode") + and mixed_qkv_conv.is_cuda + ) + if backend == "cudac" and not use_cudac: + raise RuntimeError("Requested backend='cudac' but qwen35_layout_scalar_kda_decode is not available.") + + if use_cudac: + _validate_cudac_state_indices(state_indices, rows=N, pool_size=recurrent_state.shape[0]) + out = torch.empty( + N, + HV, + recurrent_state.shape[-1], + device=mixed_qkv_conv.device, + dtype=mixed_qkv_conv.dtype, + ) + recurrent_state_out = recurrent_state.clone() + cula_cuda.qwen35_layout_scalar_kda_decode( + mixed_qkv_conv.contiguous(), + a.contiguous(), + b.contiguous(), + A_log.contiguous(), + dt_bias.contiguous(), + recurrent_state_out, + state_indices, + out, + ) + return out.unsqueeze(1), recurrent_state_out + + if backend not in ("auto", "generic_kda"): + raise ValueError(f"Unsupported backend={backend}") + + from cula.ops.qwen35_layout_decode import qwen35_layout_decode_reference + + q_rep, k_rep, v, a_kernel, b_kernel = qwen35_layout_decode_reference(mixed_qkv_conv, a, b, config=config) + return qwen35_scalar_kda_decode( + q=q_rep.unsqueeze(1).contiguous(), + k=k_rep.unsqueeze(1).contiguous(), + v=v.unsqueeze(1).contiguous(), + a=a_kernel, + b=b_kernel, + A_log=A_log, + dt_bias=dt_bias, + recurrent_state=recurrent_state, + state_indices=state_indices, + backend="generic_kda", + ) diff --git a/cula/ops/qwen35_scalar_kda_prefill.py b/cula/ops/qwen35_scalar_kda_prefill.py new file mode 100644 index 00000000..a8ce4204 --- /dev/null +++ b/cula/ops/qwen35_scalar_kda_prefill.py @@ -0,0 +1,143 @@ +# Copyright 2025-2026 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Qwen3.5 scalar-gated KDA prefill wrapper.""" + +from __future__ import annotations + +import torch + +try: + import cula.cudac as cula_cuda +except ImportError: + cula_cuda = None + + +def qwen35_scalar_kda_prefill( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + A_log: torch.Tensor, + dt_bias: torch.Tensor, + *, + initial_state: torch.Tensor | None = None, + cu_seqlens: torch.Tensor | None = None, + backend: str = "auto", +) -> tuple[torch.Tensor, torch.Tensor | None]: + """Chunked scalar-gated delta-rule prefill for Qwen3.5.""" + + if q.ndim != 4 or k.ndim != 4 or v.ndim != 4: + raise ValueError(f"q/k/v must be 4D [B,T,HV,D], got q={tuple(q.shape)} k={tuple(k.shape)} v={tuple(v.shape)}") + if q.shape != k.shape or q.shape != v.shape: + raise ValueError(f"q/k/v must have the same shape, got q={tuple(q.shape)} k={tuple(k.shape)} v={tuple(v.shape)}") + B, T, HV, K = q.shape + if K != 128 or v.shape[-1] != 128: + raise ValueError(f"Qwen3.5 prefill expects K=V=128, got q={tuple(q.shape)} v={tuple(v.shape)}") + if a.ndim == 2: + a = a.unsqueeze(0) + if b.ndim == 2: + b = b.unsqueeze(0) + if a.shape != (B, T, HV) or b.shape != (B, T, HV): + raise ValueError(f"a/b must be [B,T,HV], got a={tuple(a.shape)} b={tuple(b.shape)} expected={(B, T, HV)}") + if A_log.shape != (HV,) or dt_bias.shape != (HV,): + raise ValueError(f"A_log/dt_bias must be [HV], got A_log={tuple(A_log.shape)} dt_bias={tuple(dt_bias.shape)}") + if cu_seqlens is not None: + if B != 1: + raise ValueError("cu_seqlens mode expects flattened q/k/v with batch size 1") + if cu_seqlens.ndim != 1 or cu_seqlens.dtype != torch.int32: + raise ValueError(f"cu_seqlens must be 1D int32, got {tuple(cu_seqlens.shape)} {cu_seqlens.dtype}") + if initial_state is not None and initial_state.shape[1:] != (HV, K, K): + raise ValueError(f"initial_state must be [N,HV,128,128], got {tuple(initial_state.shape)}") + + use_cudac = ( + backend in ("auto", "cudac") + and cula_cuda is not None + and hasattr(cula_cuda, "qwen35_scalar_kda_prefill") + and q.is_cuda + ) + if backend == "cudac" and not use_cudac: + raise RuntimeError("Requested backend='cudac' but qwen35_scalar_kda_prefill is not available.") + + if use_cudac: + if HV not in (48, 24, 12, 6): + raise ValueError(f"backend='cudac' supports Qwen3.5 local HV in (48, 24, 12, 6), got {HV}") + state_count = B if cu_seqlens is None else cu_seqlens.numel() - 1 + out = torch.empty_like(v) + final_state = torch.empty(state_count, HV, K, K, device=q.device, dtype=torch.float32) + initial_state_arg = ( + torch.empty(0, device=q.device, dtype=torch.float32) + if initial_state is None + else initial_state.contiguous() + ) + cu_seqlens_arg = ( + torch.empty(0, device=q.device, dtype=torch.int32) + if cu_seqlens is None + else cu_seqlens.to(device=q.device, dtype=torch.int32).contiguous() + ) + cula_cuda.qwen35_scalar_kda_prefill( + q.contiguous(), + k.contiguous(), + v.contiguous(), + a.contiguous(), + b.contiguous(), + A_log.contiguous(), + dt_bias.contiguous(), + initial_state_arg, + cu_seqlens_arg, + out, + final_state, + ) + return out, final_state + + if backend not in ("auto", "reference"): + raise ValueError(f"Unsupported backend={backend}") + + state_count = B if cu_seqlens is None else cu_seqlens.numel() - 1 + state = ( + torch.zeros(state_count, HV, K, K, device=q.device, dtype=torch.float32) + if initial_state is None + else initial_state.float().clone() + ) + out = torch.empty_like(v) + q_f = torch.nn.functional.normalize(q.float(), dim=-1) * (K**-0.5) + k_f = torch.nn.functional.normalize(k.float(), dim=-1) + v_f = v.float() + a_f = a.float() + b_f = b.float() + A_log_f = A_log.float() + dt_bias_f = dt_bias.float() + + def _run_sequence(batch_idx: int, state_idx: int, start: int, end: int) -> None: + for t in range(start, end): + for hv in range(HV): + state_kv = state[state_idx, hv] + decay = torch.exp(-torch.exp(A_log_f[hv]) * torch.nn.functional.softplus(a_f[batch_idx, t, hv] + dt_bias_f[hv])) + beta = torch.sigmoid(b_f[batch_idx, t, hv]) + k_vec = k_f[batch_idx, t, hv] + q_vec = q_f[batch_idx, t, hv] + proj = decay * (state_kv.transpose(0, 1) @ k_vec) + v_new = beta * (v_f[batch_idx, t, hv] - proj) + state_kv_new = decay * state_kv + k_vec.unsqueeze(1) * v_new.unsqueeze(0) + out[batch_idx, t, hv] = (state_kv_new.transpose(0, 1) @ q_vec).to(out.dtype) + state[state_idx, hv] = state_kv_new + + if cu_seqlens is None: + for bidx in range(B): + _run_sequence(bidx, bidx, 0, T) + else: + for sidx in range(state_count): + _run_sequence(0, sidx, int(cu_seqlens[sidx].item()), int(cu_seqlens[sidx + 1].item())) + return out, state diff --git a/cula/qwen35/__init__.py b/cula/qwen35/__init__.py new file mode 100644 index 00000000..4cb88c3d --- /dev/null +++ b/cula/qwen35/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2025-2026 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Qwen3.5-specific linear attention support built on top of cuLA primitives.""" + +from cula.qwen35.common import Qwen35LinearAttentionConfig + +try: + from cula.qwen35.runtime import ( + qwen35_linear_attention_decode, + qwen35_linear_attention_prefill, + ) +except Exception: # pragma: no cover - optional runtime dependency during partial imports + qwen35_linear_attention_decode = None + qwen35_linear_attention_prefill = None + +__all__ = ["Qwen35LinearAttentionConfig", "qwen35_linear_attention_prefill", "qwen35_linear_attention_decode"] diff --git a/cula/qwen35/common.py b/cula/qwen35/common.py new file mode 100644 index 00000000..63299c06 --- /dev/null +++ b/cula/qwen35/common.py @@ -0,0 +1,135 @@ +# Copyright 2025-2026 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared constants and validation helpers for Qwen3.5 linear attention.""" + +from __future__ import annotations + +from dataclasses import dataclass + +import torch + + +@dataclass(frozen=True) +class Qwen35LinearAttentionConfig: + """Minimal runtime config for Qwen3.5 linear-attention kernels.""" + + hidden_size: int = 5120 + conv_kernel_size: int = 4 + num_k_heads: int = 16 + num_v_heads: int = 48 + head_k_dim: int = 128 + head_v_dim: int = 128 + qkv_dtype: torch.dtype = torch.bfloat16 + state_dtype: torch.dtype = torch.float32 + + @property + def key_dim(self) -> int: + return self.num_k_heads * self.head_k_dim + + @property + def value_dim(self) -> int: + return self.num_v_heads * self.head_v_dim + + @property + def conv_dim(self) -> int: + return self.key_dim * 2 + self.value_dim + + @property + def qk_repeat_factor(self) -> int: + assert self.num_v_heads % self.num_k_heads == 0 + return self.num_v_heads // self.num_k_heads + + +DEFAULT_QWEN35_LINEAR_ATTN_CONFIG = Qwen35LinearAttentionConfig() + + +def validate_mixed_qkv( + mixed_qkv: torch.Tensor, + config: Qwen35LinearAttentionConfig = DEFAULT_QWEN35_LINEAR_ATTN_CONFIG, +) -> None: + if mixed_qkv.dtype != config.qkv_dtype: + raise TypeError(f"mixed_qkv must be {config.qkv_dtype}, got {mixed_qkv.dtype}") + if mixed_qkv.ndim != 2: + raise ValueError(f"mixed_qkv must be 2D [tokens, conv_dim_local], got {tuple(mixed_qkv.shape)}") + if mixed_qkv.shape[-1] <= 0: + raise ValueError("mixed_qkv must have a non-zero channel dimension") + if mixed_qkv.shape[-1] % config.conv_dim != 0 and mixed_qkv.shape[-1] != config.conv_dim: + # In TP mode this is expected to be a local shard, so only require alignment + # with the Qwen3.5 packed layout ratio. + local_dim = mixed_qkv.shape[-1] + expected_splits = (config.key_dim, config.key_dim, config.value_dim) + if local_dim % sum(expected_splits) != 0: + raise ValueError(f"mixed_qkv last dim must match packed local conv dim, got {local_dim}") + + +def validate_scalar_gate_inputs( + a: torch.Tensor, + b: torch.Tensor, + config: Qwen35LinearAttentionConfig = DEFAULT_QWEN35_LINEAR_ATTN_CONFIG, +) -> None: + if a.shape != b.shape: + raise ValueError(f"a and b must have the same shape, got a={tuple(a.shape)} vs b={tuple(b.shape)}") + if a.ndim != 2: + raise ValueError(f"a and b must be 2D [tokens, num_v_heads_local], got {tuple(a.shape)}") + if a.dtype != config.qkv_dtype or b.dtype != config.qkv_dtype: + raise TypeError(f"a and b must be {config.qkv_dtype}, got a={a.dtype}, b={b.dtype}") + + +def validate_state_tensors( + conv_state: torch.Tensor | None, + recurrent_state: torch.Tensor | None, + config: Qwen35LinearAttentionConfig = DEFAULT_QWEN35_LINEAR_ATTN_CONFIG, +) -> None: + if conv_state is not None: + if conv_state.ndim != 3: + raise ValueError(f"conv_state must be 3D [batch, channels, {config.conv_kernel_size}], got {tuple(conv_state.shape)}") + if conv_state.dtype != config.qkv_dtype: + raise TypeError(f"conv_state must be {config.qkv_dtype}, got {conv_state.dtype}") + if recurrent_state is not None: + if recurrent_state.ndim != 4: + raise ValueError(f"recurrent_state must be 4D [batch, hv, k, v], got {tuple(recurrent_state.shape)}") + if recurrent_state.dtype != config.state_dtype: + raise TypeError(f"recurrent_state must be {config.state_dtype}, got {recurrent_state.dtype}") + + +def infer_local_config( + mixed_qkv_dim: int, + local_num_v_heads: int, + *, + config: Qwen35LinearAttentionConfig = DEFAULT_QWEN35_LINEAR_ATTN_CONFIG, +) -> tuple[int, int, int]: + """Infer local packed dims from runtime shard sizes. + + Returns: + - local_key_dim + - local_value_dim + - local_num_k_heads + """ + + local_value_dim = local_num_v_heads * config.head_v_dim + remaining = mixed_qkv_dim - local_value_dim + if remaining <= 0 or remaining % 2 != 0: + raise ValueError( + f"Cannot infer local q/k dims from mixed_qkv_dim={mixed_qkv_dim}, local_num_v_heads={local_num_v_heads}" + ) + local_key_dim = remaining // 2 + if local_key_dim % config.head_k_dim != 0: + raise ValueError(f"Local key dim must be divisible by head_k_dim={config.head_k_dim}, got {local_key_dim}") + local_num_k_heads = local_key_dim // config.head_k_dim + if local_num_v_heads % local_num_k_heads != 0: + raise ValueError( + f"Local num_v_heads={local_num_v_heads} must be divisible by local num_k_heads={local_num_k_heads}" + ) + return local_key_dim, local_value_dim, local_num_k_heads diff --git a/cula/qwen35/runtime.py b/cula/qwen35/runtime.py new file mode 100644 index 00000000..f3f55091 --- /dev/null +++ b/cula/qwen35/runtime.py @@ -0,0 +1,368 @@ +# Copyright 2025-2026 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Runtime dispatch for Qwen3.5 linear-attention kernels.""" + +from __future__ import annotations + +import torch + +try: + import cuda.bindings.driver as cuda +except ImportError: # pragma: no cover - optional runtime dependency + cuda = None + +from cula.qwen35.common import ( + DEFAULT_QWEN35_LINEAR_ATTN_CONFIG, + Qwen35LinearAttentionConfig, + infer_local_config, + validate_mixed_qkv, + validate_scalar_gate_inputs, + validate_state_tensors, +) +from cula.ops.qwen35_conv1d_decode import qwen35_conv1d_decode_update +from cula.ops.qwen35_conv1d_prefill import qwen35_conv1d_prefill +from cula.ops.qwen35_layout_decode import qwen35_layout_decode +from cula.ops.qwen35_layout_prefill import qwen35_layout_prefill +from cula.ops.qwen35_scalar_kda_decode import ( + has_qwen35_layout_scalar_kda_decode_cudac, + qwen35_layout_scalar_kda_decode, + qwen35_scalar_kda_decode, +) +from cula.ops.qwen35_scalar_kda_prefill import qwen35_scalar_kda_prefill + +_stream_cache: dict[tuple[str, int], object] = {} + + +def _get_cached_stream(device: torch.device) -> object: + if cuda is None: + raise RuntimeError("cuda.bindings.driver is not available in this environment.") + stream_id = int(torch.cuda.current_stream(device=device).cuda_stream) + cache_key = (str(device), stream_id) + if cache_key not in _stream_cache: + _stream_cache[cache_key] = cuda.CUstream(stream_id) + return _stream_cache[cache_key] + + +def _torch_qwen35_scalar_kda_decode_reference( + q_rep: torch.Tensor, + k_rep: torch.Tensor, + v: torch.Tensor, + a_kernel: torch.Tensor, + b_kernel: torch.Tensor, + A_log: torch.Tensor, + dt_bias: torch.Tensor, + recurrent_state: torch.Tensor, + state_indices: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Pure torch reference for Qwen3.5 scalar-gated decode.""" + tokens, num_v_heads, head_k_dim = q_rep.shape + head_v_dim = v.shape[-1] + state_out = recurrent_state.clone() + out = torch.empty(tokens, num_v_heads, head_v_dim, device=q_rep.device, dtype=q_rep.dtype) + + scale = q_rep.shape[-1] ** -0.5 + q_f = torch.nn.functional.normalize(q_rep.float(), dim=-1) * scale + k_f = torch.nn.functional.normalize(k_rep.float(), dim=-1) + v_f = v.float() + a_f = a_kernel.float() + b_f = b_kernel.float() + + for token_idx in range(tokens): + pool_idx = int(state_indices[token_idx].item()) + for hv in range(num_v_heads): + state_kv = state_out[pool_idx, hv] + state_vk = state_kv.transpose(0, 1).contiguous() + + decay_pre = a_f[token_idx, hv] + dt_bias[hv] + decay = torch.exp(-torch.exp(A_log[hv]) * torch.nn.functional.softplus(decay_pre)) + beta = torch.sigmoid(b_f[token_idx, hv]) + + k_vec = k_f[token_idx, hv] + q_vec = q_f[token_idx, hv] + + proj = decay * (state_vk @ k_vec) + v_new = beta * (v_f[token_idx, hv] - proj) + state_vk_new = decay * state_vk + v_new.unsqueeze(1) * k_vec.unsqueeze(0) + out[token_idx, hv] = (state_vk_new @ q_vec).to(out.dtype) + state_out[pool_idx, hv] = state_vk_new.transpose(0, 1).contiguous() + + return out, state_out + + +def qwen35_linear_attention_decode_reference( + mixed_qkv: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + conv_weight: torch.Tensor, + A_log: torch.Tensor, + dt_bias: torch.Tensor, + *, + config: Qwen35LinearAttentionConfig = DEFAULT_QWEN35_LINEAR_ATTN_CONFIG, + conv_state: torch.Tensor, + recurrent_state: torch.Tensor, + state_indices: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Pure torch reference for the full Qwen3.5 decode chain.""" + tokens = mixed_qkv.shape[0] + if state_indices is None: + state_indices = torch.arange(tokens, device=mixed_qkv.device, dtype=torch.int32) + else: + state_indices = state_indices.to(device=mixed_qkv.device, dtype=torch.int32) + + conv_out, conv_state_out = qwen35_conv1d_decode_update( + mixed_qkv, + conv_state, + conv_weight, + activation="silu", + backend="reference", + ) + q_rep, k_rep, v, a_kernel, b_kernel = qwen35_layout_prefill( + conv_out, + a, + b, + config=config, + backend="reference", + ) + core_attn_out, recurrent_state_out = _torch_qwen35_scalar_kda_decode_reference( + q_rep, + k_rep, + v, + a_kernel, + b_kernel, + A_log.float(), + dt_bias.float(), + recurrent_state.float(), + state_indices, + ) + return core_attn_out.reshape(tokens, -1), conv_state_out, recurrent_state_out + + +def qwen35_linear_attention_prefill( + mixed_qkv: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + conv_weight: torch.Tensor, + A_log: torch.Tensor, + dt_bias: torch.Tensor, + *, + config: Qwen35LinearAttentionConfig = DEFAULT_QWEN35_LINEAR_ATTN_CONFIG, + cu_seqlens: torch.Tensor | None = None, + recurrent_state: torch.Tensor | None = None, + conv_state: torch.Tensor | None = None, + backend: str = "auto", +) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + """Qwen3.5 prefill wrapper. + + Args: + mixed_qkv: flattened [tokens, local_conv_dim] + a, b: [tokens, local_num_v_heads] + conv_weight: [local_conv_dim, 1, 4] or [local_conv_dim, 4] + A_log, dt_bias: [local_num_v_heads] + cu_seqlens: optional int32 sequence offsets for flattened input + recurrent_state: optional initial recurrent state [num_sequences, HV, 128, 128] + + Returns: + - core_attn_out_flat: [tokens, local_value_dim] + - final conv_state: [num_sequences, local_conv_dim, 4] + - final recurrent_state: [num_sequences, local_num_v_heads, 128, 128] + """ + + validate_mixed_qkv(mixed_qkv, config) + validate_scalar_gate_inputs(a, b, config) + validate_state_tensors(conv_state, recurrent_state, config) + if mixed_qkv.is_cuda: + _get_cached_stream(mixed_qkv.device) + if conv_state is not None: + raise NotImplementedError("Qwen3.5 prefill with non-empty conv_state is not implemented yet.") + if mixed_qkv.shape[0] != a.shape[0]: + raise ValueError(f"Token dimension mismatch, got mixed_qkv={tuple(mixed_qkv.shape)} a={tuple(a.shape)}") + if A_log.ndim != 1 or dt_bias.ndim != 1 or A_log.shape != dt_bias.shape: + raise ValueError(f"A_log and dt_bias must be matching 1D tensors, got {tuple(A_log.shape)} and {tuple(dt_bias.shape)}") + if cu_seqlens is not None and (cu_seqlens.ndim != 1 or cu_seqlens.dtype != torch.int32): + raise ValueError(f"cu_seqlens must be 1D int32, got {tuple(cu_seqlens.shape)} {cu_seqlens.dtype}") + + tokens = mixed_qkv.shape[0] + local_num_v_heads = a.shape[1] + _, local_value_dim, _ = infer_local_config( + mixed_qkv.shape[1], + local_num_v_heads, + config=config, + ) + if A_log.numel() != local_num_v_heads: + raise ValueError(f"A_log must match local_num_v_heads={local_num_v_heads}, got {A_log.numel()}") + + conv_out, conv_state_out = qwen35_conv1d_prefill( + mixed_qkv, + conv_weight, + activation="silu", + cu_seqlens=cu_seqlens, + output_final_state=True, + ) + q_rep, k_rep, v, a_kernel, b_kernel = qwen35_layout_prefill( + conv_out, + a, + b, + config=config, + backend="reference" if backend == "reference" else "auto", + ) + core_attn_out, recurrent_state_out = qwen35_scalar_kda_prefill( + q=q_rep.unsqueeze(0).contiguous(), + k=k_rep.unsqueeze(0).contiguous(), + v=v.unsqueeze(0).contiguous(), + a=a_kernel.unsqueeze(0).contiguous(), + b=b_kernel.unsqueeze(0).contiguous(), + A_log=A_log, + dt_bias=dt_bias, + initial_state=recurrent_state, + cu_seqlens=cu_seqlens, + backend=backend, + ) + return core_attn_out.reshape(tokens, local_value_dim), conv_state_out, recurrent_state_out + + +def qwen35_linear_attention_decode( + mixed_qkv: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + conv_weight: torch.Tensor, + A_log: torch.Tensor, + dt_bias: torch.Tensor, + *, + config: Qwen35LinearAttentionConfig = DEFAULT_QWEN35_LINEAR_ATTN_CONFIG, + conv_state: torch.Tensor, + recurrent_state: torch.Tensor, + state_indices: torch.Tensor | None = None, + backend: str = "auto", +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Qwen3.5 decode wrapper. + + Args: + mixed_qkv: [tokens, local_conv_dim] + a, b: [tokens, local_num_v_heads] + conv_weight: [local_conv_dim, 1, 4] or [local_conv_dim, 4] + A_log, dt_bias: [local_num_v_heads] + conv_state: [tokens, local_conv_dim, 4] + recurrent_state: [pool, local_num_v_heads, 128, 128] + + Returns: + - core_attn_out_flat: [tokens, local_value_dim] + - updated_conv_state + - updated_recurrent_state + """ + + validate_mixed_qkv(mixed_qkv, config) + validate_scalar_gate_inputs(a, b, config) + validate_state_tensors(conv_state, recurrent_state, config) + if mixed_qkv.is_cuda: + _get_cached_stream(mixed_qkv.device) + + if mixed_qkv.shape[0] != a.shape[0]: + raise ValueError(f"Token dimension mismatch, got mixed_qkv={tuple(mixed_qkv.shape)} a={tuple(a.shape)}") + if A_log.ndim != 1 or dt_bias.ndim != 1: + raise ValueError(f"A_log and dt_bias must be 1D, got {tuple(A_log.shape)} and {tuple(dt_bias.shape)}") + if A_log.shape != dt_bias.shape: + raise ValueError(f"A_log and dt_bias must have the same shape, got {tuple(A_log.shape)} vs {tuple(dt_bias.shape)}") + + tokens = mixed_qkv.shape[0] + local_num_v_heads = a.shape[1] + local_key_dim, local_value_dim, local_num_k_heads = infer_local_config( + mixed_qkv.shape[1], + local_num_v_heads, + config=config, + ) + if conv_state.shape != (tokens, mixed_qkv.shape[1], config.conv_kernel_size): + raise ValueError( + f"conv_state must be [tokens, local_conv_dim, {config.conv_kernel_size}], got {tuple(conv_state.shape)}" + ) + if recurrent_state.shape[1:] != (local_num_v_heads, config.head_k_dim, config.head_v_dim): + raise ValueError( + "recurrent_state must be [pool, local_num_v_heads, head_k_dim, head_v_dim], " + f"got {tuple(recurrent_state.shape)}" + ) + if A_log.numel() != local_num_v_heads: + raise ValueError(f"A_log must match local_num_v_heads={local_num_v_heads}, got {A_log.numel()}") + + if backend == "auto" and not mixed_qkv.is_cuda: + backend = "reference" + + if backend == "reference": + return qwen35_linear_attention_decode_reference( + mixed_qkv, + a, + b, + conv_weight, + A_log, + dt_bias, + config=config, + conv_state=conv_state, + recurrent_state=recurrent_state, + state_indices=state_indices, + ) + + conv_out, conv_state_out = qwen35_conv1d_decode_update( + mixed_qkv, + conv_state, + conv_weight, + activation="silu", + backend=backend, + ) + use_fused_layout_kda = ( + backend in ("auto", "cudac") + and mixed_qkv.is_cuda + and has_qwen35_layout_scalar_kda_decode_cudac() + ) + if backend == "cudac" and not use_fused_layout_kda: + raise RuntimeError("Requested backend='cudac' but qwen35_layout_scalar_kda_decode is not available.") + + if use_fused_layout_kda: + core_attn_out, recurrent_state_out = qwen35_layout_scalar_kda_decode( + mixed_qkv_conv=conv_out, + a=a, + b=b, + A_log=A_log, + dt_bias=dt_bias, + recurrent_state=recurrent_state, + state_indices=state_indices, + config=config, + backend=backend, + ) + core_attn_out = core_attn_out.reshape(tokens, local_value_dim) + return core_attn_out, conv_state_out, recurrent_state_out + + q_rep, k_rep, v, a_kernel, b_kernel = qwen35_layout_decode( + conv_out, + a, + b, + config=config, + backend=backend, + ) + q = q_rep.unsqueeze(1).contiguous() + k = k_rep.unsqueeze(1).contiguous() + v = v.unsqueeze(1).contiguous() + + core_attn_out, recurrent_state_out = qwen35_scalar_kda_decode( + q=q, + k=k, + v=v, + a=a_kernel, + b=b_kernel, + A_log=A_log, + dt_bias=dt_bias, + recurrent_state=recurrent_state, + state_indices=state_indices, + backend=backend, + ) + core_attn_out = core_attn_out.reshape(tokens, local_value_dim) + return core_attn_out, conv_state_out, recurrent_state_out diff --git a/cula/utils.py b/cula/utils.py index 8b8e0ab1..ed020f73 100644 --- a/cula/utils.py +++ b/cula/utils.py @@ -94,11 +94,11 @@ def get_kda_fused_fwd(device: torch.device | str | int | None = None) -> Callabl """ major, minor = get_device_sm_version(device) if major == 10 and minor in (0, 3): - from cula.kda import kda_prefill_blackwell + from cula.kda.blackwell_fused_fwd import flash_kda_prefill as kda_prefill_blackwell return kda_prefill_blackwell elif major == 9 and minor == 0: - from cula.kda import kda_prefill_hopper + from cula.kda.hopper_fused_fwd import cula_kda_prefill as kda_prefill_hopper return kda_prefill_hopper else: diff --git a/docs/qwen35_kernel_plan.md b/docs/qwen35_kernel_plan.md new file mode 100644 index 00000000..5aeaba4d --- /dev/null +++ b/docs/qwen35_kernel_plan.md @@ -0,0 +1,40 @@ +# Qwen3.5 Kernel Landing Plan Inside cuLA + +This note records the internal landing structure for Qwen3.5 linear-attention +support added directly inside `cuLA`. + +## New Python package surface + +- `cula/qwen35/__init__.py` +- `cula/qwen35/common.py` +- `cula/qwen35/runtime.py` + +## New CuTe op entry files + +- `cula/ops/qwen35_conv1d_prefill.py` +- `cula/ops/qwen35_conv1d_decode.py` +- `cula/ops/qwen35_scalar_kda_prefill.py` +- `cula/ops/qwen35_scalar_kda_decode.py` + +## Intended ownership + +- `common.py` + shared constants, local-head config, shape validation +- `runtime.py` + compile-cache, stream-cache, prefill/decode dispatch boundaries +- `qwen35_conv1d_*` + depthwise causal conv1d + silu +- `qwen35_scalar_kda_*` + scalar-gated delta-rule prefill/decode kernels + +## What should be reused from existing cuLA code + +- runtime compile-cache patterns from `cula/ops/kda_decode.py` +- device helpers from `cula/utils.py` +- operator boundary style from `cula/kda/chunk.py` + +## What should stay isolated at first + +- no direct mutation of the generic `chunk_kda` public entry +- no pybind work until Python/CuTe path is numerically correct +- no conv + kda fusion until standalone kernels are validated diff --git a/setup.py b/setup.py index f7b11b95..11d7f654 100644 --- a/setup.py +++ b/setup.py @@ -147,6 +147,12 @@ def get_nvcc_thread_args(): cuda_sources = [ "csrc/api/pybind.cu", + "csrc/qwen35/decode/qwen35_conv1d_decode.cu", + "csrc/qwen35/decode/qwen35_layout_decode.cu", + "csrc/qwen35/decode/qwen35_scalar_kda_decode.cu", + "csrc/qwen35/prefill/qwen35_layout_prefill.cu", + "csrc/qwen35/prefill/qwen35_scalar_kda_prefill.cu", + "csrc/qwen35/prefill/sm90/qwen35_chunk_prefill_sm90.cu", ] if not DISABLE_SM100 or not DISABLE_SM103: cuda_sources.extend( diff --git a/tests/test_qwen35_decode.py b/tests/test_qwen35_decode.py new file mode 100644 index 00000000..274ff85d --- /dev/null +++ b/tests/test_qwen35_decode.py @@ -0,0 +1,607 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pathlib +import sys + +import pytest +import torch + +sys.path.insert(0, str(pathlib.Path(__file__).resolve().parent.parent)) + +from cula.ops.qwen35_layout_decode import qwen35_layout_decode, qwen35_layout_decode_reference +from cula.ops.qwen35_scalar_kda_decode import qwen35_layout_scalar_kda_decode, qwen35_scalar_kda_decode +from cula.ops.qwen35_conv1d_decode import qwen35_conv1d_decode_reference, qwen35_conv1d_decode_update +from cula.qwen35.common import DEFAULT_QWEN35_LINEAR_ATTN_CONFIG, Qwen35LinearAttentionConfig +from cula.qwen35.runtime import qwen35_linear_attention_decode + +try: + from cula.ops.kda_decode_fla import fused_sigmoid_gating_delta_rule_update as triton_fused_sigmoid_update +except ImportError: + triton_fused_sigmoid_update = None + +try: + import cula.cudac as cula_cuda +except ImportError: + cula_cuda = None + + +def _device(): + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def _has_qwen35_cudac(): + return ( + torch.cuda.is_available() + and cula_cuda is not None + and hasattr(cula_cuda, "qwen35_conv1d_decode") + and hasattr(cula_cuda, "qwen35_layout_decode") + and hasattr(cula_cuda, "qwen35_scalar_kda_decode") + ) + + +def _has_qwen35_fused_layout_kda_cudac(): + return _has_qwen35_cudac() and hasattr(cula_cuda, "qwen35_layout_scalar_kda_decode") + + +def make_inputs( + tokens: int = 2, + pool_size: int = 3, + device: torch.device | None = None, + config: Qwen35LinearAttentionConfig = DEFAULT_QWEN35_LINEAR_ATTN_CONFIG, +): + device = _device() if device is None else device + torch.manual_seed(0) + mixed_qkv = torch.randn(tokens, config.conv_dim, device=device, dtype=config.qkv_dtype) + a = torch.randn(tokens, config.num_v_heads, device=device, dtype=config.qkv_dtype) + b = torch.randn(tokens, config.num_v_heads, device=device, dtype=config.qkv_dtype) + conv_weight = torch.randn(config.conv_dim, config.conv_kernel_size, device=device, dtype=config.qkv_dtype) + conv_state = torch.randn(tokens, config.conv_dim, config.conv_kernel_size, device=device, dtype=config.qkv_dtype) + recurrent_state = torch.randn( + pool_size, + config.num_v_heads, + config.head_k_dim, + config.head_v_dim, + device=device, + dtype=config.state_dtype, + ) * 0.01 + A_log = -torch.rand(config.num_v_heads, device=device, dtype=torch.float32) + dt_bias = torch.randn(config.num_v_heads, device=device, dtype=torch.float32) * 0.1 + state_indices = torch.arange(tokens, device=device, dtype=torch.int32) % pool_size + return mixed_qkv, a, b, conv_weight, conv_state, recurrent_state, A_log, dt_bias, state_indices + + +def manual_conv_decode(x_t: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor): + state_tail = conv_state[..., 1:].float() + window = torch.cat([state_tail, x_t.unsqueeze(-1).float()], dim=-1) + conv = (window * weight.float().unsqueeze(0)).sum(dim=-1) + y = torch.nn.functional.silu(conv).to(dtype=x_t.dtype) + state_new = conv_state.clone() + state_new[..., 0] = conv_state[..., 1] + state_new[..., 1] = conv_state[..., 2] + state_new[..., 2] = conv_state[..., 3] + state_new[..., 3] = x_t + return y, state_new + + +def manual_qwen35_decode_reference( + mixed_qkv: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + conv_weight: torch.Tensor, + A_log: torch.Tensor, + dt_bias: torch.Tensor, + conv_state: torch.Tensor, + recurrent_state: torch.Tensor, + state_indices: torch.Tensor, + *, + config: Qwen35LinearAttentionConfig = DEFAULT_QWEN35_LINEAR_ATTN_CONFIG, +): + conv_out, conv_state_out = manual_conv_decode(mixed_qkv, conv_state, conv_weight) + q_end = config.key_dim + k_end = q_end + config.key_dim + q = conv_out[:, :q_end].view(mixed_qkv.shape[0], config.num_k_heads, config.head_k_dim) + k = conv_out[:, q_end:k_end].view(mixed_qkv.shape[0], config.num_k_heads, config.head_k_dim) + v = conv_out[:, k_end:].view(mixed_qkv.shape[0], config.num_v_heads, config.head_v_dim) + q_rep = q.repeat_interleave(config.qk_repeat_factor, dim=1) + k_rep = k.repeat_interleave(config.qk_repeat_factor, dim=1) + + scale = config.head_k_dim**-0.5 + q_f = torch.nn.functional.normalize(q_rep.float(), dim=-1) * scale + k_f = torch.nn.functional.normalize(k_rep.float(), dim=-1) + v_f = v.float() + state_out = recurrent_state.clone() + out = torch.empty(mixed_qkv.shape[0], config.value_dim, device=mixed_qkv.device, dtype=mixed_qkv.dtype) + + for token_idx in range(mixed_qkv.shape[0]): + per_token = [] + pool_idx = int(state_indices[token_idx].item()) + for hv in range(config.num_v_heads): + state_kv = state_out[pool_idx, hv] + decay = torch.exp(-torch.exp(A_log[hv]) * torch.nn.functional.softplus(a[token_idx, hv].float() + dt_bias[hv])) + beta = torch.sigmoid(b[token_idx, hv].float()) + k_vec = k_f[token_idx, hv] + q_vec = q_f[token_idx, hv] + proj = decay * (state_kv.transpose(0, 1) @ k_vec) + v_new = beta * (v_f[token_idx, hv] - proj) + state_new_kv = decay * state_kv + k_vec.unsqueeze(1) * v_new.unsqueeze(0) + per_token.append((state_new_kv.transpose(0, 1) @ q_vec).to(mixed_qkv.dtype)) + state_out[pool_idx, hv] = state_new_kv + out[token_idx] = torch.cat(per_token, dim=0) + return out, conv_state_out, state_out + + +def manual_qwen35_scalar_kda_reference( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + A_log: torch.Tensor, + dt_bias: torch.Tensor, + recurrent_state: torch.Tensor, + state_indices: torch.Tensor, +): + if a.ndim == 2: + a = a.unsqueeze(1) + if b.ndim == 2: + b = b.unsqueeze(1) + N, _, HV, K = q.shape + scale = K**-0.5 + q_f = torch.nn.functional.normalize(q.squeeze(1).float(), dim=-1) * scale + k_f = torch.nn.functional.normalize(k.squeeze(1).float(), dim=-1) + v_f = v.squeeze(1).float() + state_out = recurrent_state.clone() + out = torch.empty(N, 1, HV, v.shape[-1], device=q.device, dtype=v.dtype) + + for token_idx in range(N): + pool_idx = int(state_indices[token_idx].item()) + for hv in range(HV): + state_kv = state_out[pool_idx, hv] + decay = torch.exp(-torch.exp(A_log[hv]) * torch.nn.functional.softplus(a[token_idx, 0, hv].float() + dt_bias[hv])) + beta = torch.sigmoid(b[token_idx, 0, hv].float()) + k_vec = k_f[token_idx, hv] + q_vec = q_f[token_idx, hv] + proj = decay * (state_kv.transpose(0, 1) @ k_vec) + v_new = beta * (v_f[token_idx, hv] - proj) + state_new_kv = decay * state_kv + k_vec.unsqueeze(1) * v_new.unsqueeze(0) + out[token_idx, 0, hv] = (state_new_kv.transpose(0, 1) @ q_vec).to(v.dtype) + state_out[pool_idx, hv] = state_new_kv + return out, state_out + + +def manual_qwen35_layout_scalar_kda_reference( + mixed_qkv_conv: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + A_log: torch.Tensor, + dt_bias: torch.Tensor, + recurrent_state: torch.Tensor, + state_indices: torch.Tensor, + *, + config: Qwen35LinearAttentionConfig = DEFAULT_QWEN35_LINEAR_ATTN_CONFIG, +): + q_rep, k_rep, v, a_ref, b_ref = qwen35_layout_decode_reference(mixed_qkv_conv, a, b, config=config) + + scale = config.head_k_dim**-0.5 + q_f = torch.nn.functional.normalize(q_rep.float(), dim=-1) * scale + k_f = torch.nn.functional.normalize(k_rep.float(), dim=-1) + v_f = v.float() + state_out = recurrent_state.clone() + out = torch.empty( + mixed_qkv_conv.shape[0], + q_rep.shape[1], + config.head_v_dim, + device=mixed_qkv_conv.device, + dtype=mixed_qkv_conv.dtype, + ) + + for token_idx in range(mixed_qkv_conv.shape[0]): + pool_idx = int(state_indices[token_idx].item()) + for hv in range(q_rep.shape[1]): + state_kv = state_out[pool_idx, hv] + decay = torch.exp(-torch.exp(A_log[hv]) * torch.nn.functional.softplus(a_ref[token_idx, hv].float() + dt_bias[hv])) + beta = torch.sigmoid(b_ref[token_idx, hv].float()) + k_vec = k_f[token_idx, hv] + q_vec = q_f[token_idx, hv] + proj = decay * (state_kv.transpose(0, 1) @ k_vec) + v_new = beta * (v_f[token_idx, hv] - proj) + state_new_kv = decay * state_kv + k_vec.unsqueeze(1) * v_new.unsqueeze(0) + out[token_idx, hv] = (state_new_kv.transpose(0, 1) @ q_vec).to(mixed_qkv_conv.dtype) + state_out[pool_idx, hv] = state_new_kv + return out.unsqueeze(1), state_out + + +def _local_config(local_v_heads: int) -> Qwen35LinearAttentionConfig: + return Qwen35LinearAttentionConfig(num_k_heads=local_v_heads // 3, num_v_heads=local_v_heads) + + +@pytest.mark.parametrize("tokens", [1, 2]) +def test_qwen35_conv_decode_reference(tokens: int): + mixed_qkv, _, _, conv_weight, conv_state, _, _, _, _ = make_inputs(tokens=tokens) + y_ref, state_ref = manual_conv_decode(mixed_qkv, conv_state, conv_weight) + y_op, state_op = qwen35_conv1d_decode_update(mixed_qkv, conv_state, conv_weight, backend="reference") + assert torch.equal(y_ref, y_op) + assert torch.equal(state_ref, state_op) + y_ref2, state_ref2 = qwen35_conv1d_decode_reference(mixed_qkv, conv_state, conv_weight) + assert torch.equal(y_ref, y_ref2) + assert torch.equal(state_ref, state_ref2) + + +def test_qwen35_layout_decode_reference(): + mixed_qkv, a, b, _, _, _, _, _, _ = make_inputs(tokens=2) + q_rep_ref, k_rep_ref, v_ref, a_ref, b_ref = qwen35_layout_decode_reference(mixed_qkv, a, b) + q_rep, k_rep, v, a_kernel, b_kernel = qwen35_layout_decode(mixed_qkv, a, b, backend="reference") + assert torch.equal(q_rep_ref, q_rep) + assert torch.equal(k_rep_ref, k_rep) + assert torch.equal(v_ref, v) + assert torch.equal(a_ref, a_kernel) + assert torch.equal(b_ref, b_kernel) + + +@pytest.mark.parametrize("tokens", [1, 2]) +def test_qwen35_decode_reference_chain(tokens: int): + mixed_qkv, a, b, conv_weight, conv_state, recurrent_state, A_log, dt_bias, state_indices = make_inputs(tokens=tokens) + out_ref, conv_state_ref, recurrent_state_ref = manual_qwen35_decode_reference( + mixed_qkv, + a, + b, + conv_weight, + A_log, + dt_bias, + conv_state, + recurrent_state, + state_indices, + ) + out, conv_state_out, recurrent_state_out = qwen35_linear_attention_decode( + mixed_qkv, + a, + b, + conv_weight, + A_log, + dt_bias, + conv_state=conv_state, + recurrent_state=recurrent_state, + state_indices=state_indices, + backend="reference", + ) + + assert torch.allclose(out_ref.float(), out.float(), atol=1e-5, rtol=1e-5) + assert torch.equal(conv_state_ref, conv_state_out) + assert torch.allclose(recurrent_state_ref, recurrent_state_out, atol=1e-6, rtol=1e-6) + + +@pytest.mark.skipif(not _has_qwen35_cudac(), reason="Qwen3.5 CUDA decode backend is not available") +@pytest.mark.parametrize("tokens", [1, 2, 4]) +def test_qwen35_decode_cudac_matches_reference(tokens: int): + # Decode batches represent distinct active sequences, so keep state rows unique + # to avoid intentionally racing multiple token updates against one cache row. + mixed_qkv, a, b, conv_weight, conv_state, recurrent_state, A_log, dt_bias, state_indices = make_inputs( + tokens=tokens, + pool_size=max(tokens, 3), + device=torch.device("cuda"), + ) + out_ref, conv_state_ref, recurrent_state_ref = qwen35_linear_attention_decode( + mixed_qkv, + a, + b, + conv_weight, + A_log, + dt_bias, + conv_state=conv_state, + recurrent_state=recurrent_state, + state_indices=state_indices, + backend="reference", + ) + out, conv_state_out, recurrent_state_out = qwen35_linear_attention_decode( + mixed_qkv, + a, + b, + conv_weight, + A_log, + dt_bias, + conv_state=conv_state, + recurrent_state=recurrent_state, + state_indices=state_indices, + backend="cudac", + ) + + torch.cuda.synchronize() + assert torch.allclose(out_ref.float(), out.float(), atol=3e-2, rtol=3e-2) + assert torch.equal(conv_state_ref, conv_state_out) + assert torch.allclose(recurrent_state_ref, recurrent_state_out, atol=3e-5, rtol=3e-5) + + +@pytest.mark.skipif(not _has_qwen35_cudac(), reason="Qwen3.5 CUDA decode backend is not available") +@pytest.mark.parametrize("local_v_heads", [48, 24, 12, 6]) +def test_qwen35_conv_decode_cudac_supports_local_tp_shapes(local_v_heads: int): + config = _local_config(local_v_heads) + mixed_qkv, _, _, conv_weight, conv_state, _, _, _, _ = make_inputs( + tokens=3, + pool_size=3, + device=torch.device("cuda"), + config=config, + ) + y_ref, state_ref = qwen35_conv1d_decode_update( + mixed_qkv, + conv_state, + conv_weight, + backend="reference", + ) + y, state = qwen35_conv1d_decode_update( + mixed_qkv, + conv_state, + conv_weight, + backend="cudac", + ) + + torch.cuda.synchronize() + torch.testing.assert_close(y, y_ref) + torch.testing.assert_close(state, state_ref) + + +@pytest.mark.skipif(not _has_qwen35_cudac(), reason="Qwen3.5 CUDA decode backend is not available") +@pytest.mark.parametrize("local_v_heads", [48, 24, 12, 6]) +def test_qwen35_scalar_kda_decode_cudac_supports_local_tp_shapes(local_v_heads: int): + torch.manual_seed(3) + config = _local_config(local_v_heads) + tokens = 3 + device = torch.device("cuda") + q = torch.randn(tokens, 1, config.num_v_heads, config.head_k_dim, device=device, dtype=config.qkv_dtype) + k = torch.randn_like(q) + v = torch.randn(tokens, 1, config.num_v_heads, config.head_v_dim, device=device, dtype=config.qkv_dtype) + a = torch.randn(tokens, 1, config.num_v_heads, device=device, dtype=config.qkv_dtype) + b = torch.randn(tokens, 1, config.num_v_heads, device=device, dtype=config.qkv_dtype) + A_log = -torch.rand(config.num_v_heads, device=device, dtype=torch.float32) + dt_bias = torch.randn(config.num_v_heads, device=device, dtype=torch.float32) * 0.1 + recurrent_state = torch.randn( + tokens, + config.num_v_heads, + config.head_k_dim, + config.head_v_dim, + device=device, + dtype=config.state_dtype, + ) * 0.01 + state_indices = torch.arange(tokens, device=device, dtype=torch.int32) + + out_ref, state_ref = manual_qwen35_scalar_kda_reference( + q, + k, + v, + a, + b, + A_log, + dt_bias, + recurrent_state, + state_indices, + ) + out, state = qwen35_scalar_kda_decode( + q, + k, + v, + a, + b, + A_log, + dt_bias, + recurrent_state, + state_indices=state_indices, + backend="cudac", + ) + + torch.cuda.synchronize() + torch.testing.assert_close(out.float(), out_ref.float(), atol=3e-2, rtol=3e-2) + torch.testing.assert_close(state, state_ref, atol=3e-5, rtol=3e-5) + + +@pytest.mark.skipif(not _has_qwen35_cudac(), reason="Qwen3.5 CUDA decode backend is not available") +@pytest.mark.parametrize("local_v_heads", [48, 24, 12, 6]) +def test_qwen35_decode_cudac_supports_local_tp_shapes(local_v_heads: int): + config = _local_config(local_v_heads) + mixed_qkv, a, b, conv_weight, conv_state, recurrent_state, A_log, dt_bias, state_indices = make_inputs( + tokens=3, + pool_size=3, + device=torch.device("cuda"), + config=config, + ) + out_ref, conv_state_ref, recurrent_state_ref = qwen35_linear_attention_decode( + mixed_qkv, + a, + b, + conv_weight, + A_log, + dt_bias, + config=config, + conv_state=conv_state, + recurrent_state=recurrent_state, + state_indices=state_indices, + backend="reference", + ) + out, conv_state_out, recurrent_state_out = qwen35_linear_attention_decode( + mixed_qkv, + a, + b, + conv_weight, + A_log, + dt_bias, + config=config, + conv_state=conv_state, + recurrent_state=recurrent_state, + state_indices=state_indices, + backend="cudac", + ) + + torch.cuda.synchronize() + torch.testing.assert_close(out.float(), out_ref.float(), atol=3e-2, rtol=3e-2) + torch.testing.assert_close(conv_state_out, conv_state_ref) + torch.testing.assert_close(recurrent_state_out, recurrent_state_ref, atol=3e-5, rtol=3e-5) + + +@pytest.mark.skipif(not _has_qwen35_fused_layout_kda_cudac(), reason="Qwen3.5 fused layout+KDA CUDA backend is not available") +@pytest.mark.parametrize("tokens", [1, 2, 4]) +def test_qwen35_fused_layout_kda_cudac_matches_reference_unfused_and_triton(tokens: int): + mixed_qkv, a, b, conv_weight, conv_state, recurrent_state, A_log, dt_bias, state_indices = make_inputs( + tokens=tokens, + pool_size=max(tokens, 3), + device=torch.device("cuda"), + ) + conv_out, _ = qwen35_conv1d_decode_update( + mixed_qkv, + conv_state, + conv_weight, + activation="silu", + backend="cudac", + ) + q_rep, k_rep, v, a_kernel, b_kernel = qwen35_layout_decode(conv_out, a, b, backend="cudac") + out_ref, state_ref = manual_qwen35_layout_scalar_kda_reference( + conv_out, + a, + b, + A_log, + dt_bias, + recurrent_state, + state_indices, + ) + out_unfused, state_unfused = qwen35_scalar_kda_decode( + q=q_rep.unsqueeze(1), + k=k_rep.unsqueeze(1), + v=v.unsqueeze(1), + a=a_kernel, + b=b_kernel, + A_log=A_log, + dt_bias=dt_bias, + recurrent_state=recurrent_state, + state_indices=state_indices, + backend="cudac", + ) + if triton_fused_sigmoid_update is not None: + state_triton = recurrent_state.clone() + out_triton = triton_fused_sigmoid_update( + A_log=A_log, + a=a_kernel.unsqueeze(1).contiguous(), + dt_bias=dt_bias, + softplus_beta=1.0, + softplus_threshold=20.0, + q=q_rep.unsqueeze(1).contiguous(), + k=k_rep.unsqueeze(1).contiguous(), + v=v.unsqueeze(1).contiguous(), + b=b_kernel.unsqueeze(1).contiguous(), + initial_state_source=state_triton, + initial_state_indices=state_indices, + scale=DEFAULT_QWEN35_LINEAR_ATTN_CONFIG.head_k_dim**-0.5, + use_qk_l2norm_in_kernel=True, + cu_seqlens=None, + is_kda=False, + ) + out_fused, state_fused = qwen35_layout_scalar_kda_decode( + mixed_qkv_conv=conv_out, + a=a, + b=b, + A_log=A_log, + dt_bias=dt_bias, + recurrent_state=recurrent_state, + state_indices=state_indices, + config=DEFAULT_QWEN35_LINEAR_ATTN_CONFIG, + backend="cudac", + ) + + torch.cuda.synchronize() + assert torch.allclose(out_ref.float(), out_fused.float(), atol=3e-2, rtol=3e-2) + assert torch.allclose(state_ref, state_fused, atol=3e-5, rtol=3e-5) + assert torch.equal(out_unfused, out_fused) + assert torch.equal(state_unfused, state_fused) + if triton_fused_sigmoid_update is not None: + assert torch.allclose(out_triton.float(), out_fused.float(), atol=3e-2, rtol=3e-2) + assert torch.allclose(state_triton, state_fused, atol=3e-5, rtol=3e-5) + + +@pytest.mark.skipif(not _has_qwen35_fused_layout_kda_cudac(), reason="Qwen3.5 fused layout+KDA CUDA backend is not available") +@pytest.mark.parametrize("local_v_heads", [48, 24, 12, 6]) +def test_qwen35_layout_scalar_kda_cudac_supports_local_tp_shards(local_v_heads: int): + config = _local_config(local_v_heads) + mixed_qkv, a, b, _, _, recurrent_state, A_log, dt_bias, state_indices = make_inputs( + tokens=2, + pool_size=3, + device=torch.device("cuda"), + config=config, + ) + out_ref, state_ref = manual_qwen35_layout_scalar_kda_reference( + mixed_qkv, + a, + b, + A_log, + dt_bias, + recurrent_state, + state_indices, + config=config, + ) + q_rep_ref, k_rep_ref, v_ref, a_ref, b_ref = qwen35_layout_decode_reference(mixed_qkv, a, b, config=config) + q_rep, k_rep, v, a_kernel, b_kernel = qwen35_layout_decode(mixed_qkv, a, b, config=config, backend="cudac") + out, state = qwen35_layout_scalar_kda_decode( + mixed_qkv, + a, + b, + A_log, + dt_bias, + recurrent_state, + state_indices=state_indices, + config=config, + backend="cudac", + ) + out_3d_gate, state_3d_gate = qwen35_layout_scalar_kda_decode( + mixed_qkv, + a.unsqueeze(1), + b.unsqueeze(1), + A_log, + dt_bias, + recurrent_state, + state_indices=state_indices, + config=config, + backend="cudac", + ) + + torch.cuda.synchronize() + torch.testing.assert_close(q_rep, q_rep_ref) + torch.testing.assert_close(k_rep, k_rep_ref) + torch.testing.assert_close(v, v_ref) + torch.testing.assert_close(a_kernel, a_ref) + torch.testing.assert_close(b_kernel, b_ref) + torch.testing.assert_close(out.float(), out_ref.float(), atol=3e-2, rtol=3e-2) + torch.testing.assert_close(state, state_ref, atol=3e-5, rtol=3e-5) + torch.testing.assert_close(out_3d_gate, out) + torch.testing.assert_close(state_3d_gate, state) + + +@pytest.mark.skipif(not _has_qwen35_cudac(), reason="Qwen3.5 CUDA decode backend is not available") +def test_qwen35_decode_cudac_rejects_duplicate_state_indices(): + mixed_qkv, a, b, conv_weight, conv_state, recurrent_state, A_log, dt_bias, _ = make_inputs( + tokens=2, + pool_size=3, + device=torch.device("cuda"), + ) + state_indices = torch.zeros(2, device=mixed_qkv.device, dtype=torch.int32) + + with pytest.raises(ValueError, match="requires unique state_indices"): + qwen35_linear_attention_decode( + mixed_qkv, + a, + b, + conv_weight, + A_log, + dt_bias, + conv_state=conv_state, + recurrent_state=recurrent_state, + state_indices=state_indices, + backend="cudac", + ) diff --git a/tests/test_qwen35_prefill.py b/tests/test_qwen35_prefill.py new file mode 100644 index 00000000..445c9369 --- /dev/null +++ b/tests/test_qwen35_prefill.py @@ -0,0 +1,386 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pathlib +import sys + +import torch +import pytest + +sys.path.insert(0, str(pathlib.Path(__file__).resolve().parent.parent)) + +from cula.ops.qwen35_conv1d_prefill import qwen35_conv1d_prefill +from cula.ops.qwen35_fused_kda_prefill import has_qwen35_fused_kda_prefill, qwen35_fused_kda_prefill +from cula.ops.qwen35_layout_prefill import qwen35_layout_prefill, qwen35_layout_prefill_reference +from cula.ops.qwen35_scalar_kda_prefill import qwen35_scalar_kda_prefill +from cula.qwen35.common import Qwen35LinearAttentionConfig +from cula.qwen35.runtime import qwen35_linear_attention_prefill + +try: + import cula.cudac as cula_cuda +except ImportError: + cula_cuda = None + + +def _manual_scalar_prefill(q, k, v, a, b, A_log, dt_bias, initial_state=None, cu_seqlens=None): + B, T, HV, K = q.shape + state_count = B if cu_seqlens is None else cu_seqlens.numel() - 1 + state = torch.zeros(state_count, HV, K, K, device=q.device, dtype=torch.float32) + if initial_state is not None: + state = initial_state.float().clone() + out = torch.empty_like(v) + q_f = torch.nn.functional.normalize(q.float(), dim=-1) * (K**-0.5) + k_f = torch.nn.functional.normalize(k.float(), dim=-1) + + def run_seq(batch_idx, state_idx, start, end): + for t in range(start, end): + for hv in range(HV): + state_kv = state[state_idx, hv] + decay = torch.exp(-torch.exp(A_log[hv].float()) * torch.nn.functional.softplus(a[batch_idx, t, hv].float() + dt_bias[hv].float())) + beta = torch.sigmoid(b[batch_idx, t, hv].float()) + k_vec = k_f[batch_idx, t, hv] + q_vec = q_f[batch_idx, t, hv] + proj = decay * (state_kv.transpose(0, 1) @ k_vec) + v_new = beta * (v[batch_idx, t, hv].float() - proj) + state_new = decay * state_kv + k_vec.unsqueeze(1) * v_new.unsqueeze(0) + out[batch_idx, t, hv] = (state_new.transpose(0, 1) @ q_vec).to(out.dtype) + state[state_idx, hv] = state_new + + if cu_seqlens is None: + for batch_idx in range(B): + run_seq(batch_idx, batch_idx, 0, T) + else: + for state_idx in range(state_count): + run_seq(0, state_idx, int(cu_seqlens[state_idx].item()), int(cu_seqlens[state_idx + 1].item())) + return out, state + + +def _local_config(local_v_heads: int) -> Qwen35LinearAttentionConfig: + return Qwen35LinearAttentionConfig(num_k_heads=local_v_heads // 3, num_v_heads=local_v_heads) + + +def test_qwen35_scalar_kda_prefill_reference_matches_manual(): + torch.manual_seed(0) + B, T, HV, K = 2, 3, 2, 128 + q = torch.randn(B, T, HV, K, dtype=torch.bfloat16) + k = torch.randn_like(q) + v = torch.randn_like(q) + a = torch.randn(B, T, HV, dtype=torch.bfloat16) + b = torch.randn(B, T, HV, dtype=torch.bfloat16) + A_log = -torch.rand(HV, dtype=torch.float32) + dt_bias = torch.randn(HV, dtype=torch.float32) * 0.1 + initial_state = torch.randn(B, HV, K, K, dtype=torch.float32) * 0.01 + + out_ref, state_ref = _manual_scalar_prefill(q, k, v, a, b, A_log, dt_bias, initial_state) + out, state = qwen35_scalar_kda_prefill( + q, + k, + v, + a, + b, + A_log, + dt_bias, + initial_state=initial_state, + backend="reference", + ) + + torch.testing.assert_close(out.float(), out_ref.float(), atol=1e-3, rtol=1e-3) + torch.testing.assert_close(state, state_ref, atol=1e-4, rtol=1e-4) + + +def test_qwen35_scalar_kda_prefill_varlen_reference_matches_manual(): + torch.manual_seed(1) + T, HV, K = 4, 2, 128 + q = torch.randn(1, T, HV, K, dtype=torch.bfloat16) + k = torch.randn_like(q) + v = torch.randn_like(q) + a = torch.randn(1, T, HV, dtype=torch.bfloat16) + b = torch.randn(1, T, HV, dtype=torch.bfloat16) + A_log = -torch.rand(HV, dtype=torch.float32) + dt_bias = torch.randn(HV, dtype=torch.float32) * 0.1 + cu_seqlens = torch.tensor([0, 2, 4], dtype=torch.int32) + + out_ref, state_ref = _manual_scalar_prefill(q, k, v, a, b, A_log, dt_bias, cu_seqlens=cu_seqlens) + out, state = qwen35_scalar_kda_prefill( + q, + k, + v, + a, + b, + A_log, + dt_bias, + cu_seqlens=cu_seqlens, + backend="reference", + ) + + torch.testing.assert_close(out.float(), out_ref.float(), atol=1e-3, rtol=1e-3) + torch.testing.assert_close(state, state_ref, atol=1e-4, rtol=1e-4) + + +def test_qwen35_scalar_kda_prefill_cuda_matches_reference(): + if not torch.cuda.is_available() or cula_cuda is None or not hasattr(cula_cuda, "qwen35_scalar_kda_prefill"): + import pytest + + pytest.skip("qwen35_scalar_kda_prefill CUDA extension is not available") + + torch.manual_seed(10) + device = torch.device("cuda") + B, T, HV, K = 1, 8, 48, 128 + q = torch.randn(B, T, HV, K, device=device, dtype=torch.bfloat16) + k = torch.randn_like(q) + v = torch.randn_like(q) + a = torch.randn(B, T, HV, device=device, dtype=torch.bfloat16) + b = torch.randn(B, T, HV, device=device, dtype=torch.bfloat16) + A_log = -torch.rand(HV, device=device, dtype=torch.float32) + dt_bias = torch.randn(HV, device=device, dtype=torch.float32) * 0.1 + initial_state = torch.randn(B, HV, K, K, device=device, dtype=torch.float32) * 0.01 + + out_ref, state_ref = qwen35_scalar_kda_prefill( + q, + k, + v, + a, + b, + A_log, + dt_bias, + initial_state=initial_state, + backend="reference", + ) + out, state = qwen35_scalar_kda_prefill( + q, + k, + v, + a, + b, + A_log, + dt_bias, + initial_state=initial_state, + backend="cudac", + ) + + torch.cuda.synchronize() + torch.testing.assert_close(out.float(), out_ref.float(), atol=2e-2, rtol=2e-2) + torch.testing.assert_close(state, state_ref, atol=2e-2, rtol=2e-2) + + +@pytest.mark.parametrize("local_v_heads", [48, 24, 12, 6]) +def test_qwen35_layout_prefill_cuda_supports_local_tp_shards(local_v_heads: int): + if not torch.cuda.is_available() or cula_cuda is None or not hasattr(cula_cuda, "qwen35_layout_prefill"): + pytest.skip("qwen35_layout_prefill CUDA extension is not available") + + torch.manual_seed(20 + local_v_heads) + device = torch.device("cuda") + config = _local_config(local_v_heads) + tokens = 5 + mixed_qkv = torch.randn(tokens, config.conv_dim, device=device, dtype=torch.bfloat16) + a = torch.randn(tokens, config.num_v_heads, device=device, dtype=torch.bfloat16) + b = torch.randn(tokens, config.num_v_heads, device=device, dtype=torch.bfloat16) + + ref = qwen35_layout_prefill_reference(mixed_qkv, a, b, config=config) + out = qwen35_layout_prefill(mixed_qkv, a, b, config=config, backend="cudac") + + torch.cuda.synchronize() + for out_tensor, ref_tensor in zip(out, ref, strict=True): + torch.testing.assert_close(out_tensor, ref_tensor) + + +@pytest.mark.parametrize("local_v_heads", [48, 24, 12, 6]) +def test_qwen35_scalar_kda_prefill_cuda_supports_local_tp_shards(local_v_heads: int): + if not torch.cuda.is_available() or cula_cuda is None or not hasattr(cula_cuda, "qwen35_scalar_kda_prefill"): + pytest.skip("qwen35_scalar_kda_prefill CUDA extension is not available") + + torch.manual_seed(30 + local_v_heads) + device = torch.device("cuda") + B, T, HV, K = 1, 4, local_v_heads, 128 + q = torch.randn(B, T, HV, K, device=device, dtype=torch.bfloat16) + k = torch.randn_like(q) + v = torch.randn_like(q) + a = torch.randn(B, T, HV, device=device, dtype=torch.bfloat16) + b = torch.randn(B, T, HV, device=device, dtype=torch.bfloat16) + A_log = -torch.rand(HV, device=device, dtype=torch.float32) + dt_bias = torch.randn(HV, device=device, dtype=torch.float32) * 0.1 + initial_state = torch.randn(B, HV, K, K, device=device, dtype=torch.float32) * 0.01 + + out_ref, state_ref = qwen35_scalar_kda_prefill( + q, + k, + v, + a, + b, + A_log, + dt_bias, + initial_state=initial_state, + backend="reference", + ) + out, state = qwen35_scalar_kda_prefill( + q, + k, + v, + a, + b, + A_log, + dt_bias, + initial_state=initial_state, + backend="cudac", + ) + + torch.cuda.synchronize() + torch.testing.assert_close(out.float(), out_ref.float(), atol=2e-2, rtol=2e-2) + torch.testing.assert_close(state, state_ref, atol=2e-2, rtol=2e-2) + + +def test_qwen35_chunk_qk_prefill_sm90_matches_torch(): + if not torch.cuda.is_available() or cula_cuda is None or not hasattr(cula_cuda, "qwen35_chunk_qk_prefill_sm90"): + import pytest + + pytest.skip("qwen35_chunk_qk_prefill_sm90 CUDA extension is not available") + + torch.manual_seed(11) + device = torch.device("cuda") + B, T, HV, K = 1, 64, 48, 128 + q = torch.randn(B, T, HV, K, device=device, dtype=torch.bfloat16) + k = torch.randn_like(q) + out = torch.empty(B, HV, T, T, device=device, dtype=torch.float32) + + cula_cuda.qwen35_chunk_qk_prefill_sm90(q.contiguous(), k.contiguous(), out) + torch.cuda.synchronize() + + ref = torch.einsum("bthd,bshd->bhts", q.float(), k.float()) + torch.testing.assert_close(out, ref, atol=2e-1, rtol=2e-2) + + +@pytest.mark.parametrize("local_v_heads", [48, 24, 12, 6]) +def test_qwen35_chunk_qk_prefill_sm90_supports_local_tp_shards(local_v_heads: int): + if not torch.cuda.is_available() or cula_cuda is None or not hasattr(cula_cuda, "qwen35_chunk_qk_prefill_sm90"): + pytest.skip("qwen35_chunk_qk_prefill_sm90 CUDA extension is not available") + + torch.manual_seed(40 + local_v_heads) + device = torch.device("cuda") + B, T, HV, K = 1, 32, local_v_heads, 128 + q = torch.randn(B, T, HV, K, device=device, dtype=torch.bfloat16) + k = torch.randn_like(q) + out = torch.empty(B, HV, T, T, device=device, dtype=torch.float32) + + cula_cuda.qwen35_chunk_qk_prefill_sm90(q.contiguous(), k.contiguous(), out) + torch.cuda.synchronize() + + ref = torch.einsum("bthd,bshd->bhts", q.float(), k.float()) + torch.testing.assert_close(out, ref, atol=2e-1, rtol=2e-2) + + +def test_qwen35_fused_kda_prefill_matches_reference(): + if not torch.cuda.is_available(): + import pytest + + pytest.skip("CUDA is not available") + if not has_qwen35_fused_kda_prefill(torch.device("cuda")): + import pytest + + pytest.skip("Qwen3.5 fused KDA prefill backend is not available") + + torch.manual_seed(12) + device = torch.device("cuda") + B, T, HV, K = 1, 64, 48, 128 + q = torch.randn(B, T, HV, K, device=device, dtype=torch.bfloat16) + k = torch.randn_like(q) + v = torch.randn_like(q) + a = torch.randn(B, T, HV, device=device, dtype=torch.bfloat16) + b = torch.randn(B, T, HV, device=device, dtype=torch.bfloat16) + A_log = -torch.rand(HV, device=device, dtype=torch.float32) + dt_bias = torch.randn(HV, device=device, dtype=torch.float32) * 0.1 + initial_state = torch.randn(B, HV, K, K, device=device, dtype=torch.float32) * 0.01 + + out_ref, state_ref = qwen35_scalar_kda_prefill( + q, + k, + v, + a, + b, + A_log, + dt_bias, + initial_state=initial_state, + backend="reference", + ) + out, state = qwen35_fused_kda_prefill( + q, + k, + v, + a, + b, + A_log, + dt_bias, + initial_state=initial_state, + ) + + torch.cuda.synchronize() + torch.testing.assert_close(out.float(), out_ref.float(), atol=3e-2, rtol=3e-2) + torch.testing.assert_close(state, state_ref, atol=3e-2, rtol=3e-2) + + +def test_qwen35_conv1d_prefill_flattened_state(): + x = torch.arange(5 * 3, dtype=torch.bfloat16).reshape(5, 3) + weight = torch.ones(3, 4, dtype=torch.bfloat16) + cu_seqlens = torch.tensor([0, 2, 5], dtype=torch.int32) + + y, state = qwen35_conv1d_prefill(x, weight, cu_seqlens=cu_seqlens, output_final_state=True) + + assert y.shape == x.shape + assert state.shape == (2, 3, 4) + torch.testing.assert_close(state[0, :, -2:], x[:2].transpose(0, 1)) + torch.testing.assert_close(state[1, :, -3:], x[2:5].transpose(0, 1)) + + +def test_qwen35_layout_prefill_reference(): + torch.manual_seed(2) + config = Qwen35LinearAttentionConfig(num_k_heads=1, num_v_heads=2) + tokens = 3 + mixed_qkv = torch.randn(tokens, config.conv_dim, dtype=torch.bfloat16) + a = torch.randn(tokens, config.num_v_heads, dtype=torch.bfloat16) + b = torch.randn(tokens, config.num_v_heads, dtype=torch.bfloat16) + + ref = qwen35_layout_prefill_reference(mixed_qkv, a, b, config=config) + out = qwen35_layout_prefill(mixed_qkv, a, b, config=config, backend="reference") + + for out_tensor, ref_tensor in zip(out, ref, strict=True): + assert torch.equal(out_tensor, ref_tensor) + + +def test_qwen35_linear_attention_prefill_reference_shapes(): + torch.manual_seed(2) + config = Qwen35LinearAttentionConfig(num_k_heads=1, num_v_heads=2) + tokens = 3 + mixed_qkv = torch.randn(tokens, config.conv_dim, dtype=torch.bfloat16) + a = torch.randn(tokens, config.num_v_heads, dtype=torch.bfloat16) + b = torch.randn(tokens, config.num_v_heads, dtype=torch.bfloat16) + conv_weight = torch.randn(config.conv_dim, config.conv_kernel_size, dtype=torch.bfloat16) + A_log = -torch.rand(config.num_v_heads, dtype=torch.float32) + dt_bias = torch.randn(config.num_v_heads, dtype=torch.float32) * 0.1 + cu_seqlens = torch.tensor([0, 2, 3], dtype=torch.int32) + + out, conv_state, recurrent_state = qwen35_linear_attention_prefill( + mixed_qkv, + a, + b, + conv_weight, + A_log, + dt_bias, + config=config, + cu_seqlens=cu_seqlens, + backend="reference", + ) + + assert out.shape == (tokens, config.value_dim) + assert conv_state.shape == (2, config.conv_dim, config.conv_kernel_size) + assert recurrent_state.shape == (2, config.num_v_heads, config.head_k_dim, config.head_v_dim)