Skip to content

Conversation

@Jack-Khuu
Copy link
Contributor

While experimenting with gpt5.2, I encountered an interesting case where the generated kernel for gpumode/trimul consistently targets bf16 and fails to generalize to fp32 inputs even though the reference input is provided as fp32.

  • I suspect this is a result of the existing guidance, whose intent is to promote speedup by leveraging bf16 whenever a fp32 input is provided
    - AVOID FP32 for inputs and outputs - use BF16 instead when users specify FP32 problems
  • Curious enough, this seems to only happen with GPT5.2. Both gpt5 and claude-opus generated kernels fully ignore the bf16 guidance and just operate in fp32

This PR updates the test generation instructions, to prompt for testing both input dtypes bf16 and fp32, when the reference input is fp32

  • The kernel generated with the updated template is able to pass leaderboard test for GPUMode submission, while the non-updated template consistently fails

This PR also adds a convenience Relay for GPT5.2 to allow generation via a local provider


Example without Change

Test
import math
import sys
import time
import traceback
import inspect
from typing import Dict, Tuple

import torch
from torch import nn


# Summary: Tests the "outgoing" Triangle Multiplicative Update (TriMul) forward op on
# 4D tensors [B, N, N, C] with optional mask [B, N, N], matching the provided PyTorch reference:
# LayerNorm -> gated projections -> masked -> einsum over k -> LayerNorm -> output gate -> Linear.


class TriMulRef(nn.Module):
    # Based on the reference in the prompt
    def __init__(self, dim: int, hidden_dim: int):
        super().__init__()
        self.norm = nn.LayerNorm(dim)

        self.left_proj = nn.Linear(dim, hidden_dim, bias=False)
        self.right_proj = nn.Linear(dim, hidden_dim, bias=False)

        self.left_gate = nn.Linear(dim, hidden_dim, bias=False)
        self.right_gate = nn.Linear(dim, hidden_dim, bias=False)
        self.out_gate = nn.Linear(dim, hidden_dim, bias=False)

        self.to_out_norm = nn.LayerNorm(hidden_dim)
        self.to_out = nn.Linear(hidden_dim, dim, bias=False)

    def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        x = self.norm(x)

        left = self.left_proj(x)
        right = self.right_proj(x)

        mask_f = mask.unsqueeze(-1)
        left = left * mask_f
        right = right * mask_f

        left_gate = self.left_gate(x).sigmoid()
        right_gate = self.right_gate(x).sigmoid()
        out_gate = self.out_gate(x).sigmoid()

        left = left * left_gate
        right = right * right_gate

        out = torch.einsum("... i k d, ... j k d -> ... i j d", left, right)

        out = self.to_out_norm(out)
        out = out * out_gate
        return self.to_out(out)


def _set_tf32(enabled: bool) -> Tuple[bool, bool]:
    prev_mm = torch.backends.cuda.matmul.allow_tf32
    prev_cudnn = torch.backends.cudnn.allow_tf32
    torch.backends.cuda.matmul.allow_tf32 = enabled
    torch.backends.cudnn.allow_tf32 = enabled
    return prev_mm, prev_cudnn


def _restore_tf32(prev: Tuple[bool, bool]) -> None:
    torch.backends.cuda.matmul.allow_tf32 = prev[0]
    torch.backends.cudnn.allow_tf32 = prev[1]


def _make_case(
    *,
    seqlen: int,
    bs: int,
    dim: int,
    hiddendim: int,
    seed: int,
    nomask: bool,
    distribution: str,
    device: str,
    dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor], Dict]:
    gen = torch.Generator(device=device)
    gen.manual_seed(seed)

    shape = (bs, seqlen, seqlen, dim)

    if distribution == "cauchy":
        x_f32 = torch.empty(shape, device=device, dtype=torch.float32).cauchy_(
            median=0.0, sigma=2.0, generator=gen
        )
    else:
        x_f32 = torch.randn(shape, device=device, dtype=torch.float32, generator=gen)

    x = x_f32.to(dtype=dtype).contiguous()

    if nomask:
        mask = torch.ones((bs, seqlen, seqlen), device=device, dtype=dtype)
    else:
        mask = torch.randint(
            0, 2, (bs, seqlen, seqlen), device=device, generator=gen, dtype=torch.int32
        ).to(dtype=dtype)

    # Weights (bf16) - same key names as prompt
    weights: Dict[str, torch.Tensor] = {}
    weights["norm.weight"] = torch.randn(dim, device=device, dtype=torch.float32, generator=gen).to(dtype)
    weights["norm.bias"] = torch.randn(dim, device=device, dtype=torch.float32, generator=gen).to(dtype)

    # scale like prompt
    weights["left_proj.weight"] = (
        torch.randn(hiddendim, dim, device=device, dtype=torch.float32, generator=gen) / math.sqrt(hiddendim)
    ).to(dtype)
    weights["right_proj.weight"] = (
        torch.randn(hiddendim, dim, device=device, dtype=torch.float32, generator=gen) / math.sqrt(hiddendim)
    ).to(dtype)
    weights["left_gate.weight"] = (
        torch.randn(hiddendim, dim, device=device, dtype=torch.float32, generator=gen) / math.sqrt(hiddendim)
    ).to(dtype)
    weights["right_gate.weight"] = (
        torch.randn(hiddendim, dim, device=device, dtype=torch.float32, generator=gen) / math.sqrt(hiddendim)
    ).to(dtype)
    weights["out_gate.weight"] = (
        torch.randn(hiddendim, dim, device=device, dtype=torch.float32, generator=gen) / math.sqrt(hiddendim)
    ).to(dtype)

    weights["to_out_norm.weight"] = torch.randn(
        hiddendim, device=device, dtype=torch.float32, generator=gen
    ).to(dtype)
    weights["to_out_norm.bias"] = torch.randn(
        hiddendim, device=device, dtype=torch.float32, generator=gen
    ).to(dtype)

    weights["to_out.weight"] = (
        torch.randn(dim, hiddendim, device=device, dtype=torch.float32, generator=gen) / math.sqrt(dim)
    ).to(dtype)

    config = {"hidden_dim": hiddendim, "dim": dim}
    return x, mask.contiguous(), weights, config


def _load_weights_into_ref(m: TriMulRef, weights: Dict[str, torch.Tensor]) -> None:
    # copy_ to avoid creating new Parameters (and keep dtype/device consistent)
    with torch.no_grad():
        m.norm.weight.copy_(weights["norm.weight"])
        m.norm.bias.copy_(weights["norm.bias"])
        m.left_proj.weight.copy_(weights["left_proj.weight"])
        m.right_proj.weight.copy_(weights["right_proj.weight"])
        m.left_gate.weight.copy_(weights["left_gate.weight"])
        m.right_gate.weight.copy_(weights["right_gate.weight"])
        m.out_gate.weight.copy_(weights["out_gate.weight"])
        m.to_out_norm.weight.copy_(weights["to_out_norm.weight"])
        m.to_out_norm.bias.copy_(weights["to_out_norm.bias"])
        m.to_out.weight.copy_(weights["to_out.weight"])


def _call_kernel(kernel_function, x, mask, weights, config):
    # Requirement: call like a normal Python function (no Triton launch syntax).
    # Kernel APIs vary; try common call patterns.
    try:
        sig = None
        try:
            sig = inspect.signature(kernel_function)
        except Exception:
            sig = None

        if sig is not None:
            nparams = len(sig.parameters)
            if nparams == 4:
                return kernel_function(x, mask, weights, config)
            if nparams == 1:
                return kernel_function((x, mask, weights, config))

        # Fallbacks
        try:
            return kernel_function(x, mask, weights, config)
        except TypeError:
            return kernel_function((x, mask, weights, config))
    except Exception:
        raise


def _as_tensor(out):
    if isinstance(out, torch.Tensor):
        return out
    if isinstance(out, (tuple, list)) and len(out) > 0 and isinstance(out[0], torch.Tensor):
        return out[0]
    return None


def test_kernel() -> bool:
    try:
        from kernel import kernel_function
        if not callable(kernel_function):
            print("kernel_function is not callable")
            return False

        if not torch.cuda.is_available():
            raise RuntimeError("CUDA not available")

        device = "cuda"
        dtype = torch.bfloat16  # requirement: avoid FP32; use BF16 instead

        # Use ONLY shapes from the prompt. We keep this suite feasible by excluding the extreme
        # N=768/1024 cases that are typically too large for a correctness unit test.
        cases = [
            # normal, small, nomask/mask
            dict(seqlen=32, bs=1, dim=128, hiddendim=128, seed=9371, nomask=True, distribution="normal"),
            dict(seqlen=32, bs=1, dim=128, hiddendim=128, seed=1092, nomask=False, distribution="normal"),
            # normal, moderate, bs=2
            dict(seqlen=64, bs=2, dim=256, hiddendim=128, seed=210284, nomask=False, distribution="normal"),
            # normal, larger dim
            dict(seqlen=128, bs=1, dim=768, hiddendim=128, seed=81934, nomask=True, distribution="normal"),
            # normal, stress (largest feasible from prompt list)
            dict(seqlen=256, bs=1, dim=128, hiddendim=128, seed=10432, nomask=False, distribution="normal"),
            # cauchy, heavy-tailed
            dict(seqlen=32, bs=1, dim=128, hiddendim=128, seed=937321, nomask=True, distribution="cauchy"),
        ]

        # BF16 + large reductions (einsum over k up to 256) can accumulate noticeable error.
        # Loosen tolerance accordingly.
        rtol, atol = 2e-2, 2e-2

        prev_tf32 = _set_tf32(False)
        try:
            for idx, spec in enumerate(cases):
                print(f"\n[Case {idx+1}/{len(cases)}] {spec}")
                torch.cuda.synchronize()
                t0 = time.time()

                x, mask, weights, config = _make_case(device=device, dtype=dtype, **spec)

                # Reference
                ref = TriMulRef(dim=config["dim"], hidden_dim=config["hidden_dim"]).to(device=device, dtype=dtype)
                _load_weights_into_ref(ref, weights)

                torch.cuda.synchronize()
                t_ref0 = time.time()
                y_ref = ref(x, mask)
                torch.cuda.synchronize()
                t_ref1 = time.time()

                # Kernel
                torch.cuda.synchronize()
                t_k0 = time.time()
                y_out_raw = _call_kernel(kernel_function, x, mask, weights, config)
                torch.cuda.synchronize()
                t_k1 = time.time()

                y_out = _as_tensor(y_out_raw)
                if y_out is None:
                    print(f"kernel_function returned unsupported type: {type(y_out_raw)}")
                    print(f"Value repr: {repr(y_out_raw)[:500]}")
                    return False

                # Basic checks
                if y_out.shape != y_ref.shape:
                    print("SHAPE MISMATCH")
                    print(f"Expected shape: {tuple(y_ref.shape)}")
                    print(f"Got shape:      {tuple(y_out.shape)}")
                    return False

                if y_out.device != x.device:
                    print("DEVICE MISMATCH")
                    print(f"Input device: {x.device}, Output device: {y_out.device}")
                    return False

                # Numerical check
                try:
                    y_out_f = y_out.float()
                    y_ref_f = y_ref.float()

                    ok = torch.allclose(y_out_f, y_ref_f, rtol=rtol, atol=atol, equal_nan=True)
                    if not ok:
                        diff = (y_out_f - y_ref_f).abs()
                        max_abs = diff.max().item()
                        mean_abs = diff.mean().item()

                        denom = y_ref_f.abs().clamp_min(1e-8)
                        max_rel = (diff / denom).max().item()

                        print("NUMERICAL MISMATCH")
                        print(f"Input shape: {tuple(x.shape)}, dtype: {x.dtype}")
                        print(f"Mask shape:  {tuple(mask.shape)}, dtype: {mask.dtype}, nomask={spec['nomask']}")
                        print(f"Output dtype: {y_out.dtype}, Ref dtype: {y_ref.dtype}")
                        print(f"rtol={rtol}, atol={atol}, equal_nan=True")
                        print(f"max_abs_diff={max_abs:.6g}, mean_abs_diff={mean_abs:.6g}, max_rel_diff={max_rel:.6g}")

                        # Print a small sample
                        flat_ref = y_ref_f.flatten()
                        flat_out = y_out_f.flatten()
                        flat_diff = diff.flatten()
                        k = min(10, flat_diff.numel())
                        topk = torch.topk(flat_diff, k=k)
                        print("\nTop diffs (index, ref, out, absdiff):")
                        for i in range(k):
                            fi = topk.indices[i].item()
                            print(
                                f"  {fi}: ref={flat_ref[fi].item():.6g}, out={flat_out[fi].item():.6g}, "
                                f"absdiff={topk.values[i].item():.6g}"
                            )

                        print("\nSmall prefix sample:")
                        n = min(10, flat_diff.numel())
                        print("  ref[:n] =", flat_ref[:n].tolist())
                        print("  out[:n] =", flat_out[:n].tolist())
                        return False

                except AssertionError as ae:
                    print(f"AssertionError during comparison: {ae}")
                    return False

                t1 = time.time()
                print(
                    f"Timings: ref={t_ref1 - t_ref0:.4f}s, kernel={t_k1 - t_k0:.4f}s, total_case={t1 - t0:.4f}s"
                )

                # cleanup to reduce peak memory across cases
                del x, mask, weights, config, ref, y_ref, y_out, y_out_raw
                torch.cuda.empty_cache()

        finally:
            _restore_tf32(prev_tf32)

        print("\nAll selected TriMul test cases passed.")
        return True

    except Exception as e:
        if isinstance(e, NameError):
            print(f"Test failed: NameError (likely undefined helper in kernel.py): {e}")
        else:
            print(f"Test failed: {e}")
        print("Traceback:\n" + traceback.format_exc())
        return False


if __name__ == "__main__":
    ok = test_kernel()
    sys.exit(0 if ok else 1)
Kernel
from __future__ import annotations
from typing import Any, Dict, Tuple, Union

import torch
import triton
import triton.language as tl


# -----------------------------
# Math helpers (Triton)
# -----------------------------
@triton.jit
def _sigmoid_fp32(x):
    # Stable sigmoid with a single exp; x is fp32.
    ax = tl.abs(x)
    e = tl.math.exp(-ax)
    denom = 1.0 + e
    return tl.where(x >= 0.0, 1.0 / denom, e / denom)


# -----------------------------
# Kernel 1: LayerNorm over last dim (per row of [M, C] contiguous view)
# Welford-style combine; fixed masking for partial blocks.
# -----------------------------
@triton.jit
def _layernorm_lastdim_kernel(
    x_ptr,
    y_ptr,
    gamma_ptr,
    beta_ptr,
    M,  # runtime
    C: tl.constexpr,  # compile-time row width
    eps: tl.constexpr,
    BLOCK_C: tl.constexpr,
):
    pid = tl.program_id(0)
    row = pid
    row_off = row * C

    offs = tl.arange(0, BLOCK_C)

    mean = tl.zeros((), dtype=tl.float32)
    m2 = tl.zeros((), dtype=tl.float32)
    n = tl.zeros((), dtype=tl.float32)

    # First pass: mean/var
    for c0 in tl.static_range(0, C, BLOCK_C):
        cols = c0 + offs
        m = cols < C

        x = tl.load(x_ptr + row_off + cols, mask=m, other=0.0).to(tl.float32)

        cnt = tl.sum(m.to(tl.float32), axis=0)
        block_sum = tl.sum(x, axis=0)
        block_mean = block_sum / cnt

        diff = tl.where(m, x - block_mean, 0.0)
        block_m2 = tl.sum(diff * diff, axis=0)

        delta = block_mean - mean
        new_n = n + cnt
        mean = mean + delta * (cnt / new_n)
        m2 = m2 + block_m2 + delta * delta * (n * cnt / new_n)
        n = new_n

    var = m2 / n
    rstd = tl.rsqrt(var + eps)

    # Second pass: normalize + affine, store BF16
    for c0 in tl.static_range(0, C, BLOCK_C):
        cols = c0 + offs
        m = cols < C
        x = tl.load(x_ptr + row_off + cols, mask=m, other=0.0).to(tl.float32)
        g = tl.load(gamma_ptr + cols, mask=m, other=0.0).to(tl.float32)
        b = tl.load(beta_ptr + cols, mask=m, other=0.0).to(tl.float32)
        y = (x - mean) * rstd
        y = y * g + b
        tl.store(y_ptr + row_off + cols, y.to(tl.bfloat16), mask=m)


# -----------------------------
# Kernel 2 (fused): 5 projections + (mask * gate) application to left/right.
# Inputs:
#   x_norm: [M, DIM] BF16 contiguous
#   W_*:   [H, DIM]  BF16 contiguous (row-major)
#   mask:  [M] BF16 (0/1)
# Outputs:
#   left:    [M, H] BF16 = (xW_left^T * mask) * sigmoid(bf16(xW_left_gate^T))
#   right:   [M, H] BF16 = (xW_right^T * mask) * sigmoid(bf16(xW_right_gate^T))
#   out_gate:[M, H] BF16 = sigmoid(bf16(xW_out_gate^T))
# -----------------------------
@triton.jit
def _fused_proj_gates_mask_kernel(
    x_ptr,
    mask_ptr,
    w_left_ptr,
    w_right_ptr,
    w_lg_ptr,
    w_rg_ptr,
    w_og_ptr,
    left_ptr,
    right_ptr,
    out_gate_ptr,
    M,
    stride_xm,
    stride_xk,
    stride_wn,
    stride_wk,
    stride_om,
    stride_on,
    DIM: tl.constexpr,
    H: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)

    acc_l = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    acc_r = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    acc_lg = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    acc_rg = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    acc_og = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

    for k0 in tl.static_range(0, DIM, BLOCK_K):
        k = k0 + offs_k

        a_ptrs = x_ptr + offs_m[:, None] * stride_xm + k[None, :] * stride_xk
        a = tl.load(a_ptrs, mask=(offs_m[:, None] < M) & (k[None, :] < DIM), other=0.0).to(tl.bfloat16)

        # For output = A @ W^T, we need b[k,n] = W[n,k].
        b_ptrs_l = w_left_ptr + offs_n[None, :] * stride_wn + k[:, None] * stride_wk
        b_ptrs_r = w_right_ptr + offs_n[None, :] * stride_wn + k[:, None] * stride_wk
        b_ptrs_lg = w_lg_ptr + offs_n[None, :] * stride_wn + k[:, None] * stride_wk
        b_ptrs_rg = w_rg_ptr + offs_n[None, :] * stride_wn + k[:, None] * stride_wk
        b_ptrs_og = w_og_ptr + offs_n[None, :] * stride_wn + k[:, None] * stride_wk

        b_l = tl.load(b_ptrs_l, mask=(offs_n[None, :] < H) & (k[:, None] < DIM), other=0.0).to(tl.bfloat16)
        b_r = tl.load(b_ptrs_r, mask=(offs_n[None, :] < H) & (k[:, None] < DIM), other=0.0).to(tl.bfloat16)
        b_lg = tl.load(b_ptrs_lg, mask=(offs_n[None, :] < H) & (k[:, None] < DIM), other=0.0).to(tl.bfloat16)
        b_rg = tl.load(b_ptrs_rg, mask=(offs_n[None, :] < H) & (k[:, None] < DIM), other=0.0).to(tl.bfloat16)
        b_og = tl.load(b_ptrs_og, mask=(offs_n[None, :] < H) & (k[:, None] < DIM), other=0.0).to(tl.bfloat16)

        acc_l = tl.dot(a, b_l, acc_l)
        acc_r = tl.dot(a, b_r, acc_r)
        acc_lg = tl.dot(a, b_lg, acc_lg)
        acc_rg = tl.dot(a, b_rg, acc_rg)
        acc_og = tl.dot(a, b_og, acc_og)

    # Match PyTorch gate path: sigmoid consumes BF16-rounded linear output
    lg = _sigmoid_fp32(acc_lg.to(tl.bfloat16).to(tl.float32)).to(tl.bfloat16)
    rg = _sigmoid_fp32(acc_rg.to(tl.bfloat16).to(tl.float32)).to(tl.bfloat16)
    og = _sigmoid_fp32(acc_og.to(tl.bfloat16).to(tl.float32)).to(tl.bfloat16)

    l = acc_l.to(tl.bfloat16)
    r = acc_r.to(tl.bfloat16)

    mask_row = tl.load(mask_ptr + offs_m, mask=offs_m < M, other=0.0).to(tl.bfloat16)
    mask_row = mask_row[:, None]

    # Match PyTorch op boundaries (BF16 rounding after each elementwise op)
    l = (l * mask_row).to(tl.bfloat16)
    r = (r * mask_row).to(tl.bfloat16)
    l = (l * lg).to(tl.bfloat16)
    r = (r * rg).to(tl.bfloat16)

    o_ptrs = left_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
    tl.store(o_ptrs, l, mask=(offs_m[:, None] < M) & (offs_n[None, :] < H))
    o_ptrs = right_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
    tl.store(o_ptrs, r, mask=(offs_m[:, None] < M) & (offs_n[None, :] < H))
    o_ptrs = out_gate_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
    tl.store(o_ptrs, og, mask=(offs_m[:, None] < M) & (offs_n[None, :] < H))


# -----------------------------
# Kernel 3: Einsum via batched matmul over hidden dim (parallelized over H).
# out[b,i,j,d] = sum_k left[b,i,k,d] * right[b,j,k,d]
# -----------------------------
@triton.jit
def _einsum_batched_mm_kernel(
    left_ptr,
    right_ptr,
    out_ptr,
    N: tl.constexpr,
    H: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    pid_bh = tl.program_id(2)  # batch over (b, d)

    b = pid_bh // H
    d = pid_bh - b * H

    stride_b = N * N * H
    stride_i = N * H
    stride_j = H
    stride_k = H

    base_l = left_ptr + b * stride_b + d
    base_r = right_ptr + b * stride_b + d
    base_o = out_ptr + b * stride_b + d

    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)  # i
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)  # j
    offs_k = tl.arange(0, BLOCK_K)                    # k tile

    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

    for k0 in tl.static_range(0, N, BLOCK_K):
        k = k0 + offs_k
        a_ptrs = base_l + offs_m[:, None] * stride_i + k[None, :] * stride_k
        b_ptrs = base_r + offs_n[None, :] * stride_i + k[:, None] * stride_k  # right[j,k] as [k,j]
        a = tl.load(a_ptrs, mask=(offs_m[:, None] < N) & (k[None, :] < N), other=0.0).to(tl.bfloat16)
        bmat = tl.load(b_ptrs, mask=(offs_n[None, :] < N) & (k[:, None] < N), other=0.0).to(tl.bfloat16)
        acc = tl.dot(a, bmat, acc)

    o_ptrs = base_o + offs_m[:, None] * stride_i + offs_n[None, :] * stride_j
    tl.store(o_ptrs, acc.to(tl.bfloat16), mask=(offs_m[:, None] < N) & (offs_n[None, :] < N))


# -----------------------------
# Kernel 4 (fused): LayerNorm(hidden_dim) over out + multiply out_gate + final linear to dim
# Mimics PyTorch dtype boundaries:
#   ln_out_bf16 = LayerNorm(out_in)  (output BF16)
#   gated_bf16  = ln_out_bf16 * out_gate_bf16
#   y = Linear(gated_bf16) (BF16 input, FP32 accum, BF16 output)
# -----------------------------
@triton.jit
def _fused_out_ln_gate_linear_kernel(
    out_in_ptr,
    out_gate_ptr,
    ln_gamma_ptr,
    ln_beta_ptr,
    w_ptr,   # [DIM, H] row-major
    y_ptr,   # [M, DIM]
    M,
    DIM,
    H: tl.constexpr,
    eps: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, H)

    x_ptrs = out_in_ptr + offs_m[:, None] * H + offs_k[None, :]
    g_ptrs = out_gate_ptr + offs_m[:, None] * H + offs_k[None, :]

    x = tl.load(x_ptrs, mask=(offs_m[:, None] < M), other=0.0).to(tl.float32)
    gate = tl.load(g_ptrs, mask=(offs_m[:, None] < M), other=0.0).to(tl.bfloat16)

    mean = tl.sum(x, axis=1) / H
    xc = x - mean[:, None]
    var = tl.sum(xc * xc, axis=1) / H
    rstd = tl.rsqrt(var + eps)

    gamma = tl.load(ln_gamma_ptr + offs_k, mask=offs_k < H, other=0.0).to(tl.float32)
    beta = tl.load(ln_beta_ptr + offs_k, mask=offs_k < H, other=0.0).to(tl.float32)

    ln_out = xc * rstd[:, None]
    ln_out = ln_out * gamma[None, :] + beta[None, :]

    ln_out_bf16 = ln_out.to(tl.bfloat16)
    gated_bf16 = (ln_out_bf16 * gate).to(tl.bfloat16)

    # W^T tile: B[k,n] = w[n,k] where w is [DIM,H]
    b_ptrs = w_ptr + offs_n[None, :] * H + offs_k[:, None]
    b = tl.load(b_ptrs, mask=(offs_n[None, :] < DIM) & (offs_k[:, None] < H), other=0.0).to(tl.bfloat16)

    y = tl.dot(gated_bf16, b, tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32))

    y_ptrs = y_ptr + offs_m[:, None] * DIM + offs_n[None, :]
    tl.store(y_ptrs, y.to(tl.bfloat16), mask=(offs_m[:, None] < M) & (offs_n[None, :] < DIM))


# -----------------------------
# Public wrapper
# -----------------------------
def kernel_function(
    x: Union[torch.Tensor, Tuple[Any, ...]],
    mask: torch.Tensor | None = None,
    weights: Dict[str, torch.Tensor] | None = None,
    config: Dict[str, Any] | None = None,
) -> torch.Tensor:
    # Support tester fallback: kernel_function((x,mask,weights,config))
    if isinstance(x, (tuple, list)) and mask is None and weights is None and config is None:
        x, mask, weights, config = x  # type: ignore[misc]

    assert isinstance(x, torch.Tensor) and isinstance(mask, torch.Tensor)
    assert isinstance(weights, dict) and isinstance(config, dict)

    assert x.is_cuda and mask.is_cuda, "CUDA tensors required"
    assert x.dtype == torch.bfloat16, "Test requires BF16 input"
    assert mask.dtype == torch.bfloat16, "Mask is BF16 in the test"
    assert x.is_contiguous(), "x must be contiguous"
    assert mask.is_contiguous(), "mask must be contiguous"

    B, N1, N2, DIM = x.shape
    assert N1 == N2, "Expected x shape [B,N,N,C]"
    N = N1
    H = int(config["hidden_dim"])
    assert H == 128, "This implementation is specialized for hidden_dim=128 (as in the tests)"

    # Weights
    norm_w = weights["norm.weight"]
    norm_b = weights["norm.bias"]
    w_left = weights["left_proj.weight"]
    w_right = weights["right_proj.weight"]
    w_left_gate = weights["left_gate.weight"]
    w_right_gate = weights["right_gate.weight"]
    w_out_gate = weights["out_gate.weight"]
    out_ln_w = weights["to_out_norm.weight"]
    out_ln_b = weights["to_out_norm.bias"]
    w_to_out = weights["to_out.weight"]

    # Basic checks
    assert norm_w.shape == (DIM,) and norm_b.shape == (DIM,)
    assert w_left.shape == (H, DIM) and w_right.shape == (H, DIM)
    assert w_left_gate.shape == (H, DIM) and w_right_gate.shape == (H, DIM)
    assert w_out_gate.shape == (H, DIM)
    assert out_ln_w.shape == (H,) and out_ln_b.shape == (H,)
    assert w_to_out.shape == (DIM, H)

    for t in (
        norm_w, norm_b,
        w_left, w_right, w_left_gate, w_right_gate, w_out_gate,
        out_ln_w, out_ln_b, w_to_out,
    ):
        assert t.is_cuda and t.dtype == torch.bfloat16 and t.is_contiguous()

    # Flatten [B,N,N,*] -> [M,*]
    M = B * N * N
    x2d = x.view(M, DIM)
    mask1d = mask.view(M)

    # Intermediates
    x_norm = torch.empty_like(x)
    x_norm2d = x_norm.view(M, DIM)

    left = torch.empty((M, H), device=x.device, dtype=torch.bfloat16)
    right = torch.empty((M, H), device=x.device, dtype=torch.bfloat16)
    out_gate = torch.empty((M, H), device=x.device, dtype=torch.bfloat16)

    # 1) LayerNorm on x (per row DIM)
    _layernorm_lastdim_kernel[(M,)](
        x2d,
        x_norm2d,
        norm_w,
        norm_b,
        M,
        C=DIM,
        eps=1e-5,
        BLOCK_C=256,
        num_warps=4,
    )

    # 2) Fused projections + sigmoids + mask/gate for left/right
    def grid_proj(meta):
        return (triton.cdiv(M, meta["BLOCK_M"]), triton.cdiv(H, meta["BLOCK_N"]))

    proj_meta = dict(BLOCK_M=128, BLOCK_N=64, BLOCK_K=32, num_warps=8, num_stages=4)
    _fused_proj_gates_mask_kernel[grid_proj(proj_meta)](
        x_norm2d,
        mask1d,
        w_left,
        w_right,
        w_left_gate,
        w_right_gate,
        w_out_gate,
        left,
        right,
        out_gate,
        M,
        stride_xm=x_norm2d.stride(0),
        stride_xk=x_norm2d.stride(1),
        stride_wn=w_left.stride(0),
        stride_wk=w_left.stride(1),
        stride_om=left.stride(0),
        stride_on=left.stride(1),
        DIM=DIM,
        H=H,
        **proj_meta,
    )

    # 3) Einsum via batched matmul over (b,d)
    out = torch.empty((B, N, N, H), device=x.device, dtype=torch.bfloat16)
    left4 = left.view(B, N, N, H)
    right4 = right.view(B, N, N, H)

    grid_einsum = (triton.cdiv(N, 64), triton.cdiv(N, 64), B * H)
    _einsum_batched_mm_kernel[grid_einsum](
        left4,
        right4,
        out,
        N=N,
        H=H,
        BLOCK_M=64,
        BLOCK_N=64,
        BLOCK_K=32,
        num_warps=8,
        num_stages=4,
    )

    # 4) Fused: LayerNorm(hidden_dim) + out_gate multiply + to_out linear -> y
    y = torch.empty((M, DIM), device=x.device, dtype=torch.bfloat16)
    out2d = out.view(M, H)
    out_gate2d = out_gate.view(M, H)

    def grid_fused(meta):
        return (triton.cdiv(M, meta["BLOCK_M"]), triton.cdiv(DIM, meta["BLOCK_N"]))

    fused_meta = dict(BLOCK_M=16, BLOCK_N=64, num_warps=4, num_stages=3)
    _fused_out_ln_gate_linear_kernel[grid_fused(fused_meta)](
        out2d,
        out_gate2d,
        out_ln_w,
        out_ln_b,
        w_to_out,
        y,
        M,
        DIM,
        H=H,
        eps=1e-5,
        **fused_meta,
    )

    return y.view(B, N, N, DIM)

def custom_kernel(input):
    return kernel_function(*input)

Example with Change

Test
import os
import sys
import math
import time
from typing import Dict, Tuple

import torch
from torch import nn, einsum


# Summary:
# Tests a Triton implementation of AlphaFold-style Triangle Multiplicative Update (TriMul, outgoing)
# over input tensors shaped [B, N, N, C], comparing against a PyTorch reference implementation.
def test_kernel() -> bool:
    try:
        from kernel import kernel_function
        if not callable(kernel_function):
            print("ERROR: kernel_function is not callable")
            return False

        if not torch.cuda.is_available():
            raise RuntimeError("CUDA not available")

        device = torch.device("cuda")

        # Make results more comparable (kernel is Triton; reference uses PyTorch ops).
        torch.backends.cuda.matmul.allow_tf32 = False
        torch.backends.cudnn.allow_tf32 = False

        class TriMul(nn.Module):
            # Reference code (matches prompt)
            def __init__(self, dim: int, hidden_dim: int):
                super().__init__()
                self.norm = nn.LayerNorm(dim)
                self.left_proj = nn.Linear(dim, hidden_dim, bias=False)
                self.right_proj = nn.Linear(dim, hidden_dim, bias=False)
                self.left_gate = nn.Linear(dim, hidden_dim, bias=False)
                self.right_gate = nn.Linear(dim, hidden_dim, bias=False)
                self.out_gate = nn.Linear(dim, hidden_dim, bias=False)
                self.to_out_norm = nn.LayerNorm(hidden_dim)
                self.to_out = nn.Linear(hidden_dim, dim, bias=False)

            def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
                x = self.norm(x)

                left = self.left_proj(x)
                right = self.right_proj(x)

                mask = mask.unsqueeze(-1)
                left = left * mask
                right = right * mask

                left_gate = self.left_gate(x).sigmoid()
                right_gate = self.right_gate(x).sigmoid()
                out_gate = self.out_gate(x).sigmoid()

                left = left * left_gate
                right = right * right_gate

                out = einsum("... i k d, ... j k d -> ... i j d", left, right)

                out = self.to_out_norm(out)
                out = out * out_gate
                return self.to_out(out)

        def _set_all_seeds(seed: int) -> None:
            torch.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)

        def generate_input(
            *,
            seqlen: int,
            bs: int,
            dim: int,
            hiddendim: int,
            seed: int,
            nomask: bool,
            distribution: str,
            dtype: torch.dtype,
        ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor], Dict]:
            _set_all_seeds(seed)
            gen = torch.Generator(device="cuda")
            gen.manual_seed(seed)

            config = {"hidden_dim": hiddendim, "dim": dim}

            if distribution == "cauchy":
                # Match prompt (heavier-tail distribution)
                x = torch.distributions.Cauchy(0.0, 2.0).sample((bs, seqlen, seqlen, dim))
                x = x.to(device=device, dtype=dtype).contiguous()
            else:
                x = torch.randn((bs, seqlen, seqlen, dim), device=device, dtype=dtype, generator=gen).contiguous()

            if nomask:
                mask = torch.ones((bs, seqlen, seqlen), device=device, dtype=dtype)
            else:
                # Use float mask (0/1) for clearer semantics and kernel compatibility.
                mask = torch.randint(0, 2, (bs, seqlen, seqlen), device=device, generator=gen).to(dtype)

            # Weights (match names from prompt)
            # Use generator for reproducibility (both ref and kernel receive identical tensors).
            def randn(shape):
                return torch.randn(shape, device=device, dtype=dtype, generator=gen)

            weights: Dict[str, torch.Tensor] = {}
            weights["norm.weight"] = randn((dim,))
            weights["norm.bias"] = randn((dim,))
            weights["left_proj.weight"] = randn((hiddendim, dim)) / math.sqrt(hiddendim)
            weights["right_proj.weight"] = randn((hiddendim, dim)) / math.sqrt(hiddendim)
            weights["left_gate.weight"] = randn((hiddendim, dim)) / math.sqrt(hiddendim)
            weights["right_gate.weight"] = randn((hiddendim, dim)) / math.sqrt(hiddendim)
            weights["out_gate.weight"] = randn((hiddendim, dim)) / math.sqrt(hiddendim)
            weights["to_out_norm.weight"] = randn((hiddendim,))
            weights["to_out_norm.bias"] = randn((hiddendim,))
            weights["to_out.weight"] = randn((dim, hiddendim)) / math.sqrt(dim)

            return x, mask, weights, config

        def reference_forward(
            x: torch.Tensor, mask: torch.Tensor, weights: Dict[str, torch.Tensor], config: Dict
        ) -> torch.Tensor:
            model = TriMul(dim=config["dim"], hidden_dim=config["hidden_dim"]).to(device=x.device, dtype=x.dtype)
            model.eval()
            with torch.no_grad():
                # Copy weights into module
                model.norm.weight.copy_(weights["norm.weight"])
                model.norm.bias.copy_(weights["norm.bias"])
                model.left_proj.weight.copy_(weights["left_proj.weight"])
                model.right_proj.weight.copy_(weights["right_proj.weight"])
                model.left_gate.weight.copy_(weights["left_gate.weight"])
                model.right_gate.weight.copy_(weights["right_gate.weight"])
                model.out_gate.weight.copy_(weights["out_gate.weight"])
                model.to_out_norm.weight.copy_(weights["to_out_norm.weight"])
                model.to_out_norm.bias.copy_(weights["to_out_norm.bias"])
                model.to_out.weight.copy_(weights["to_out.weight"])
                y = model(x, mask)
            return y

        def call_kernel(
            x: torch.Tensor, mask: torch.Tensor, weights: Dict[str, torch.Tensor], config: Dict
        ) -> torch.Tensor:
            # Per instructions: call as a normal Python function.
            data_tuple = (x, mask, weights, config)

            last_err = None
            for attempt in range(4):
                try:
                    if attempt == 0:
                        out = kernel_function(data_tuple)
                    elif attempt == 1:
                        out = kernel_function(x, mask, weights, config)
                    elif attempt == 2:
                        out = kernel_function(x, mask, weights)
                    else:
                        out = kernel_function(x, mask)
                    break
                except TypeError as e:
                    last_err = e
                    out = None
            else:
                print("ERROR: Could not call kernel_function with common signatures.")
                print(f"Last TypeError: {last_err}")
                raise last_err

            if isinstance(out, (tuple, list)):
                if len(out) == 0:
                    raise RuntimeError("kernel_function returned an empty tuple/list")
                out0 = out[0]
            else:
                out0 = out

            if not isinstance(out0, torch.Tensor):
                raise RuntimeError(f"kernel_function returned non-tensor output type: {type(out0)}")

            return out0

        def compare_tensors(
            y_ref: torch.Tensor,
            y: torch.Tensor,
            *,
            rtol: float,
            atol: float,
            case_desc: str,
            x: torch.Tensor,
            mask: torch.Tensor,
        ) -> bool:
            if y.device != x.device:
                print(f"ERROR[{case_desc}]: device mismatch: result.device={y.device}, input.device={x.device}")
                return False
            if y.shape != y_ref.shape:
                print(f"ERROR[{case_desc}]: shape mismatch: got {tuple(y.shape)} expected {tuple(y_ref.shape)}")
                return False

            y_f = y.float()
            yref_f = y_ref.float()

            ref_finite = torch.isfinite(yref_f)
            out_finite = torch.isfinite(y_f)

            if not torch.equal(ref_finite, out_finite):
                ref_nf = (~ref_finite).sum().item()
                out_nf = (~out_finite).sum().item()
                print(f"ERROR[{case_desc}]: non-finite pattern mismatch")
                print(f"  non-finite count: expected={ref_nf}, got={out_nf}")
                # Show a few indices where mismatch occurs
                bad = (ref_finite != out_finite).flatten().nonzero()[:10].flatten()
                print(f"  first mismatch flat indices (up to 10): {bad.tolist()}")
                return False

            # Compare only finite values (if any non-finite exist).
            finite_mask = ref_finite
            if finite_mask.all():
                ok = torch.allclose(y_f, yref_f, rtol=rtol, atol=atol)
            else:
                ok = torch.allclose(y_f[finite_mask], yref_f[finite_mask], rtol=rtol, atol=atol)

            if not ok:
                diff = (y_f - yref_f).abs()
                max_abs = diff[finite_mask].max().item() if finite_mask.any() else float("nan")
                denom = (yref_f.abs() + 1e-8)
                rel = (diff / denom)
                max_rel = rel[finite_mask].max().item() if finite_mask.any() else float("nan")

                print(f"NUMERICAL MISMATCH[{case_desc}] (rtol={rtol}, atol={atol}):")
                print(f"  input: shape={tuple(x.shape)} dtype={x.dtype} device={x.device}")
                print(f"  mask:  shape={tuple(mask.shape)} dtype={mask.dtype} device={mask.device}")
                print(f"  out:   shape={tuple(y.shape)} dtype={y.dtype} device={y.device}")
                print(f"  ref:   shape={tuple(y_ref.shape)} dtype={y_ref.dtype} device={y_ref.device}")
                print(f"  max_abs_diff={max_abs:.6g}, max_rel_diff={max_rel:.6g}")

                flat_ref = yref_f.flatten()
                flat_out = y_f.flatten()
                flat_diff = diff.flatten()
                k = min(10, flat_diff.numel())
                topk = torch.topk(flat_diff, k=k).indices
                print("  worst elements (index, ref, out, absdiff):")
                for idx in topk.tolist():
                    print(f"    {idx}: {flat_ref[idx].item():.6g}, {flat_out[idx].item():.6g}, {flat_diff[idx].item():.6g}")

                print("  samples:")
                print(f"    ref[:10]={flat_ref[:10].tolist()}")
                print(f"    out[:10]={flat_out[:10].tolist()}")

                return False

            return True

        # Representative subset of the provided specs (exact sizes from prompt),
        # avoiding extremely large N that would make the reference O(N^3) too slow for a unit test.
        cases = [
            # normal distribution
            dict(bs=1, dim=128, hiddendim=128, nomask=True, seed=9371, seqlen=32, distribution="normal"),
            dict(bs=1, dim=128, hiddendim=128, nomask=False, seed=1092, seqlen=32, distribution="normal"),
            dict(bs=2, dim=256, hiddendim=128, nomask=True, seed=2291, seqlen=64, distribution="normal"),
            dict(bs=1, dim=768, hiddendim=128, nomask=True, seed=81934, seqlen=128, distribution="normal"),
            dict(bs=1, dim=128, hiddendim=128, nomask=False, seed=10432, seqlen=256, distribution="normal"),
            # cauchy distribution
            dict(bs=1, dim=128, hiddendim=128, nomask=True, seed=937321, seqlen=32, distribution="cauchy"),
            dict(bs=2, dim=256, hiddendim=128, nomask=True, seed=2291, seqlen=64, distribution="cauchy"),
        ]

        all_ok = True
        with torch.no_grad():
            for i, spec in enumerate(cases):
                # Use BF16 for normal (per requirement to avoid FP32); use FP32 for cauchy to reduce overflow risk.
                # Tolerances adjusted accordingly.
                if spec["distribution"] == "cauchy":
                    dtype = torch.float32
                    rtol, atol = 1e-3, 1e-3
                else:
                    dtype = torch.bfloat16
                    # BF16 + large reductions (TriMul) needs looser tolerance.
                    rtol, atol = 2e-2, 2e-2  # documented: bf16 precision + accumulation

                case_desc = (
                    f"case{i}: dist={spec['distribution']} bs={spec['bs']} N={spec['seqlen']} "
                    f"dim={spec['dim']} hidden={spec['hiddendim']} nomask={spec['nomask']} seed={spec['seed']} dtype={dtype}"
                )
                print(f"Running {case_desc}")

                try:
                    x, mask, weights, config = generate_input(dtype=dtype, **spec)

                    torch.cuda.synchronize()
                    t0 = time.perf_counter()
                    y_ref = reference_forward(x, mask, weights, config)
                    torch.cuda.synchronize()
                    t_ref = time.perf_counter() - t0

                    torch.cuda.synchronize()
                    t0 = time.perf_counter()
                    y = call_kernel(x, mask, weights, config)
                    torch.cuda.synchronize()
                    t_kernel = time.perf_counter() - t0

                    ok = compare_tensors(y_ref, y, rtol=rtol, atol=atol, case_desc=case_desc, x=x, mask=mask)
                    print(f"  ref_time={t_ref:.4f}s kernel_time={t_kernel:.4f}s -> {'PASS' if ok else 'FAIL'}")
                    all_ok = all_ok and ok

                except torch.cuda.OutOfMemoryError as e:
                    print(f"ERROR[{case_desc}]: CUDA OOM: {e}")
                    return False
                except NameError as e:
                    print(f"ERROR[{case_desc}]: NameError (likely undefined helper in kernel.py): {e}")
                    return False
                except Exception as e:
                    print(f"ERROR[{case_desc}]: Exception: {type(e).__name__}: {e}")
                    return False
                finally:
                    # Reduce peak memory between cases
                    try:
                        del x, mask, weights, config, y_ref, y
                    except Exception:
                        pass
                    torch.cuda.empty_cache()

        return all_ok

    except Exception as e:
        if isinstance(e, NameError):
            print(f"Test failed: NameError (likely undefined helper in kernel.py): {e}")
        else:
            print(f"Test failed: {type(e).__name__}: {e}")
        return False


if __name__ == "__main__":
    success = test_kernel()
    sys.exit(0 if success else 1)
Kernel
from __future__ import annotations

from typing import Any, Dict, Tuple

import torch
import triton
import triton.language as tl


def _parse_args(*args):
    # Supports call patterns used by the test harness.
    if len(args) == 1 and isinstance(args[0], (tuple, list)):
        t = args[0]
        if len(t) < 2:
            raise TypeError("Expected at least (x, mask) in the input tuple/list.")
        x = t[0]
        mask = t[1]
        weights = t[2] if len(t) > 2 else None
        config = t[3] if len(t) > 3 else None
        return x, mask, weights, config

    if len(args) >= 2:
        x = args[0]
        mask = args[1]
        weights = args[2] if len(args) > 2 else None
        config = args[3] if len(args) > 3 else None
        return x, mask, weights, config

    raise TypeError("Unsupported kernel_function signature.")


@triton.jit
def _layernorm_2d_kernel(
    x_ptr, w_ptr, b_ptr, y_ptr,
    M, N,
    stride_xm, stride_xn,
    stride_ym, stride_yn,
    eps: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    """
    LayerNorm over the last dim for a 2D tensor [M, N]. Affine params w,b are length N.

    IMPORTANT: When BLOCK_N > N, masked lanes must not contribute to variance. We zero them out
    explicitly during var accumulation.
    """
    row = tl.program_id(0)
    row_mask = row < M

    cols = tl.arange(0, BLOCK_N)
    col_mask = cols < N

    x = tl.load(
        x_ptr + row * stride_xm + cols * stride_xn,
        mask=row_mask & col_mask,
        other=0.0,
    ).to(tl.float32)

    n_f = tl.full((), N, tl.float32)
    mean = tl.sum(x, axis=0) / n_f

    x_centered = tl.where(col_mask, x - mean, 0.0)
    var = tl.sum(x_centered * x_centered, axis=0) / n_f
    rstd = tl.math.rsqrt(var + eps)

    w = tl.load(w_ptr + cols, mask=col_mask, other=0.0).to(tl.float32)
    b = tl.load(b_ptr + cols, mask=col_mask, other=0.0).to(tl.float32)

    y = (x_centered * rstd) * w + b
    tl.store(
        y_ptr + row * stride_ym + cols * stride_yn,
        y.to(y_ptr.dtype.element_ty),
        mask=row_mask & col_mask,
    )


@triton.jit
def _sigmoid_fp32(x_f32):
    # sigmoid(x) = 1 / (1 + exp(-x)); use exp2 for portability
    log2e = 1.4426950408889634
    e = tl.math.exp2(-x_f32 * log2e)
    return 1.0 / (1.0 + e)


@triton.jit
def _proj_gate_kernel(
    x_ptr, mask_ptr,
    w_lp_ptr, w_rp_ptr, w_lg_ptr, w_rg_ptr, w_og_ptr,
    left_ptr, right_ptr, og_ptr,
    M, K, H,
    stride_xm, stride_xk,
    stride_wm, stride_wk,   # weights are [H, K]
    stride_om, stride_oh,   # outputs are [M, H]
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,  # block over H
    BLOCK_K: tl.constexpr,
):
    """
    Fused stage:
      left  = (x @ Wlp^T) * mask * sigmoid(x @ Wlg^T)
      right = (x @ Wrp^T) * mask * sigmoid(x @ Wrg^T)
      og    = sigmoid(x @ Wog^T)
    for x in [M, K], weights in [H, K], outputs in [M, H].
    """
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)

    m_mask = offs_m < M
    n_mask = offs_n < H

    acc_lp = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    acc_rp = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    acc_lg = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    acc_rg = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    acc_og = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

    # Runtime K-loop without tl.any/tl.break.
    for k0 in tl.range(0, K, BLOCK_K):
        k0 = tl.multiple_of(k0, BLOCK_K)
        offs_k = k0 + tl.arange(0, BLOCK_K)
        k_mask = offs_k < K

        a = tl.load(
            x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk,
            mask=m_mask[:, None] & k_mask[None, :],
            other=0.0,
        )

        # Load blocks of W^T as [BLOCK_K, BLOCK_N] where element (k,n) = W[n,k].
        w_lp = tl.load(
            w_lp_ptr + offs_n[None, :] * stride_wm + offs_k[:, None] * stride_wk,
            mask=k_mask[:, None] & n_mask[None, :],
            other=0.0,
        )
        w_rp = tl.load(
            w_rp_ptr + offs_n[None, :] * stride_wm + offs_k[:, None] * stride_wk,
            mask=k_mask[:, None] & n_mask[None, :],
            other=0.0,
        )
        w_lg = tl.load(
            w_lg_ptr + offs_n[None, :] * stride_wm + offs_k[:, None] * stride_wk,
            mask=k_mask[:, None] & n_mask[None, :],
            other=0.0,
        )
        w_rg = tl.load(
            w_rg_ptr + offs_n[None, :] * stride_wm + offs_k[:, None] * stride_wk,
            mask=k_mask[:, None] & n_mask[None, :],
            other=0.0,
        )
        w_og = tl.load(
            w_og_ptr + offs_n[None, :] * stride_wm + offs_k[:, None] * stride_wk,
            mask=k_mask[:, None] & n_mask[None, :],
            other=0.0,
        )

        acc_lp = tl.dot(a, w_lp, acc_lp, allow_tf32=False)
        acc_rp = tl.dot(a, w_rp, acc_rp, allow_tf32=False)
        acc_lg = tl.dot(a, w_lg, acc_lg, allow_tf32=False)
        acc_rg = tl.dot(a, w_rg, acc_rg, allow_tf32=False)
        acc_og = tl.dot(a, w_og, acc_og, allow_tf32=False)

    mask_val = tl.load(mask_ptr + offs_m, mask=m_mask, other=0.0).to(tl.float32)[:, None]

    lg = _sigmoid_fp32(acc_lg)
    rg = _sigmoid_fp32(acc_rg)
    og = _sigmoid_fp32(acc_og)

    left = acc_lp * mask_val * lg
    right = acc_rp * mask_val * rg

    out_mask = m_mask[:, None] & n_mask[None, :]
    tl.store(left_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_oh,
             left.to(left_ptr.dtype.element_ty), mask=out_mask)
    tl.store(right_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_oh,
             right.to(right_ptr.dtype.element_ty), mask=out_mask)
    tl.store(og_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_oh,
             og.to(og_ptr.dtype.element_ty), mask=out_mask)


@triton.jit
def _trimul_out_kernel(
    left_ptr, right_ptr, out_ptr,
    B, N, H,
    stride_lb, stride_li, stride_lj, stride_lh,
    stride_rb, stride_ri, stride_rj, stride_rh,
    stride_ob, stride_oi, stride_oj, stride_oh,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
):
    """
    Triangle multiplicative update (outgoing):
      out[b,i,j,h] = sum_k left[b,i,k,h] * right[b,j,k,h]
    One program computes a tile (i-block, j-block) for a single (b,h).
    """
    pid_i = tl.program_id(0)
    pid_j = tl.program_id(1)
    pid_bh = tl.program_id(2)

    b = pid_bh // H
    h = pid_bh - b * H

    offs_i = pid_i * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_j = pid_j * BLOCK_N + tl.arange(0, BLOCK_N)

    i_mask = offs_i < N
    j_mask = offs_j < N

    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

    for k0 in tl.range(0, N, BLOCK_K):
        k0 = tl.multiple_of(k0, BLOCK_K)
        offs_k = k0 + tl.arange(0, BLOCK_K)
        k_mask = offs_k < N

        left_ptrs = (
            left_ptr
            + b * stride_lb
            + offs_i[:, None] * stride_li
            + offs_k[None, :] * stride_lj
            + h * stride_lh
        )
        right_ptrs = (
            right_ptr
            + b * stride_rb
            + offs_j[:, None] * stride_ri
            + offs_k[None, :] * stride_rj
            + h * stride_rh
        )

        l = tl.load(left_ptrs, mask=i_mask[:, None] & k_mask[None, :], other=0.0)
        r = tl.load(right_ptrs, mask=j_mask[:, None] & k_mask[None, :], other=0.0)

        acc = tl.dot(l, r.T, acc, allow_tf32=False)

    out_ptrs = (
        out_ptr
        + b * stride_ob
        + offs_i[:, None] * stride_oi
        + offs_j[None, :] * stride_oj
        + h * stride_oh
    )
    out_mask = i_mask[:, None] & j_mask[None, :]
    tl.store(out_ptrs, acc.to(out_ptr.dtype.element_ty), mask=out_mask)


@triton.jit
def _out_ln_gate_linear_kernel(
    x_ptr, gate_ptr,
    gamma_ptr, beta_ptr,
    w_to_out_ptr,
    y_ptr,
    M, C,
    stride_xm, stride_xk,       # x is [M, H]
    stride_gm, stride_gk,       # gate is [M, H]
    stride_wm, stride_wk,       # w_to_out is [C, H]
    stride_ym, stride_yn,       # y is [M, C]
    eps: tl.constexpr,
    H: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    """
    Fused stage:
      a = LayerNorm(x) over H with affine (gamma,beta)
      a = a * gate
      y = a @ W_to_out^T, where W_to_out is [C, H] (PyTorch linear weight layout)
    """
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, H)

    m_mask = offs_m < M
    n_mask = offs_n < C

    x = tl.load(
        x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk,
        mask=m_mask[:, None],
        other=0.0,
    ).to(tl.float32)

    h_f = tl.full((), H, tl.float32)
    mean = tl.sum(x, axis=1) / h_f
    x0 = x - mean[:, None]
    var = tl.sum(x0 * x0, axis=1) / h_f
    rstd = tl.math.rsqrt(var + eps)

    gamma = tl.load(gamma_ptr + offs_k, mask=offs_k < H, other=0.0).to(tl.float32)
    beta = tl.load(beta_ptr + offs_k, mask=offs_k < H, other=0.0).to(tl.float32)

    a = (x0 * rstd[:, None]) * gamma[None, :] + beta[None, :]

    gate = tl.load(
        gate_ptr + offs_m[:, None] * stride_gm + offs_k[None, :] * stride_gk,
        mask=m_mask[:, None],
        other=0.0,
    ).to(tl.float32)

    a = a * gate

    # Match common GEMM behavior: bf16 inputs for bf16 path; fp32 stays fp32.
    a_dot = a.to(x_ptr.dtype.element_ty)

    # Load W_to_out^T block: w[k, n] = W_to_out[n, k]
    w = tl.load(
        w_to_out_ptr + offs_n[None, :] * stride_wm + offs_k[:, None] * stride_wk,
        mask=n_mask[None, :],
        other=0.0,
    ).to(w_to_out_ptr.dtype.element_ty)

    acc = tl.dot(a_dot, w, allow_tf32=False)
    y_ptrs = y_ptr + offs_m[:, None] * stride_ym + offs_n[None, :] * stride_yn
    tl.store(y_ptrs, acc.to(y_ptr.dtype.element_ty), mask=m_mask[:, None] & n_mask[None, :])


def kernel_function(*args: Any) -> torch.Tensor:
    """
    Triton implementation of AlphaFold-style Triangle Multiplicative Update (outgoing).

    Wrapper responsibilities only:
      - parse arguments
      - validate shapes/dtypes/devices
      - allocate intermediates/outputs
      - configure and launch Triton kernels
    """
    x, mask, weights, _config = _parse_args(*args)

    if not isinstance(x, torch.Tensor) or not isinstance(mask, torch.Tensor):
        raise TypeError("x and mask must be torch.Tensors.")
    if weights is None or not isinstance(weights, dict):
        raise TypeError("weights must be provided as a dict.")

    if not x.is_cuda or not mask.is_cuda:
        raise ValueError("x and mask must be CUDA tensors.")
    if x.dtype not in (torch.bfloat16, torch.float32):
        raise ValueError(f"Unsupported dtype: {x.dtype} (expected bf16 or fp32).")
    if mask.dtype != x.dtype:
        raise ValueError(f"mask dtype must match x dtype; got mask={mask.dtype}, x={x.dtype}")

    if x.ndim != 4:
        raise ValueError(f"Expected x.ndim == 4 [B,N,N,C], got shape={tuple(x.shape)}")
    if mask.ndim != 3:
        raise ValueError(f"Expected mask.ndim == 3 [B,N,N], got shape={tuple(mask.shape)}")

    B, N1, N2, C = x.shape
    if N1 != N2:
        raise ValueError(f"Expected square pair dims N,N, got {N1} and {N2}")
    N = N1
    if tuple(mask.shape) != (B, N, N):
        raise ValueError(f"mask shape mismatch: got {tuple(mask.shape)}, expected {(B, N, N)}")

    # Extract weights (no compute)
    norm_w = weights["norm.weight"]
    norm_b = weights["norm.bias"]
    w_lp = weights["left_proj.weight"]
    w_rp = weights["right_proj.weight"]
    w_lg = weights["left_gate.weight"]
    w_rg = weights["right_gate.weight"]
    w_og = weights["out_gate.weight"]
    out_norm_w = weights["to_out_norm.weight"]
    out_norm_b = weights["to_out_norm.bias"]
    w_to_out = weights["to_out.weight"]

    # Validate weight shapes
    if norm_w.shape != (C,) or norm_b.shape != (C,):
        raise ValueError("norm.weight/norm.bias must be shape (dim,)")
    H = w_lp.shape[0]
    if (w_lp.shape != (H, C) or w_rp.shape != (H, C) or w_lg.shape != (H, C) or
            w_rg.shape != (H, C) or w_og.shape != (H, C)):
        raise ValueError("All projection/gate weights must have shape (hidden_dim, dim)")
    if out_norm_w.shape != (H,) or out_norm_b.shape != (H,):
        raise ValueError("to_out_norm.weight/to_out_norm.bias must be shape (hidden_dim,)")
    if w_to_out.shape != (C, H):
        raise ValueError("to_out.weight must have shape (dim, hidden_dim)")

    if any((not t.is_cuda) or (t.dtype != x.dtype)
           for t in (norm_w, norm_b, w_lp, w_rp, w_lg, w_rg, w_og, out_norm_w, out_norm_b, w_to_out)):
        raise ValueError("All weights must be CUDA tensors with the same dtype as x.")

    # Flatten pair dims
    M = B * N * N
    x2 = x.view(M, C)
    mask1 = mask.view(M)

    # Intermediates / outputs
    x_ln = torch.empty_like(x)
    x_ln2 = x_ln.view(M, C)

    left = torch.empty((B, N, N, H), device=x.device, dtype=x.dtype)
    right = torch.empty((B, N, N, H), device=x.device, dtype=x.dtype)
    out_gate = torch.empty((B, N, N, H), device=x.device, dtype=x.dtype)

    left2 = left.view(M, H)
    right2 = right.view(M, H)
    out_gate2 = out_gate.view(M, H)

    out_hidden = torch.empty((B, N, N, H), device=x.device, dtype=x.dtype)
    out_hidden2 = out_hidden.view(M, H)

    y = torch.empty((B, N, N, C), device=x.device, dtype=x.dtype)
    y2 = y.view(M, C)

    # ---- Kernel 1: LayerNorm(x) over C ----
    BLOCK_N_LN = 1024  # covers max C=768; masks handle smaller.
    _layernorm_2d_kernel[(M,)](
        x2, norm_w, norm_b, x_ln2,
        M, C,
        x2.stride(0), x2.stride(1),
        x_ln2.stride(0), x_ln2.stride(1),
        eps=1e-5,
        BLOCK_N=BLOCK_N_LN,
        num_warps=4,
    )

    # ---- Kernel 2: Fused projections + gates + mask ----
    BLOCK_M_P = 16
    BLOCK_N_P = 64
    BLOCK_K_P = 32
    grid_proj = (triton.cdiv(M, BLOCK_M_P), triton.cdiv(H, BLOCK_N_P))
    _proj_gate_kernel[grid_proj](
        x_ln2, mask1,
        w_lp, w_rp, w_lg, w_rg, w_og,
        left2, right2, out_gate2,
        M, C, H,
        x_ln2.stride(0), x_ln2.stride(1),
        w_lp.stride(0), w_lp.stride(1),
        left2.stride(0), left2.stride(1),
        BLOCK_M=BLOCK_M_P, BLOCK_N=BLOCK_N_P, BLOCK_K=BLOCK_K_P,
        num_warps=4,
        num_stages=2,
    )

    # ---- Kernel 3: Triangle multiplicative update (outgoing) ----
    BLOCK_M_T = 32
    BLOCK_N_T = 32
    BLOCK_K_T = 32
    grid_tri = (triton.cdiv(N, BLOCK_M_T), triton.cdiv(N, BLOCK_N_T), B * H)
    _trimul_out_kernel[grid_tri](
        left, right, out_hidden,
        B, N, H,
        left.stride(0), left.stride(1), left.stride(2), left.stride(3),
        right.stride(0), right.stride(1), right.stride(2), right.stride(3),
        out_hidden.stride(0), out_hidden.stride(1), out_hidden.stride(2), out_hidden.stride(3),
        BLOCK_M=BLOCK_M_T, BLOCK_N=BLOCK_N_T, BLOCK_K=BLOCK_K_T,
        num_warps=4,
        num_stages=2,
    )

    # ---- Kernel 4: Fused output LayerNorm(H) + out_gate + final linear to C ----
    BLOCK_M_O = 16
    BLOCK_N_O = 128
    grid_out = (triton.cdiv(M, BLOCK_M_O), triton.cdiv(C, BLOCK_N_O))
    _out_ln_gate_linear_kernel[grid_out](
        out_hidden2, out_gate2,
        out_norm_w, out_norm_b,
        w_to_out,
        y2,
        M, C,
        out_hidden2.stride(0), out_hidden2.stride(1),
        out_gate2.stride(0), out_gate2.stride(1),
        w_to_out.stride(0), w_to_out.stride(1),
        y2.stride(0), y2.stride(1),
        eps=1e-5,
        H=H,
        BLOCK_M=BLOCK_M_O, BLOCK_N=BLOCK_N_O,
        num_warps=4,
        num_stages=2,
    )

    return y

Test

python -m Fuser.auto_agent --problem /home/jackkhuu/KernelAgent/problem.py --ka-model gpt-5-2 --router-model gpt-5-2 --extract-model gpt-5-2 --dispatch-model gpt-5-2 --compose-model gpt-5-2 --verify --ignore-router-config --no-router-cache --test-timeout-s 120

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 7, 2026
@Jack-Khuu Jack-Khuu merged commit d0fb93b into main Feb 7, 2026
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants