From c69a3ad6281e680694642baf0182029b69afeefd Mon Sep 17 00:00:00 2001 From: andrej Date: Mon, 6 Apr 2026 16:46:36 -0600 Subject: [PATCH 1/8] add single-dispatch layer-by-layer MHA --- aie_kernels/aie2p/softmax.cc | 3 +- iron/operators/gemm/design.py | 4 +- iron/operators/mha_prefill_lxl_sd/__init__.py | 2 + iron/operators/mha_prefill_lxl_sd/op.py | 245 ++++++++++++++++++ .../operators/mha_prefill_lxl_sd/reference.py | 206 +++++++++++++++ iron/operators/mha_prefill_lxl_sd/test.py | 186 +++++++++++++ 6 files changed, 643 insertions(+), 3 deletions(-) create mode 100644 iron/operators/mha_prefill_lxl_sd/__init__.py create mode 100644 iron/operators/mha_prefill_lxl_sd/op.py create mode 100644 iron/operators/mha_prefill_lxl_sd/reference.py create mode 100644 iron/operators/mha_prefill_lxl_sd/test.py diff --git a/aie_kernels/aie2p/softmax.cc b/aie_kernels/aie2p/softmax.cc index 64cca202..5778682a 100644 --- a/aie_kernels/aie2p/softmax.cc +++ b/aie_kernels/aie2p/softmax.cc @@ -3,6 +3,7 @@ #include #include +#include #define SM_VEC_LEN 64 // 32 #define log2e 1.4453125 // 1.44269504089 @@ -30,7 +31,7 @@ void softmax_simple_bf16(bfloat16 *restrict input_vector, bfloat16 *restrict out aie::vector in_elems, exp_val, input_bf16, log2e_vec, max_val_vec; aie::accum out_vals, exp_val_accum, scaled_accum, exp_in_accum; - float max_val = 0; + float max_val = -INFINITY; float accum_exp_val = 0; float running_max = 0; bfloat16 col_sum_inv; diff --git a/iron/operators/gemm/design.py b/iron/operators/gemm/design.py index a8ed8ad3..32222dd4 100644 --- a/iron/operators/gemm/design.py +++ b/iron/operators/gemm/design.py @@ -299,7 +299,7 @@ def my_matmul( gemm_object, [C_l1_ty_internal], ) - matmul_func_name = f"matmul{scalar_suffix}_{dtype_in_str}_f32" + matmul_func_name = f"{func_prefix}matmul{scalar_suffix}_{dtype_in_str}_f32" matmul_kernel = Kernel( matmul_func_name, gemm_object, @@ -314,7 +314,7 @@ def my_matmul( gemm_object, [C_l1_ty], ) - matmul_func_name = f"matmul{scalar_suffix}_{dtype_in_str}_{dtype_out_str}" + matmul_func_name = f"{func_prefix}matmul{scalar_suffix}_{dtype_in_str}_{dtype_out_str}" matmul_kernel = Kernel( matmul_func_name, gemm_object, diff --git a/iron/operators/mha_prefill_lxl_sd/__init__.py b/iron/operators/mha_prefill_lxl_sd/__init__.py new file mode 100644 index 00000000..82f09a67 --- /dev/null +++ b/iron/operators/mha_prefill_lxl_sd/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/iron/operators/mha_prefill_lxl_sd/op.py b/iron/operators/mha_prefill_lxl_sd/op.py new file mode 100644 index 00000000..fe951829 --- /dev/null +++ b/iron/operators/mha_prefill_lxl_sd/op.py @@ -0,0 +1,245 @@ +# SPDX-FileCopyrightText: Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +A layer-by-layer (LxL) single-dispatch (SD) implementation of multi-head attention (MHA). +""" + +from iron.common.context import AIEContext +from iron.common.fusion import FusedMLIROperator +from iron.operators.gemm.op import AIEGEMM +from iron.operators.rope.op import AIERope +from iron.operators.strided_copy.op import AIEStridedCopy +from iron.operators.repeat.op import AIERepeat +from iron.operators.softmax.op import AIESoftmax +from iron.operators.transpose.op import AIETranspose +from iron.operators.elementwise_mul.op import AIEElementwiseMul +from iron.operators.elementwise_add.op import AIEElementwiseAdd + + +def _pick_tile_n(N, num_cols, max_tile_n=64): + tile_n = N // num_cols + while tile_n > max_tile_n: + tile_n //= 2 + assert N % (tile_n * num_cols) == 0 + return tile_n + + +def _build_core_ops(H, G, d, E, S, elf_ctx): + """Build core attention sub-ops and runlist (no projections/RoPE/GQA). + + Expects pre-processed inputs: + queries: (H, S, d) deinterleaved, contiguous per head + keys: (H, d, S) transposed and GQA-repeated + values: (H, S, d) GQA-repeated + """ + B = 2 # bytes per bf16 element + + gemm_scores = AIEGEMM( + M=S, K=d, N=S, num_aie_columns=8, tile_m=16, tile_k=64, + tile_n=_pick_tile_n(S, 8), context=elf_ctx, + ) + scale = AIEElementwiseMul( + size=H * S * S, tile_size=S * S // 8, + num_aie_columns=8, context=elf_ctx, + ) + mask = AIEElementwiseAdd( + size=H * S * S, tile_size=S * S // 8, + num_aie_columns=8, context=elf_ctx, + ) + softmax = AIESoftmax( + rows=H * S, cols=S, num_aie_columns=1, num_channels=1, + rtp_vector_size=S, context=elf_ctx, + ) + gemm_context = AIEGEMM( + M=S, K=S, N=d, num_aie_columns=4, tile_m=16, tile_k=64, + tile_n=16, context=elf_ctx, prio_accuracy=True, + ) + reinterleave = AIEStridedCopy( + input_sizes=(H, S, d), input_strides=(S * d, d, 1), input_offset=0, + output_sizes=(H, S, d), output_strides=(d, H * d, 1), output_offset=0, + input_buffer_size=H * S * d, output_buffer_size=S * H * d, + transfer_size=S * d, num_aie_channels=1, context=elf_ctx, + ) + gemm_output = AIEGEMM( + M=S, K=H * d, N=E, num_aie_columns=8, tile_m=16, tile_k=64, + tile_n=_pick_tile_n(E, 8), context=elf_ctx, prio_accuracy=True, + ) + + qh = S * d * B + kdS = d * S * B + kSd = S * d * B + sh = S * S * B + ch = S * d * B + + runlist = [ + *[(gemm_scores, + f"queries[{h*qh}:{(h+1)*qh}]", + f"keys[{h*kdS}:{(h+1)*kdS}]", + f"attn_scores[{h*sh}:{(h+1)*sh}]") + for h in range(H)], + (scale, "attn_scores", "attn_scale_factor", "attn_scores"), + (mask, "attn_scores", "causal_mask", "attn_scores_masked"), + (softmax, "attn_scores_masked", "attn_weights"), + *[(gemm_context, + f"attn_weights[{h*sh}:{(h+1)*sh}]", + f"values[{h*kSd}:{(h+1)*kSd}]", + f"attn_context[{h*ch}:{(h+1)*ch}]") + for h in range(H)], + (reinterleave, "attn_context", "context_interleaved"), + (gemm_output, "context_interleaved", "W_output", "attn_output"), + ] + + buffer_sizes = { + "queries": H * S * d * B, + "keys": H * d * S * B, + "values": H * S * d * B, + "attn_scores": H * S * S * B, + "attn_scores_masked": H * S * S * B, + "attn_weights": H * S * S * B, + "attn_context": H * S * d * B, + "context_interleaved": S * H * d * B, + } + + return runlist, buffer_sizes + + +class AIEAttentionPrefillFused(FusedMLIROperator): + """Fused attention prefill (core, no projections/RoPE). + + Accepts pre-projected Q (S*H,d), K (S*G,d), V (S*G,d) in interleaved layout. + """ + + def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, + seq_len, context=None): + assert head_dim == 64 + assert num_heads % num_kv_groups == 0 + assert seq_len % 256 == 0 + assert (num_heads * seq_len) % 16 == 0 + + self.num_heads = num_heads + self.num_kv_groups = num_kv_groups + self.head_dim = head_dim + self.embedding_dim = embedding_dim + self.seq_len = seq_len + + elf_ctx = context or AIEContext() + runlist, buffer_sizes = _build_core_ops( + num_heads, num_kv_groups, head_dim, embedding_dim, seq_len, elf_ctx, + ) + + super().__init__( + name=f"attention_prefill_fused_{num_heads}h{num_kv_groups}g{head_dim}d{embedding_dim}e{seq_len}s", + runlist=runlist, + input_args=["queries", "keys", "values", + "W_output", "attn_scale_factor", "causal_mask"], + output_args=["attn_output"], + buffer_sizes=buffer_sizes, + context=elf_ctx, + ) + + +class AIEAttentionPrefillProjectedFused(FusedMLIROperator): + """Fused attention prefill with Q/K/V projections and RoPE. + + Accepts raw input (S, E) and rope_angles (S, d). + """ + + def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, + seq_len, context=None): + assert head_dim == 64 + assert num_heads % num_kv_groups == 0 + assert seq_len % 256 == 0 + assert (num_heads * seq_len) % 16 == 0 + + self.num_heads = num_heads + self.num_kv_groups = num_kv_groups + self.head_dim = head_dim + self.embedding_dim = embedding_dim + self.seq_len = seq_len + + H, G, d, E, S = num_heads, num_kv_groups, head_dim, embedding_dim, seq_len + group_size = H // G + B = 2 + + elf_ctx = context or AIEContext() + + # ---- Projection + RoPE ---- + gemm_query = AIEGEMM( + M=S, K=E, N=H * d, num_aie_columns=8, tile_m=16, tile_k=64, + tile_n=_pick_tile_n(H * d, 8), context=elf_ctx, + ) + gemm_kv = AIEGEMM( + M=S, K=E, N=G * d, num_aie_columns=8, tile_m=16, tile_k=64, + tile_n=_pick_tile_n(G * d, 8), context=elf_ctx, + ) + rope_queries = AIERope(rows=S * H, cols=d, angle_rows=S, context=elf_ctx) + rope_keys = AIERope(rows=S * G, cols=d, angle_rows=S, context=elf_ctx) + + # ---- Deinterleave ---- + deinterleave_q = AIEStridedCopy( + input_sizes=(H, S, d), input_strides=(d, H * d, 1), input_offset=0, + output_sizes=(H, S, d), output_strides=(S * d, d, 1), output_offset=0, + input_buffer_size=S * H * d, output_buffer_size=H * S * d, + transfer_size=S * d, num_aie_channels=1, context=elf_ctx, + ) + deinterleave_kv = AIEStridedCopy( + input_sizes=(G, S, d), input_strides=(d, G * d, 1), input_offset=0, + output_sizes=(G, S, d), output_strides=(S * d, d, 1), output_offset=0, + input_buffer_size=S * G * d, output_buffer_size=G * S * d, + transfer_size=S * d, num_aie_channels=1, context=elf_ctx, + ) + + # ---- Transpose keys + GQA repeat ---- + transpose_keys = AIETranspose( + M=S, N=d, num_aie_columns=2, num_channels=1, + m=256, n=32, s=8, context=elf_ctx, + ) + repeat_kv = AIERepeat( + rows=G, cols=d * S, repeat=group_size, + transfer_size=d, context=elf_ctx, + ) + + kSd = S * d * B + kdS = d * S * B + + prefix_runlist = [ + (gemm_query, "input", "W_query", "queries_projected"), + (gemm_kv, "input", "W_key", "keys_projected"), + (gemm_kv, "input", "W_value", "values_projected"), + (rope_queries, "queries_projected", "rope_angles", "queries_roped"), + (rope_keys, "keys_projected", "rope_angles", "keys_roped"), + (deinterleave_q, "queries_roped", "queries"), + (deinterleave_kv, "keys_roped", "keys_deint"), + (deinterleave_kv, "values_projected", "values_deint"), + *[(transpose_keys, + f"keys_deint[{g*kSd}:{(g+1)*kSd}]", + f"keys_transposed[{g*kdS}:{(g+1)*kdS}]") + for g in range(G)], + (repeat_kv, "keys_transposed", "keys"), + (repeat_kv, "values_deint", "values"), + ] + prefix_buffer_sizes = { + "queries_projected": S * H * d * B, + "keys_projected": S * G * d * B, + "values_projected": S * G * d * B, + "queries_roped": S * H * d * B, + "keys_roped": S * G * d * B, + "keys_deint": G * S * d * B, + "values_deint": G * S * d * B, + "keys_transposed": G * d * S * B, + } + + core_runlist, core_buffer_sizes = _build_core_ops( + H, G, d, E, S, elf_ctx, + ) + + super().__init__( + name=f"attention_prefill_projected_fused_{H}h{G}g{d}d{E}e{S}s", + runlist=prefix_runlist + core_runlist, + input_args=["input", "rope_angles", "W_query", "W_key", "W_value", + "W_output", "attn_scale_factor", "causal_mask"], + output_args=["attn_output"], + buffer_sizes={**prefix_buffer_sizes, **core_buffer_sizes}, + context=elf_ctx, + ) diff --git a/iron/operators/mha_prefill_lxl_sd/reference.py b/iron/operators/mha_prefill_lxl_sd/reference.py new file mode 100644 index 00000000..484f4d7f --- /dev/null +++ b/iron/operators/mha_prefill_lxl_sd/reference.py @@ -0,0 +1,206 @@ +# SPDX-FileCopyrightText: Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import torch +import numpy as np +from ml_dtypes import bfloat16 + + +def apply_rope(x, lut): + """Apply Rotary Position Embedding using pre-computed cos/sin LUT. + + x: (rows, cols) — rows are (positions * heads) interleaved + lut: (angle_rows, cols) — interleaved [cos_0, sin_0, cos_1, sin_1, ...] + If angle_rows < rows, each angle row is reused for + (rows // angle_rows) consecutive input rows (block repetition). + Returns: (rows, cols) with RoPE applied (two-halves method) + """ + rows, cols = x.shape + angle_rows = lut.shape[0] + half = cols // 2 + + cos = lut[:, ::2] # (angle_rows, half) + sin = lut[:, 1::2] # (angle_rows, half) + + if angle_rows < rows: + # Block repetition: each angle row repeats for consecutive input rows + repeats = rows // angle_rows + cos = cos.repeat_interleave(repeats, dim=0) # (rows, half) + sin = sin.repeat_interleave(repeats, dim=0) # (rows, half) + + x1 = x[:, :half] + x2 = x[:, half:] + out = torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1) + return out + + +def generate_golden_reference( + num_heads, + num_kv_groups, + head_dim, + embedding_dim, + seq_len, + seed=42, +): + """Generate golden reference for fused attention prefill. + + Parameters: + num_heads (H): number of query attention heads + num_kv_groups (G): number of KV heads (G=H for MHA, G i with -inf + causal_mask = torch.zeros(H * S, S, dtype=torch.bfloat16) + for h in range(H): + for i in range(S): + for j in range(S): + if j > i: + causal_mask[h * S + i, j] = torch.tensor(float("-inf")).to( + torch.bfloat16 + ) + + # ---- Step 1-3: Q/K/V projections ---- + queries_raw = x.float() @ W_query.float() # (S, H*d) + queries_raw = queries_raw.to(torch.bfloat16) + keys_raw = x.float() @ W_key.float() # (S, G*d) + keys_raw = keys_raw.to(torch.bfloat16) + values_raw = x.float() @ W_value.float() # (S, G*d) + values_raw = values_raw.to(torch.bfloat16) + + # ---- Step 4-5: RoPE ---- + # Q proj output is (S, H*d), viewed as (S*H, d) with heads interleaved: + # row layout: [pos0_head0, pos0_head1, ..., pos0_headH-1, pos1_head0, ...] + # RoPE angle_rows=S: row i uses angle row (i % S) = position index + queries_for_rope = queries_raw.reshape(S * H, d) + queries_roped = apply_rope(queries_for_rope, rope_angles) # (S*H, d) + + keys_for_rope = keys_raw.reshape(S * G, d) + keys_roped = apply_rope(keys_for_rope, rope_angles) # (S*G, d) + + # ---- Step 6: Deinterleave Q: (S*H, d) → (H, S, d) ---- + # Current layout: [pos0_h0, pos0_h1, ..., pos0_{H-1}, pos1_h0, ...] + # = (S, H, d) in memory; reshape and transpose to (H, S, d) + queries_deinterleaved = queries_roped.reshape(S, H, d).transpose(0, 1).contiguous() # (H, S, d) + + # ---- Step 7: Deinterleave K: (S*G, d) → (G, S, d) then transpose to (G, d, S) ---- + keys_deinterleaved = keys_roped.reshape(S, G, d).transpose(0, 1).contiguous() # (G, S, d) + keys_transposed = keys_deinterleaved.transpose(1, 2).contiguous() # (G, d, S) + + # ---- Step 8: Deinterleave V: (S, G*d) → (G, S, d) ---- + values_deinterleaved = values_raw.reshape(S, G, d).transpose(0, 1).contiguous() # (G, S, d) + + # ---- Step 9: GQA repeat ---- + if group_size > 1: + # Repeat keys and values: (G, ...) → (H, ...) + # Flatten to (G, d*S) / (G, S*d), repeat, reshape + keys_for_scores = keys_transposed.reshape(G, d * S).repeat_interleave( + group_size, dim=0 + ).reshape(H, d, S) + values_for_context = values_deinterleaved.reshape(G, S * d).repeat_interleave( + group_size, dim=0 + ).reshape(H, S, d) + else: + keys_for_scores = keys_transposed # (H, d, S) + values_for_context = values_deinterleaved # (H, S, d) + + # ---- Step 10: Score GEMM per head ---- + # Q_head(S, d) @ K_head(d, S) → scores(S, S) + attn_scores = torch.zeros(H, S, S, dtype=torch.bfloat16) + for h in range(H): + attn_scores[h] = ( + queries_deinterleaved[h].float() @ keys_for_scores[h].float() + ).to(torch.bfloat16) + + # ---- Step 11: Scale ---- + attn_scores_scaled = (attn_scores.float() * scale).to(torch.bfloat16) + # ---- Step 12: Causal mask (add -inf) ---- + attn_scores_masked = attn_scores_scaled.reshape(H * S, S).float() + causal_mask.float() + attn_scores_masked = attn_scores_masked.to(torch.bfloat16) + + # ---- Step 13: Softmax ---- + attn_weights = torch.nn.functional.softmax( + attn_scores_masked.float().reshape(H, S, S), dim=-1 + ).to(torch.bfloat16) # (H, S, S) + + # ---- Step 14: Context GEMM per head ---- + # weights(S, S) @ values(S, d) → context(S, d) + attn_context = torch.zeros(H, S, d, dtype=torch.bfloat16) + for h in range(H): + attn_context[h] = ( + attn_weights[h].float() @ values_for_context[h].float() + ).to(torch.bfloat16) + + # ---- Step 15: Re-interleave context: (H, S, d) → (S, H*d) ---- + context_interleaved = attn_context.transpose(0, 1).contiguous().reshape(S, H * d) + + # ---- Step 16: Output projection ---- + attn_output = (context_interleaved.float() @ W_output.float()).to(torch.bfloat16) + + return { + "input": x, + "rope_angles": rope_angles, + "W_query": W_query, + "W_key": W_key, + "W_value": W_value, + "W_output": W_output, + "attn_scale_factor": attn_scale_factor, + "causal_mask": causal_mask, + "queries_raw": queries_raw, + "keys_raw": keys_raw, + "values_raw": values_raw, + "queries_roped": queries_roped, + "keys_roped": keys_roped, + "queries_deinterleaved": queries_deinterleaved, + "keys_deinterleaved": keys_deinterleaved, + "keys_transposed": keys_transposed, + "values_deinterleaved": values_deinterleaved, + "keys_for_scores": keys_for_scores, + "values_for_context": values_for_context, + "attn_scores": attn_scores, + "attn_scores_scaled": attn_scores_scaled, + "attn_scores_masked": attn_scores_masked, + "attn_weights": attn_weights, + "attn_context": attn_context, + "context_interleaved": context_interleaved, + "attn_output": attn_output, + } diff --git a/iron/operators/mha_prefill_lxl_sd/test.py b/iron/operators/mha_prefill_lxl_sd/test.py new file mode 100644 index 00000000..ccf768d9 --- /dev/null +++ b/iron/operators/mha_prefill_lxl_sd/test.py @@ -0,0 +1,186 @@ +# SPDX-FileCopyrightText: Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import time + +import numpy as np +import pytest +import torch +from ml_dtypes import bfloat16 + +from iron.common.test_utils import verify_buffer +from iron.common.utils import torch_to_numpy + +from iron.operators.mha_prefill_lxl_sd.op import ( + AIEAttentionPrefillFused, + AIEAttentionPrefillProjectedFused, +) +from iron.operators.mha_prefill_lxl_sd.reference import generate_golden_reference + +REL_TOL = 0.08 +ABS_TOL = 2.0 +MAX_ERROR_RATE = 0.03 + + +def get_params(): + return [ + pytest.param(2, 2, 64, 256, 256, id="H2"), + pytest.param(32, 8, 64, 2048, 256, id="H32"), + ] + + +def _load_input(fc, name, tensor): + """Load a tensor into a named sub-buffer of the fused callable.""" + fc.get_buffer(name).view_as_np()[:] = torch_to_numpy(tensor).flatten() + + +def _get_scratch_tensor(fc, name, shape): + """Read a named buffer from the fused callable's scratch space.""" + fc.scratch_buffer.on = "npu" + fc.scratch_buffer.to("cpu") + sub = fc.get_buffer(name) + return np.frombuffer( + sub.memory_view, dtype=bfloat16, count=int(np.prod(shape)) + ).reshape(shape).astype(np.float32) + + +def _verify_output(fc, golden, H, d, S, E): + """Chain-consistent output verification shared by both test variants.""" + npu_context = torch.from_numpy( + _get_scratch_tensor(fc, "context_interleaved", (S, H * d)) + ).bfloat16() + chain_ref = (npu_context.float() @ golden["W_output"].float()).to(torch.bfloat16) + + fc.output_buffer.on = "npu" + fc.output_buffer.to("cpu") + output_np = fc.get_buffer("attn_output").view_as_np() + output = torch.from_numpy(output_np.reshape(S, E).astype(np.float32)).bfloat16() + + errors = verify_buffer( + output, "attn_output", chain_ref.reshape(S, E), + rel_tol=REL_TOL, abs_tol=ABS_TOL, max_error_rate=MAX_ERROR_RATE, + ) + assert not errors, f"Output verification failed with {len(errors)} errors" + + +def _core_gemm_flops(H, G, d, E, S): + """Count GEMM FLOPs for the core attention operator.""" + score_flops = H * 2 * S * d * S # H x (S,d)@(d,S) + context_flops = H * 2 * S * S * d # H x (S,S)@(S,d) + output_flops = 2 * S * (H * d) * E # (S,H*d)@(H*d,E) + return score_flops + context_flops + output_flops + + +def _projected_gemm_flops(H, G, d, E, S): + """Count GEMM FLOPs for the projected attention operator.""" + query_proj = 2 * S * E * (H * d) # (S,E)@(E,H*d) + kv_proj = 2 * (2 * S * E * (G * d)) # key + value: (S,E)@(E,G*d) each + return query_proj + kv_proj + _core_gemm_flops(H, G, d, E, S) + + +def _print_metrics(label, elapsed_s, flops): + """Print latency and throughput metrics.""" + gflops = flops / elapsed_s / 1e9 + print(f" {label}: {elapsed_s*1e3:.2f} ms, {gflops:.2f} GFLOPS") + + +# --------------------------------------------------------------------------- +# Core attention tests (pre-projected Q, K, V) +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("H,G,d,E,S", get_params()) +def test_mha_pefill_lxl_sd(H, G, d, E, S): + """Core attention: score GEMM -> scale -> mask -> softmax -> context GEMM -> output.""" + golden = generate_golden_reference(H, G, d, E, S) + + op = AIEAttentionPrefillFused(H, G, d, E, S) + op.compile() + fc = op.get_callable() + + _load_input(fc, "queries", golden["queries_deinterleaved"]) + _load_input(fc, "keys", golden["keys_for_scores"]) + _load_input(fc, "values", golden["values_for_context"]) + _load_input(fc, "W_output", golden["W_output"]) + _load_input(fc, "attn_scale_factor", golden["attn_scale_factor"]) + _load_input(fc, "causal_mask", golden["causal_mask"]) + + t0 = time.perf_counter() + fc() + elapsed = time.perf_counter() - t0 + _print_metrics("core", elapsed, _core_gemm_flops(H, G, d, E, S)) + + _verify_output(fc, golden, H, d, S, E) + + +# --------------------------------------------------------------------------- +# Projected attention tests (with Q/K/V projections + RoPE) +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("H,G,d,E,S", get_params()) +def test_attention_prefill_projected_fused(H, G, d, E, S): + """Projected attention: Q/K/V proj -> RoPE -> GQA -> attention -> output proj.""" + golden = generate_golden_reference(H, G, d, E, S) + + op = AIEAttentionPrefillProjectedFused(H, G, d, E, S) + op.compile() + fc = op.get_callable() + + _load_input(fc, "input", golden["input"]) + _load_input(fc, "rope_angles", golden["rope_angles"]) + _load_input(fc, "W_query", golden["W_query"]) + _load_input(fc, "W_key", golden["W_key"]) + _load_input(fc, "W_value", golden["W_value"]) + _load_input(fc, "W_output", golden["W_output"]) + _load_input(fc, "attn_scale_factor", golden["attn_scale_factor"]) + _load_input(fc, "causal_mask", golden["causal_mask"]) + + t0 = time.perf_counter() + fc() + elapsed = time.perf_counter() - t0 + _print_metrics("projected", elapsed, _projected_gemm_flops(H, G, d, E, S)) + + _verify_output(fc, golden, H, d, S, E) + + +# --------------------------------------------------------------------------- +# Intermediate checks (extensive, not run by default) +# --------------------------------------------------------------------------- + +INTERMEDIATE_CHECKS = [ + ("attn_scores", "attn_scores", lambda H, G, S, d: (H, S, S)), + ("attn_scores_masked", "attn_scores_masked", lambda H, G, S, d: (H, S, S)), + ("attn_weights", "attn_weights", lambda H, G, S, d: (H, S, S)), + ("attn_context", "attn_context", lambda H, G, S, d: (H, S, d)), + ("context_interleaved", "context_interleaved", lambda H, G, S, d: (S, H * d)), +] + + +@pytest.mark.extensive +@pytest.mark.parametrize("H,G,d,E,S", get_params()) +def test_mha_pefill_lxl_sd_intermediates(H, G, d, E, S): + """Check intermediate buffers of core attention (for debugging).""" + golden = generate_golden_reference(H, G, d, E, S) + + op = AIEAttentionPrefillFused(H, G, d, E, S) + op.compile() + fc = op.get_callable() + + _load_input(fc, "queries", golden["queries_deinterleaved"]) + _load_input(fc, "keys", golden["keys_for_scores"]) + _load_input(fc, "values", golden["values_for_context"]) + _load_input(fc, "W_output", golden["W_output"]) + _load_input(fc, "attn_scale_factor", golden["attn_scale_factor"]) + _load_input(fc, "causal_mask", golden["causal_mask"]) + + fc() + + for buf_name, golden_key, shape_fn in INTERMEDIATE_CHECKS: + shape = shape_fn(H, G, S, d) + actual = _get_scratch_tensor(fc, buf_name, shape) + expected = golden[golden_key].float().numpy().reshape(shape) + diff = np.abs(actual - expected) + print( + f" [{buf_name}] shape={shape} " + f"nan={int(np.isnan(actual).sum())} " + f"max_abs_err={diff.max():.4f} mean_abs_err={diff.mean():.6f}" + ) From 1bc9bc38bc9c132af33494613696293a1d6fdc6b Mon Sep 17 00:00:00 2001 From: andrej Date: Mon, 6 Apr 2026 17:21:37 -0600 Subject: [PATCH 2/8] add GPT-2 sizes as test cases, make causal mask an option --- iron/common/fusion.py | 3 + iron/operators/mha_prefill_lxl_sd/op.py | 60 ++++++++++++----- iron/operators/mha_prefill_lxl_sd/test.py | 81 ++++++++++++++++++----- 3 files changed, 112 insertions(+), 32 deletions(-) diff --git a/iron/common/fusion.py b/iron/common/fusion.py index 99219848..f8d1b9cd 100644 --- a/iron/common/fusion.py +++ b/iron/common/fusion.py @@ -5,6 +5,7 @@ import ml_dtypes import pyxrt import ctypes +import time from . import compilation as comp from .base import AIEOperatorBase, MLIROperator from .utils import XRTSubBuffer @@ -290,8 +291,10 @@ def __call__(self, *args): for i, arg in enumerate(args): assert isinstance(arg, pyxrt.bo), f"Argument {i} is not a pyxrt.bo" run.set_arg(i, arg) + t0 = time.perf_counter() run.start() ret_code = run.wait() + self.last_elapsed = time.perf_counter() - t0 if ret_code != pyxrt.ert_cmd_state.ERT_CMD_STATE_COMPLETED: raise RuntimeError(f"Kernel execution failed with return code {ret_code}") diff --git a/iron/operators/mha_prefill_lxl_sd/op.py b/iron/operators/mha_prefill_lxl_sd/op.py index fe951829..b939e7e2 100644 --- a/iron/operators/mha_prefill_lxl_sd/op.py +++ b/iron/operators/mha_prefill_lxl_sd/op.py @@ -25,13 +25,15 @@ def _pick_tile_n(N, num_cols, max_tile_n=64): return tile_n -def _build_core_ops(H, G, d, E, S, elf_ctx): +def _build_core_ops(H, G, d, E, S, elf_ctx, causal_mask=True): """Build core attention sub-ops and runlist (no projections/RoPE/GQA). Expects pre-processed inputs: queries: (H, S, d) deinterleaved, contiguous per head keys: (H, d, S) transposed and GQA-repeated values: (H, S, d) GQA-repeated + + If causal_mask=False, the elementwise-add masking step is omitted. """ B = 2 # bytes per bf16 element @@ -43,10 +45,11 @@ def _build_core_ops(H, G, d, E, S, elf_ctx): size=H * S * S, tile_size=S * S // 8, num_aie_columns=8, context=elf_ctx, ) - mask = AIEElementwiseAdd( - size=H * S * S, tile_size=S * S // 8, - num_aie_columns=8, context=elf_ctx, - ) + if causal_mask: + mask = AIEElementwiseAdd( + size=H * S * S, tile_size=S * S // 8, + num_aie_columns=8, context=elf_ctx, + ) softmax = AIESoftmax( rows=H * S, cols=S, num_aie_columns=1, num_channels=1, rtp_vector_size=S, context=elf_ctx, @@ -79,8 +82,19 @@ def _build_core_ops(H, G, d, E, S, elf_ctx): f"attn_scores[{h*sh}:{(h+1)*sh}]") for h in range(H)], (scale, "attn_scores", "attn_scale_factor", "attn_scores"), - (mask, "attn_scores", "causal_mask", "attn_scores_masked"), - (softmax, "attn_scores_masked", "attn_weights"), + ] + + if causal_mask: + runlist += [ + (mask, "attn_scores", "causal_mask", "attn_scores_masked"), + (softmax, "attn_scores_masked", "attn_weights"), + ] + else: + runlist += [ + (softmax, "attn_scores", "attn_weights"), + ] + + runlist += [ *[(gemm_context, f"attn_weights[{h*sh}:{(h+1)*sh}]", f"values[{h*kSd}:{(h+1)*kSd}]", @@ -95,11 +109,12 @@ def _build_core_ops(H, G, d, E, S, elf_ctx): "keys": H * d * S * B, "values": H * S * d * B, "attn_scores": H * S * S * B, - "attn_scores_masked": H * S * S * B, "attn_weights": H * S * S * B, "attn_context": H * S * d * B, "context_interleaved": S * H * d * B, } + if causal_mask: + buffer_sizes["attn_scores_masked"] = H * S * S * B return runlist, buffer_sizes @@ -111,7 +126,7 @@ class AIEAttentionPrefillFused(FusedMLIROperator): """ def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, - seq_len, context=None): + seq_len, causal_mask=True, context=None): assert head_dim == 64 assert num_heads % num_kv_groups == 0 assert seq_len % 256 == 0 @@ -126,13 +141,19 @@ def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, elf_ctx = context or AIEContext() runlist, buffer_sizes = _build_core_ops( num_heads, num_kv_groups, head_dim, embedding_dim, seq_len, elf_ctx, + causal_mask=causal_mask, ) + mask_suffix = "_causal" if causal_mask else "_nomask" + input_args = ["queries", "keys", "values", + "W_output", "attn_scale_factor"] + if causal_mask: + input_args.append("causal_mask") + super().__init__( - name=f"attention_prefill_fused_{num_heads}h{num_kv_groups}g{head_dim}d{embedding_dim}e{seq_len}s", + name=f"attention_prefill_fused_{num_heads}h{num_kv_groups}g{head_dim}d{embedding_dim}e{seq_len}s{mask_suffix}", runlist=runlist, - input_args=["queries", "keys", "values", - "W_output", "attn_scale_factor", "causal_mask"], + input_args=input_args, output_args=["attn_output"], buffer_sizes=buffer_sizes, context=elf_ctx, @@ -146,7 +167,7 @@ class AIEAttentionPrefillProjectedFused(FusedMLIROperator): """ def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, - seq_len, context=None): + seq_len, causal_mask=True, context=None): assert head_dim == 64 assert num_heads % num_kv_groups == 0 assert seq_len % 256 == 0 @@ -231,14 +252,19 @@ def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, } core_runlist, core_buffer_sizes = _build_core_ops( - H, G, d, E, S, elf_ctx, + H, G, d, E, S, elf_ctx, causal_mask=causal_mask, ) + mask_suffix = "_causal" if causal_mask else "_nomask" + input_args = ["input", "rope_angles", "W_query", "W_key", "W_value", + "W_output", "attn_scale_factor"] + if causal_mask: + input_args.append("causal_mask") + super().__init__( - name=f"attention_prefill_projected_fused_{H}h{G}g{d}d{E}e{S}s", + name=f"attention_prefill_projected_fused_{H}h{G}g{d}d{E}e{S}s{mask_suffix}", runlist=prefix_runlist + core_runlist, - input_args=["input", "rope_angles", "W_query", "W_key", "W_value", - "W_output", "attn_scale_factor", "causal_mask"], + input_args=input_args, output_args=["attn_output"], buffer_sizes={**prefix_buffer_sizes, **core_buffer_sizes}, context=elf_ctx, diff --git a/iron/operators/mha_prefill_lxl_sd/test.py b/iron/operators/mha_prefill_lxl_sd/test.py index ccf768d9..1bfc0953 100644 --- a/iron/operators/mha_prefill_lxl_sd/test.py +++ b/iron/operators/mha_prefill_lxl_sd/test.py @@ -1,8 +1,6 @@ # SPDX-FileCopyrightText: Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -import time - import numpy as np import pytest import torch @@ -25,10 +23,23 @@ def get_params(): return [ pytest.param(2, 2, 64, 256, 256, id="H2"), - pytest.param(32, 8, 64, 2048, 256, id="H32"), + pytest.param(32, 8, 64, 2048, 256, id="Llama3.2-256seq"), + pytest.param(12, 12, 64, 768, 256, id="GPT2-Small-256seq"), ] +def get_benchmark_params(): + """GPT-2 Small across sequence lengths 256..32768, with/without causal mask.""" + params = [] + S = 256 + while S <= 32768: + for mask in [True, False]: + tag = "causal" if mask else "nomask" + params.append(pytest.param(12, 12, 64, 768, S, mask, id=f"GPT2-S{S}-{tag}")) + S *= 2 + return params + + def _load_input(fc, name, tensor): """Load a tensor into a named sub-buffer of the fused callable.""" fc.get_buffer(name).view_as_np()[:] = torch_to_numpy(tensor).flatten() @@ -78,16 +89,14 @@ def _projected_gemm_flops(H, G, d, E, S): return query_proj + kv_proj + _core_gemm_flops(H, G, d, E, S) -def _print_metrics(label, elapsed_s, flops): - """Print latency and throughput metrics.""" - gflops = flops / elapsed_s / 1e9 - print(f" {label}: {elapsed_s*1e3:.2f} ms, {gflops:.2f} GFLOPS") - - # --------------------------------------------------------------------------- # Core attention tests (pre-projected Q, K, V) # --------------------------------------------------------------------------- +@pytest.mark.metrics( + Latency=r"Latency \(us\): (?P[\d\.]+)", + Throughput=r"Throughput: (?P[\d\.e\+-]+) GFLOP/s", +) @pytest.mark.parametrize("H,G,d,E,S", get_params()) def test_mha_pefill_lxl_sd(H, G, d, E, S): """Core attention: score GEMM -> scale -> mask -> softmax -> context GEMM -> output.""" @@ -104,10 +113,12 @@ def test_mha_pefill_lxl_sd(H, G, d, E, S): _load_input(fc, "attn_scale_factor", golden["attn_scale_factor"]) _load_input(fc, "causal_mask", golden["causal_mask"]) - t0 = time.perf_counter() fc() - elapsed = time.perf_counter() - t0 - _print_metrics("core", elapsed, _core_gemm_flops(H, G, d, E, S)) + + latency_us = fc.last_elapsed * 1e6 + gflops = _core_gemm_flops(H, G, d, E, S) / (fc.last_elapsed) / 1e9 + print(f"\nLatency (us): {latency_us:.1f}") + print(f"Throughput: {gflops:.6e} GFLOP/s") _verify_output(fc, golden, H, d, S, E) @@ -116,6 +127,10 @@ def test_mha_pefill_lxl_sd(H, G, d, E, S): # Projected attention tests (with Q/K/V projections + RoPE) # --------------------------------------------------------------------------- +@pytest.mark.metrics( + Latency=r"Latency \(us\): (?P[\d\.]+)", + Throughput=r"Throughput: (?P[\d\.e\+-]+) GFLOP/s", +) @pytest.mark.parametrize("H,G,d,E,S", get_params()) def test_attention_prefill_projected_fused(H, G, d, E, S): """Projected attention: Q/K/V proj -> RoPE -> GQA -> attention -> output proj.""" @@ -134,14 +149,50 @@ def test_attention_prefill_projected_fused(H, G, d, E, S): _load_input(fc, "attn_scale_factor", golden["attn_scale_factor"]) _load_input(fc, "causal_mask", golden["causal_mask"]) - t0 = time.perf_counter() fc() - elapsed = time.perf_counter() - t0 - _print_metrics("projected", elapsed, _projected_gemm_flops(H, G, d, E, S)) + + latency_us = fc.last_elapsed * 1e6 + gflops = _projected_gemm_flops(H, G, d, E, S) / (fc.last_elapsed) / 1e9 + print(f"\nLatency (us): {latency_us:.1f}") + print(f"Throughput: {gflops:.6e} GFLOP/s") _verify_output(fc, golden, H, d, S, E) +# --------------------------------------------------------------------------- +# Benchmark: GPT-2 Small core MHA across sequence lengths, +/- causal mask +# --------------------------------------------------------------------------- + +@pytest.mark.benchmark +@pytest.mark.metrics( + Latency=r"Latency \(us\): (?P[\d\.]+)", + Throughput=r"Throughput: (?P[\d\.e\+-]+) GFLOP/s", +) +@pytest.mark.parametrize("H,G,d,E,S,causal", get_benchmark_params()) +def test_mha_prefill_benchmark(H, G, d, E, S, causal): + """Benchmark core MHA for GPT-2 Small across sequence lengths.""" + golden = generate_golden_reference(H, G, d, E, S) + + op = AIEAttentionPrefillFused(H, G, d, E, S, causal_mask=causal) + op.compile() + fc = op.get_callable() + + _load_input(fc, "queries", golden["queries_deinterleaved"]) + _load_input(fc, "keys", golden["keys_for_scores"]) + _load_input(fc, "values", golden["values_for_context"]) + _load_input(fc, "W_output", golden["W_output"]) + _load_input(fc, "attn_scale_factor", golden["attn_scale_factor"]) + if causal: + _load_input(fc, "causal_mask", golden["causal_mask"]) + + fc() + + latency_us = fc.last_elapsed * 1e6 + gflops = _core_gemm_flops(H, G, d, E, S) / (fc.last_elapsed) / 1e9 + print(f"\nLatency (us): {latency_us:.1f}") + print(f"Throughput: {gflops:.6e} GFLOP/s") + + # --------------------------------------------------------------------------- # Intermediate checks (extensive, not run by default) # --------------------------------------------------------------------------- From b698e02a46da4e26fdd4f6fb5f3abdae14105db9 Mon Sep 17 00:00:00 2001 From: andrej Date: Mon, 6 Apr 2026 17:27:09 -0600 Subject: [PATCH 3/8] as benchmarked --- .../operators/mha_prefill_lxl_sd/reference.py | 148 +++++++----------- pytest.ini | 1 + 2 files changed, 54 insertions(+), 95 deletions(-) diff --git a/iron/operators/mha_prefill_lxl_sd/reference.py b/iron/operators/mha_prefill_lxl_sd/reference.py index 484f4d7f..3d1c21ed 100644 --- a/iron/operators/mha_prefill_lxl_sd/reference.py +++ b/iron/operators/mha_prefill_lxl_sd/reference.py @@ -2,36 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 import torch -import numpy as np -from ml_dtypes import bfloat16 +from iron.operators.rope.rope_utils import apply_rope as _apply_rope_4d -def apply_rope(x, lut): - """Apply Rotary Position Embedding using pre-computed cos/sin LUT. - x: (rows, cols) — rows are (positions * heads) interleaved - lut: (angle_rows, cols) — interleaved [cos_0, sin_0, cos_1, sin_1, ...] - If angle_rows < rows, each angle row is reused for - (rows // angle_rows) consecutive input rows (block repetition). - Returns: (rows, cols) with RoPE applied (two-halves method) - """ - rows, cols = x.shape - angle_rows = lut.shape[0] - half = cols // 2 - - cos = lut[:, ::2] # (angle_rows, half) - sin = lut[:, 1::2] # (angle_rows, half) - - if angle_rows < rows: - # Block repetition: each angle row repeats for consecutive input rows - repeats = rows // angle_rows - cos = cos.repeat_interleave(repeats, dim=0) # (rows, half) - sin = sin.repeat_interleave(repeats, dim=0) # (rows, half) - - x1 = x[:, :half] - x2 = x[:, half:] - out = torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1) - return out +def _bf16_matmul(a, b): + """(float32 matmul) → bfloat16, matching NPU accumulation.""" + return (a.float() @ b.float()).to(torch.bfloat16) def generate_golden_reference( @@ -75,64 +52,49 @@ def generate_golden_reference( rope_angles = rope_angles.to(torch.bfloat16) # Weight matrices (transposed for GEMM: input @ W → output) - # Q proj: (S, E) @ (E, H*d) → (S, H*d) W_query = torch.randn(E, H * d, dtype=torch.bfloat16) * val_range - # K proj: (S, E) @ (E, G*d) → (S, G*d) W_key = torch.randn(E, G * d, dtype=torch.bfloat16) * val_range - # V proj: (S, E) @ (E, G*d) → (S, G*d) W_value = torch.randn(E, G * d, dtype=torch.bfloat16) * val_range - # Output proj: (S, H*d) @ (H*d, E) → (S, E) W_output = torch.randn(H * d, E, dtype=torch.bfloat16) * val_range - # Scale factor: 1/sqrt(d), broadcast to (H*S, S) + # Scale factor: 1/sqrt(d), broadcast to (H*S*S,) scale = 1.0 / (d ** 0.5) attn_scale_factor = torch.full((H * S * S,), scale, dtype=torch.bfloat16) - # Causal mask: (H*S, S) — 0 for valid positions, -inf for future positions - # Row (h*S + i) attends to positions 0..i, so mask col j > i with -inf + # Causal mask: (H*S, S) — 0 for valid positions, -inf for future causal_mask = torch.zeros(H * S, S, dtype=torch.bfloat16) for h in range(H): for i in range(S): - for j in range(S): - if j > i: - causal_mask[h * S + i, j] = torch.tensor(float("-inf")).to( - torch.bfloat16 - ) - - # ---- Step 1-3: Q/K/V projections ---- - queries_raw = x.float() @ W_query.float() # (S, H*d) - queries_raw = queries_raw.to(torch.bfloat16) - keys_raw = x.float() @ W_key.float() # (S, G*d) - keys_raw = keys_raw.to(torch.bfloat16) - values_raw = x.float() @ W_value.float() # (S, G*d) - values_raw = values_raw.to(torch.bfloat16) - - # ---- Step 4-5: RoPE ---- - # Q proj output is (S, H*d), viewed as (S*H, d) with heads interleaved: - # row layout: [pos0_head0, pos0_head1, ..., pos0_headH-1, pos1_head0, ...] - # RoPE angle_rows=S: row i uses angle row (i % S) = position index - queries_for_rope = queries_raw.reshape(S * H, d) - queries_roped = apply_rope(queries_for_rope, rope_angles) # (S*H, d) - - keys_for_rope = keys_raw.reshape(S * G, d) - keys_roped = apply_rope(keys_for_rope, rope_angles) # (S*G, d) - - # ---- Step 6: Deinterleave Q: (S*H, d) → (H, S, d) ---- - # Current layout: [pos0_h0, pos0_h1, ..., pos0_{H-1}, pos1_h0, ...] - # = (S, H, d) in memory; reshape and transpose to (H, S, d) + for j in range(i + 1, S): + causal_mask[h * S + i, j] = torch.tensor(float("-inf")).to( + torch.bfloat16 + ) + + # ---- Q/K/V projections ---- + queries_raw = _bf16_matmul(x, W_query) # (S, H*d) + keys_raw = _bf16_matmul(x, W_key) # (S, G*d) + values_raw = _bf16_matmul(x, W_value) # (S, G*d) + + # ---- RoPE (reuses rope_utils.apply_rope with 4D interface) ---- + # Reshape interleaved (S, N*d) → (1, N, S, d) for rope_utils + queries_roped = _apply_rope_4d( + queries_raw.reshape(S, H, d).permute(1, 0, 2).unsqueeze(0), # (1, H, S, d) + rope_angles, + ).squeeze(0).permute(1, 0, 2).contiguous().reshape(S * H, d) # (S*H, d) + + keys_roped = _apply_rope_4d( + keys_raw.reshape(S, G, d).permute(1, 0, 2).unsqueeze(0), # (1, G, S, d) + rope_angles, + ).squeeze(0).permute(1, 0, 2).contiguous().reshape(S * G, d) # (S*G, d) + + # ---- Deinterleave Q/K/V ---- queries_deinterleaved = queries_roped.reshape(S, H, d).transpose(0, 1).contiguous() # (H, S, d) + keys_deinterleaved = keys_roped.reshape(S, G, d).transpose(0, 1).contiguous() # (G, S, d) + keys_transposed = keys_deinterleaved.transpose(1, 2).contiguous() # (G, d, S) + values_deinterleaved = values_raw.reshape(S, G, d).transpose(0, 1).contiguous() # (G, S, d) - # ---- Step 7: Deinterleave K: (S*G, d) → (G, S, d) then transpose to (G, d, S) ---- - keys_deinterleaved = keys_roped.reshape(S, G, d).transpose(0, 1).contiguous() # (G, S, d) - keys_transposed = keys_deinterleaved.transpose(1, 2).contiguous() # (G, d, S) - - # ---- Step 8: Deinterleave V: (S, G*d) → (G, S, d) ---- - values_deinterleaved = values_raw.reshape(S, G, d).transpose(0, 1).contiguous() # (G, S, d) - - # ---- Step 9: GQA repeat ---- + # ---- GQA repeat ---- if group_size > 1: - # Repeat keys and values: (G, ...) → (H, ...) - # Flatten to (G, d*S) / (G, S*d), repeat, reshape keys_for_scores = keys_transposed.reshape(G, d * S).repeat_interleave( group_size, dim=0 ).reshape(H, d, S) @@ -140,41 +102,37 @@ def generate_golden_reference( group_size, dim=0 ).reshape(H, S, d) else: - keys_for_scores = keys_transposed # (H, d, S) - values_for_context = values_deinterleaved # (H, S, d) + keys_for_scores = keys_transposed # (H, d, S) + values_for_context = values_deinterleaved # (H, S, d) - # ---- Step 10: Score GEMM per head ---- - # Q_head(S, d) @ K_head(d, S) → scores(S, S) - attn_scores = torch.zeros(H, S, S, dtype=torch.bfloat16) - for h in range(H): - attn_scores[h] = ( - queries_deinterleaved[h].float() @ keys_for_scores[h].float() - ).to(torch.bfloat16) + # ---- Score GEMM per head ---- + attn_scores = torch.stack( + [_bf16_matmul(queries_deinterleaved[h], keys_for_scores[h]) for h in range(H)] + ) # (H, S, S) - # ---- Step 11: Scale ---- + # ---- Scale ---- attn_scores_scaled = (attn_scores.float() * scale).to(torch.bfloat16) - # ---- Step 12: Causal mask (add -inf) ---- - attn_scores_masked = attn_scores_scaled.reshape(H * S, S).float() + causal_mask.float() - attn_scores_masked = attn_scores_masked.to(torch.bfloat16) - # ---- Step 13: Softmax ---- + # ---- Causal mask ---- + attn_scores_masked = ( + attn_scores_scaled.reshape(H * S, S).float() + causal_mask.float() + ).to(torch.bfloat16) + + # ---- Softmax ---- attn_weights = torch.nn.functional.softmax( attn_scores_masked.float().reshape(H, S, S), dim=-1 ).to(torch.bfloat16) # (H, S, S) - # ---- Step 14: Context GEMM per head ---- - # weights(S, S) @ values(S, d) → context(S, d) - attn_context = torch.zeros(H, S, d, dtype=torch.bfloat16) - for h in range(H): - attn_context[h] = ( - attn_weights[h].float() @ values_for_context[h].float() - ).to(torch.bfloat16) + # ---- Context GEMM per head ---- + attn_context = torch.stack( + [_bf16_matmul(attn_weights[h], values_for_context[h]) for h in range(H)] + ) # (H, S, d) - # ---- Step 15: Re-interleave context: (H, S, d) → (S, H*d) ---- + # ---- Re-interleave context: (H, S, d) → (S, H*d) ---- context_interleaved = attn_context.transpose(0, 1).contiguous().reshape(S, H * d) - # ---- Step 16: Output projection ---- - attn_output = (context_interleaved.float() @ W_output.float()).to(torch.bfloat16) + # ---- Output projection ---- + attn_output = _bf16_matmul(context_interleaved, W_output) return { "input": x, diff --git a/pytest.ini b/pytest.ini index 44f08847..a3566ee2 100644 --- a/pytest.ini +++ b/pytest.ini @@ -9,4 +9,5 @@ python_functions = test_* markers = extensive: extensive test suite (deselect with '-m "not extensive"') supported_devices(*devices): mark test as only supported on the given devices (e.g. "npu1", "npu2"). All devices supported by default. + benchmark: benchmark-only tests (select with '-m benchmark') addopts = -v --tb=short --import-mode=importlib From 1da4d569d68e7dd68c5f799badd4a5221eec6b1f Mon Sep 17 00:00:00 2001 From: andrej Date: Mon, 6 Apr 2026 17:57:04 -0600 Subject: [PATCH 4/8] fix DMA dimension overflow --- iron/operators/mha_prefill_lxl_sd/op.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/iron/operators/mha_prefill_lxl_sd/op.py b/iron/operators/mha_prefill_lxl_sd/op.py index b939e7e2..d2af422c 100644 --- a/iron/operators/mha_prefill_lxl_sd/op.py +++ b/iron/operators/mha_prefill_lxl_sd/op.py @@ -59,8 +59,10 @@ def _build_core_ops(H, G, d, E, S, elf_ctx, causal_mask=True): tile_n=16, context=elf_ctx, prio_accuracy=True, ) reinterleave = AIEStridedCopy( - input_sizes=(H, S, d), input_strides=(S * d, d, 1), input_offset=0, - output_sizes=(H, S, d), output_strides=(d, H * d, 1), output_offset=0, + #input_sizes=(H, S, d), input_strides=(S * d, d, 1), input_offset=0, + input_sizes=(1, 1, 1, H * S * d), input_strides=(0, 0, 0, 1), input_offset=0, + #output_sizes=(H, S, d), output_strides=(d, H * d, 1), output_offset=0, + output_sizes=(H, 256, S // 256, d), output_strides=(d, 256 * H * d, H * d, 1), output_offset=0, input_buffer_size=H * S * d, output_buffer_size=S * H * d, transfer_size=S * d, num_aie_channels=1, context=elf_ctx, ) From b1d1e5860160df8e0e5a4094df7fc106da83d907 Mon Sep 17 00:00:00 2001 From: andrej Date: Mon, 6 Apr 2026 18:23:12 -0600 Subject: [PATCH 5/8] create separate attn_scores_scaled buffer --- iron/operators/mha_prefill_lxl_sd/op.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/iron/operators/mha_prefill_lxl_sd/op.py b/iron/operators/mha_prefill_lxl_sd/op.py index d2af422c..f0e7e3fb 100644 --- a/iron/operators/mha_prefill_lxl_sd/op.py +++ b/iron/operators/mha_prefill_lxl_sd/op.py @@ -83,17 +83,17 @@ def _build_core_ops(H, G, d, E, S, elf_ctx, causal_mask=True): f"keys[{h*kdS}:{(h+1)*kdS}]", f"attn_scores[{h*sh}:{(h+1)*sh}]") for h in range(H)], - (scale, "attn_scores", "attn_scale_factor", "attn_scores"), + (scale, "attn_scores", "attn_scale_factor", "attn_scores_scaled"), ] if causal_mask: runlist += [ - (mask, "attn_scores", "causal_mask", "attn_scores_masked"), + (mask, "attn_scores_scaled", "causal_mask", "attn_scores_masked"), (softmax, "attn_scores_masked", "attn_weights"), ] else: runlist += [ - (softmax, "attn_scores", "attn_weights"), + (softmax, "attn_scores_scaled", "attn_weights"), ] runlist += [ @@ -111,6 +111,7 @@ def _build_core_ops(H, G, d, E, S, elf_ctx, causal_mask=True): "keys": H * d * S * B, "values": H * S * d * B, "attn_scores": H * S * S * B, + "attn_scores_scaled": H * S * S * B, "attn_weights": H * S * S * B, "attn_context": H * S * d * B, "context_interleaved": S * H * d * B, From 2b45766f26fe4f364f72b81e273e3f79d78fe185 Mon Sep 17 00:00:00 2001 From: andrej Date: Tue, 7 Apr 2026 13:57:16 -0600 Subject: [PATCH 6/8] move output GEMM out of core MHA --- iron/operators/mha_prefill_lxl_sd/op.py | 53 +++++++++++++---------- iron/operators/mha_prefill_lxl_sd/test.py | 46 +++++++++++++------- 2 files changed, 61 insertions(+), 38 deletions(-) diff --git a/iron/operators/mha_prefill_lxl_sd/op.py b/iron/operators/mha_prefill_lxl_sd/op.py index f0e7e3fb..578a63bc 100644 --- a/iron/operators/mha_prefill_lxl_sd/op.py +++ b/iron/operators/mha_prefill_lxl_sd/op.py @@ -25,7 +25,7 @@ def _pick_tile_n(N, num_cols, max_tile_n=64): return tile_n -def _build_core_ops(H, G, d, E, S, elf_ctx, causal_mask=True): +def _build_core_ops(H, G, d, S, elf_ctx, causal_mask=True): """Build core attention sub-ops and runlist (no projections/RoPE/GQA). Expects pre-processed inputs: @@ -33,6 +33,9 @@ def _build_core_ops(H, G, d, E, S, elf_ctx, causal_mask=True): keys: (H, d, S) transposed and GQA-repeated values: (H, S, d) GQA-repeated + Produces: + attn_context: (H, S, d) — per-head context vectors + If causal_mask=False, the elementwise-add masking step is omitted. """ B = 2 # bytes per bf16 element @@ -58,18 +61,6 @@ def _build_core_ops(H, G, d, E, S, elf_ctx, causal_mask=True): M=S, K=S, N=d, num_aie_columns=4, tile_m=16, tile_k=64, tile_n=16, context=elf_ctx, prio_accuracy=True, ) - reinterleave = AIEStridedCopy( - #input_sizes=(H, S, d), input_strides=(S * d, d, 1), input_offset=0, - input_sizes=(1, 1, 1, H * S * d), input_strides=(0, 0, 0, 1), input_offset=0, - #output_sizes=(H, S, d), output_strides=(d, H * d, 1), output_offset=0, - output_sizes=(H, 256, S // 256, d), output_strides=(d, 256 * H * d, H * d, 1), output_offset=0, - input_buffer_size=H * S * d, output_buffer_size=S * H * d, - transfer_size=S * d, num_aie_channels=1, context=elf_ctx, - ) - gemm_output = AIEGEMM( - M=S, K=H * d, N=E, num_aie_columns=8, tile_m=16, tile_k=64, - tile_n=_pick_tile_n(E, 8), context=elf_ctx, prio_accuracy=True, - ) qh = S * d * B kdS = d * S * B @@ -102,8 +93,6 @@ def _build_core_ops(H, G, d, E, S, elf_ctx, causal_mask=True): f"values[{h*kSd}:{(h+1)*kSd}]", f"attn_context[{h*ch}:{(h+1)*ch}]") for h in range(H)], - (reinterleave, "attn_context", "context_interleaved"), - (gemm_output, "context_interleaved", "W_output", "attn_output"), ] buffer_sizes = { @@ -114,7 +103,6 @@ def _build_core_ops(H, G, d, E, S, elf_ctx, causal_mask=True): "attn_scores_scaled": H * S * S * B, "attn_weights": H * S * S * B, "attn_context": H * S * d * B, - "context_interleaved": S * H * d * B, } if causal_mask: buffer_sizes["attn_scores_masked"] = H * S * S * B @@ -143,13 +131,12 @@ def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, elf_ctx = context or AIEContext() runlist, buffer_sizes = _build_core_ops( - num_heads, num_kv_groups, head_dim, embedding_dim, seq_len, elf_ctx, + num_heads, num_kv_groups, head_dim, seq_len, elf_ctx, causal_mask=causal_mask, ) mask_suffix = "_causal" if causal_mask else "_nomask" - input_args = ["queries", "keys", "values", - "W_output", "attn_scale_factor"] + input_args = ["queries", "keys", "values", "attn_scale_factor"] if causal_mask: input_args.append("causal_mask") @@ -157,7 +144,7 @@ def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, name=f"attention_prefill_fused_{num_heads}h{num_kv_groups}g{head_dim}d{embedding_dim}e{seq_len}s{mask_suffix}", runlist=runlist, input_args=input_args, - output_args=["attn_output"], + output_args=["attn_context"], buffer_sizes=buffer_sizes, context=elf_ctx, ) @@ -255,9 +242,29 @@ def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, } core_runlist, core_buffer_sizes = _build_core_ops( - H, G, d, E, S, elf_ctx, causal_mask=causal_mask, + H, G, d, S, elf_ctx, causal_mask=causal_mask, + ) + + # ---- Reinterleave + output projection ---- + reinterleave = AIEStridedCopy( + input_sizes=(1, 1, 1, H * S * d), input_strides=(0, 0, 0, 1), input_offset=0, + output_sizes=(H, 256, S // 256, d), output_strides=(d, 256 * H * d, H * d, 1), output_offset=0, + input_buffer_size=H * S * d, output_buffer_size=S * H * d, + transfer_size=S * d, num_aie_channels=1, context=elf_ctx, + ) + gemm_output = AIEGEMM( + M=S, K=H * d, N=E, num_aie_columns=8, tile_m=16, tile_k=64, + tile_n=_pick_tile_n(E, 8), context=elf_ctx, prio_accuracy=True, ) + suffix_runlist = [ + (reinterleave, "attn_context", "context_interleaved"), + (gemm_output, "context_interleaved", "W_output", "attn_output"), + ] + suffix_buffer_sizes = { + "context_interleaved": S * H * d * B, + } + mask_suffix = "_causal" if causal_mask else "_nomask" input_args = ["input", "rope_angles", "W_query", "W_key", "W_value", "W_output", "attn_scale_factor"] @@ -266,9 +273,9 @@ def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, super().__init__( name=f"attention_prefill_projected_fused_{H}h{G}g{d}d{E}e{S}s{mask_suffix}", - runlist=prefix_runlist + core_runlist, + runlist=prefix_runlist + core_runlist + suffix_runlist, input_args=input_args, output_args=["attn_output"], - buffer_sizes={**prefix_buffer_sizes, **core_buffer_sizes}, + buffer_sizes={**prefix_buffer_sizes, **core_buffer_sizes, **suffix_buffer_sizes}, context=elf_ctx, ) diff --git a/iron/operators/mha_prefill_lxl_sd/test.py b/iron/operators/mha_prefill_lxl_sd/test.py index 1bfc0953..94549d7f 100644 --- a/iron/operators/mha_prefill_lxl_sd/test.py +++ b/iron/operators/mha_prefill_lxl_sd/test.py @@ -55,6 +55,16 @@ def _get_scratch_tensor(fc, name, shape): ).reshape(shape).astype(np.float32) +def _get_output_tensor(fc, name, shape): + """Read a named buffer from the fused callable's output space.""" + fc.output_buffer.on = "npu" + fc.output_buffer.to("cpu") + sub = fc.get_buffer(name) + return np.frombuffer( + sub.memory_view, dtype=bfloat16, count=int(np.prod(shape)) + ).reshape(shape).astype(np.float32) + + def _verify_output(fc, golden, H, d, S, E): """Chain-consistent output verification shared by both test variants.""" npu_context = torch.from_numpy( @@ -78,15 +88,15 @@ def _core_gemm_flops(H, G, d, E, S): """Count GEMM FLOPs for the core attention operator.""" score_flops = H * 2 * S * d * S # H x (S,d)@(d,S) context_flops = H * 2 * S * S * d # H x (S,S)@(S,d) - output_flops = 2 * S * (H * d) * E # (S,H*d)@(H*d,E) - return score_flops + context_flops + output_flops + return score_flops + context_flops def _projected_gemm_flops(H, G, d, E, S): """Count GEMM FLOPs for the projected attention operator.""" query_proj = 2 * S * E * (H * d) # (S,E)@(E,H*d) kv_proj = 2 * (2 * S * E * (G * d)) # key + value: (S,E)@(E,G*d) each - return query_proj + kv_proj + _core_gemm_flops(H, G, d, E, S) + output_proj = 2 * S * (H * d) * E # (S,H*d)@(H*d,E) + return query_proj + kv_proj + _core_gemm_flops(H, G, d, E, S) + output_proj # --------------------------------------------------------------------------- @@ -99,7 +109,7 @@ def _projected_gemm_flops(H, G, d, E, S): ) @pytest.mark.parametrize("H,G,d,E,S", get_params()) def test_mha_pefill_lxl_sd(H, G, d, E, S): - """Core attention: score GEMM -> scale -> mask -> softmax -> context GEMM -> output.""" + """Core attention: score GEMM -> scale -> mask -> softmax -> context GEMM.""" golden = generate_golden_reference(H, G, d, E, S) op = AIEAttentionPrefillFused(H, G, d, E, S) @@ -109,7 +119,6 @@ def test_mha_pefill_lxl_sd(H, G, d, E, S): _load_input(fc, "queries", golden["queries_deinterleaved"]) _load_input(fc, "keys", golden["keys_for_scores"]) _load_input(fc, "values", golden["values_for_context"]) - _load_input(fc, "W_output", golden["W_output"]) _load_input(fc, "attn_scale_factor", golden["attn_scale_factor"]) _load_input(fc, "causal_mask", golden["causal_mask"]) @@ -120,7 +129,14 @@ def test_mha_pefill_lxl_sd(H, G, d, E, S): print(f"\nLatency (us): {latency_us:.1f}") print(f"Throughput: {gflops:.6e} GFLOP/s") - _verify_output(fc, golden, H, d, S, E) + actual = _get_output_tensor(fc, "attn_context", (H, S, d)) + expected = golden["attn_context"].float().numpy().reshape(H, S, d) + errors = verify_buffer( + torch.from_numpy(actual).bfloat16(), "attn_context", + torch.from_numpy(expected).bfloat16().reshape(H, S, d), + rel_tol=REL_TOL, abs_tol=ABS_TOL, max_error_rate=MAX_ERROR_RATE, + ) + assert not errors, f"Output verification failed with {len(errors)} errors" # --------------------------------------------------------------------------- @@ -180,7 +196,6 @@ def test_mha_prefill_benchmark(H, G, d, E, S, causal): _load_input(fc, "queries", golden["queries_deinterleaved"]) _load_input(fc, "keys", golden["keys_for_scores"]) _load_input(fc, "values", golden["values_for_context"]) - _load_input(fc, "W_output", golden["W_output"]) _load_input(fc, "attn_scale_factor", golden["attn_scale_factor"]) if causal: _load_input(fc, "causal_mask", golden["causal_mask"]) @@ -198,11 +213,10 @@ def test_mha_prefill_benchmark(H, G, d, E, S, causal): # --------------------------------------------------------------------------- INTERMEDIATE_CHECKS = [ - ("attn_scores", "attn_scores", lambda H, G, S, d: (H, S, S)), - ("attn_scores_masked", "attn_scores_masked", lambda H, G, S, d: (H, S, S)), - ("attn_weights", "attn_weights", lambda H, G, S, d: (H, S, S)), - ("attn_context", "attn_context", lambda H, G, S, d: (H, S, d)), - ("context_interleaved", "context_interleaved", lambda H, G, S, d: (S, H * d)), + ("attn_scores", "attn_scores", lambda H, G, S, d: (H, S, S), "scratch"), + ("attn_scores_masked", "attn_scores_masked", lambda H, G, S, d: (H, S, S), "scratch"), + ("attn_weights", "attn_weights", lambda H, G, S, d: (H, S, S), "scratch"), + ("attn_context", "attn_context", lambda H, G, S, d: (H, S, d), "output"), ] @@ -219,15 +233,17 @@ def test_mha_pefill_lxl_sd_intermediates(H, G, d, E, S): _load_input(fc, "queries", golden["queries_deinterleaved"]) _load_input(fc, "keys", golden["keys_for_scores"]) _load_input(fc, "values", golden["values_for_context"]) - _load_input(fc, "W_output", golden["W_output"]) _load_input(fc, "attn_scale_factor", golden["attn_scale_factor"]) _load_input(fc, "causal_mask", golden["causal_mask"]) fc() - for buf_name, golden_key, shape_fn in INTERMEDIATE_CHECKS: + for buf_name, golden_key, shape_fn, buf_type in INTERMEDIATE_CHECKS: shape = shape_fn(H, G, S, d) - actual = _get_scratch_tensor(fc, buf_name, shape) + if buf_type == "output": + actual = _get_output_tensor(fc, buf_name, shape) + else: + actual = _get_scratch_tensor(fc, buf_name, shape) expected = golden[golden_key].float().numpy().reshape(shape) diff = np.abs(actual - expected) print( From 6e4aed4f2acda450f19acb8b1aa70e768ce30669 Mon Sep 17 00:00:00 2001 From: andrej Date: Tue, 7 Apr 2026 14:56:33 -0600 Subject: [PATCH 7/8] remove symbol renaming after rebase to use link_with, other fixes --- iron/common/compilation/base.py | 1 + iron/common/fusion.py | 12 ++--- iron/operators/mha_prefill_lxl_sd/op.py | 50 +++++++++---------- .../operators/mha_prefill_lxl_sd/reference.py | 14 +++++- iron/operators/mha_prefill_lxl_sd/test.py | 35 ++++++------- 5 files changed, 56 insertions(+), 56 deletions(-) diff --git a/iron/common/compilation/base.py b/iron/common/compilation/base.py index 53b1fd3c..4a1d7e44 100644 --- a/iron/common/compilation/base.py +++ b/iron/common/compilation/base.py @@ -496,6 +496,7 @@ def compile(self, graph): str(self.aiecc_path), "-v", "-j1", + "--dynamic-objFifos", "--no-compile-host", "--no-xchesscc", "--no-xbridge", diff --git a/iron/common/fusion.py b/iron/common/fusion.py index f8d1b9cd..292b26e9 100644 --- a/iron/common/fusion.py +++ b/iron/common/fusion.py @@ -43,8 +43,7 @@ def get_kernel_artifacts(self): """Collect all kernel artifacts from child operators. Returns: - List of KernelObjectArtifact instances from all unique child operators, - with filenames and symbol prefixes disambiguated per operator index. + List of KernelObjectArtifact instances from all unique child operators. """ kernel_artifacts = [] seen: dict[int, object] = {} @@ -53,9 +52,6 @@ def get_kernel_artifacts(self): ] for idx, op in enumerate(unique_operators): objs = op.get_kernel_artifacts() - for obj in objs: - obj.filename = f"op{idx}_{obj.filename}" - obj.prefix_symbols = f"op{idx}_" kernel_artifacts.extend(objs) return kernel_artifacts @@ -83,8 +79,6 @@ def get_mlir_artifact(self): ] for idx, op in enumerate(unique_operators): mlir_artifact = op.get_mlir_artifact() - if len(op.get_kernel_artifacts()) > 0: - mlir_artifact.generator.kwargs["func_prefix"] = f"op{idx}_" op_name = f"op{idx}_{op.__class__.__name__}" op_names[id(op)] = op_name operator_mlir_map[op_name] = mlir_artifact @@ -374,10 +368,10 @@ def get_buffer(self, buffer_name): return sub_buffer def __call__(self): - self.input_buffer.to("npu") + self.input_buffer._sync_to_device() super().__call__( self.input_buffer.buffer_object(), self.output_buffer.buffer_object(), self.scratch_buffer.buffer_object(), ) - self.output_buffer.to("cpu") + self.output_buffer._sync_from_device() diff --git a/iron/operators/mha_prefill_lxl_sd/op.py b/iron/operators/mha_prefill_lxl_sd/op.py index 578a63bc..b65a7171 100644 --- a/iron/operators/mha_prefill_lxl_sd/op.py +++ b/iron/operators/mha_prefill_lxl_sd/op.py @@ -7,14 +7,14 @@ from iron.common.context import AIEContext from iron.common.fusion import FusedMLIROperator -from iron.operators.gemm.op import AIEGEMM -from iron.operators.rope.op import AIERope -from iron.operators.strided_copy.op import AIEStridedCopy -from iron.operators.repeat.op import AIERepeat -from iron.operators.softmax.op import AIESoftmax -from iron.operators.transpose.op import AIETranspose -from iron.operators.elementwise_mul.op import AIEElementwiseMul -from iron.operators.elementwise_add.op import AIEElementwiseAdd +from iron.operators.gemm.op import GEMM +from iron.operators.rope.op import RoPE +from iron.operators.strided_copy.op import StridedCopy +from iron.operators.repeat.op import Repeat +from iron.operators.softmax.op import Softmax +from iron.operators.transpose.op import Transpose +from iron.operators.elementwise_mul.op import ElementwiseMul +from iron.operators.elementwise_add.op import ElementwiseAdd def _pick_tile_n(N, num_cols, max_tile_n=64): @@ -40,24 +40,24 @@ def _build_core_ops(H, G, d, S, elf_ctx, causal_mask=True): """ B = 2 # bytes per bf16 element - gemm_scores = AIEGEMM( + gemm_scores = GEMM( M=S, K=d, N=S, num_aie_columns=8, tile_m=16, tile_k=64, tile_n=_pick_tile_n(S, 8), context=elf_ctx, ) - scale = AIEElementwiseMul( + scale = ElementwiseMul( size=H * S * S, tile_size=S * S // 8, num_aie_columns=8, context=elf_ctx, ) if causal_mask: - mask = AIEElementwiseAdd( + mask = ElementwiseAdd( size=H * S * S, tile_size=S * S // 8, num_aie_columns=8, context=elf_ctx, ) - softmax = AIESoftmax( + softmax = Softmax( rows=H * S, cols=S, num_aie_columns=1, num_channels=1, rtp_vector_size=S, context=elf_ctx, ) - gemm_context = AIEGEMM( + gemm_context = GEMM( M=S, K=S, N=d, num_aie_columns=4, tile_m=16, tile_k=64, tile_n=16, context=elf_ctx, prio_accuracy=True, ) @@ -110,7 +110,7 @@ def _build_core_ops(H, G, d, S, elf_ctx, causal_mask=True): return runlist, buffer_sizes -class AIEAttentionPrefillFused(FusedMLIROperator): +class AttentionPrefillFused(FusedMLIROperator): """Fused attention prefill (core, no projections/RoPE). Accepts pre-projected Q (S*H,d), K (S*G,d), V (S*G,d) in interleaved layout. @@ -150,7 +150,7 @@ def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, ) -class AIEAttentionPrefillProjectedFused(FusedMLIROperator): +class AttentionPrefillProjectedFused(FusedMLIROperator): """Fused attention prefill with Q/K/V projections and RoPE. Accepts raw input (S, E) and rope_angles (S, d). @@ -176,25 +176,25 @@ def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, elf_ctx = context or AIEContext() # ---- Projection + RoPE ---- - gemm_query = AIEGEMM( + gemm_query = GEMM( M=S, K=E, N=H * d, num_aie_columns=8, tile_m=16, tile_k=64, tile_n=_pick_tile_n(H * d, 8), context=elf_ctx, ) - gemm_kv = AIEGEMM( + gemm_kv = GEMM( M=S, K=E, N=G * d, num_aie_columns=8, tile_m=16, tile_k=64, tile_n=_pick_tile_n(G * d, 8), context=elf_ctx, ) - rope_queries = AIERope(rows=S * H, cols=d, angle_rows=S, context=elf_ctx) - rope_keys = AIERope(rows=S * G, cols=d, angle_rows=S, context=elf_ctx) + rope_queries = RoPE(rows=S * H, cols=d, angle_rows=S, context=elf_ctx) + rope_keys = RoPE(rows=S * G, cols=d, angle_rows=S, context=elf_ctx) # ---- Deinterleave ---- - deinterleave_q = AIEStridedCopy( + deinterleave_q = StridedCopy( input_sizes=(H, S, d), input_strides=(d, H * d, 1), input_offset=0, output_sizes=(H, S, d), output_strides=(S * d, d, 1), output_offset=0, input_buffer_size=S * H * d, output_buffer_size=H * S * d, transfer_size=S * d, num_aie_channels=1, context=elf_ctx, ) - deinterleave_kv = AIEStridedCopy( + deinterleave_kv = StridedCopy( input_sizes=(G, S, d), input_strides=(d, G * d, 1), input_offset=0, output_sizes=(G, S, d), output_strides=(S * d, d, 1), output_offset=0, input_buffer_size=S * G * d, output_buffer_size=G * S * d, @@ -202,11 +202,11 @@ def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, ) # ---- Transpose keys + GQA repeat ---- - transpose_keys = AIETranspose( + transpose_keys = Transpose( M=S, N=d, num_aie_columns=2, num_channels=1, m=256, n=32, s=8, context=elf_ctx, ) - repeat_kv = AIERepeat( + repeat_kv = Repeat( rows=G, cols=d * S, repeat=group_size, transfer_size=d, context=elf_ctx, ) @@ -246,13 +246,13 @@ def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, ) # ---- Reinterleave + output projection ---- - reinterleave = AIEStridedCopy( + reinterleave = StridedCopy( input_sizes=(1, 1, 1, H * S * d), input_strides=(0, 0, 0, 1), input_offset=0, output_sizes=(H, 256, S // 256, d), output_strides=(d, 256 * H * d, H * d, 1), output_offset=0, input_buffer_size=H * S * d, output_buffer_size=S * H * d, transfer_size=S * d, num_aie_channels=1, context=elf_ctx, ) - gemm_output = AIEGEMM( + gemm_output = GEMM( M=S, K=H * d, N=E, num_aie_columns=8, tile_m=16, tile_k=64, tile_n=_pick_tile_n(E, 8), context=elf_ctx, prio_accuracy=True, ) diff --git a/iron/operators/mha_prefill_lxl_sd/reference.py b/iron/operators/mha_prefill_lxl_sd/reference.py index 3d1c21ed..92cba961 100644 --- a/iron/operators/mha_prefill_lxl_sd/reference.py +++ b/iron/operators/mha_prefill_lxl_sd/reference.py @@ -3,7 +3,19 @@ import torch -from iron.operators.rope.rope_utils import apply_rope as _apply_rope_4d + +def _apply_rope_4d(x, angles): + """Apply RoPE to a 4D tensor using interleaved cos/sin angles. + + x: (batch, heads, seq_len, head_dim) + angles: (seq_len, head_dim) with interleaved [cos_0, sin_0, cos_1, sin_1, ...] + Returns: same shape as x with RoPE applied (two-halves method). + """ + half = x.shape[-1] // 2 + cos = angles[:, ::2].unsqueeze(0).unsqueeze(0) # (1, 1, S, half) + sin = angles[:, 1::2].unsqueeze(0).unsqueeze(0) # (1, 1, S, half) + x1, x2 = x[..., :half], x[..., half:] + return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1) def _bf16_matmul(a, b): diff --git a/iron/operators/mha_prefill_lxl_sd/test.py b/iron/operators/mha_prefill_lxl_sd/test.py index 94549d7f..318c44fc 100644 --- a/iron/operators/mha_prefill_lxl_sd/test.py +++ b/iron/operators/mha_prefill_lxl_sd/test.py @@ -7,11 +7,10 @@ from ml_dtypes import bfloat16 from iron.common.test_utils import verify_buffer -from iron.common.utils import torch_to_numpy from iron.operators.mha_prefill_lxl_sd.op import ( - AIEAttentionPrefillFused, - AIEAttentionPrefillProjectedFused, + AttentionPrefillFused, + AttentionPrefillProjectedFused, ) from iron.operators.mha_prefill_lxl_sd.reference import generate_golden_reference @@ -42,27 +41,22 @@ def get_benchmark_params(): def _load_input(fc, name, tensor): """Load a tensor into a named sub-buffer of the fused callable.""" - fc.get_buffer(name).view_as_np()[:] = torch_to_numpy(tensor).flatten() + np_buf = tensor.contiguous().view(torch.uint16).numpy().view(bfloat16) + fc.get_buffer(name).data[:] = np_buf.flatten() def _get_scratch_tensor(fc, name, shape): """Read a named buffer from the fused callable's scratch space.""" - fc.scratch_buffer.on = "npu" - fc.scratch_buffer.to("cpu") + fc.scratch_buffer._sync_from_device() sub = fc.get_buffer(name) - return np.frombuffer( - sub.memory_view, dtype=bfloat16, count=int(np.prod(shape)) - ).reshape(shape).astype(np.float32) + return sub.data[:int(np.prod(shape))].reshape(shape).astype(np.float32) def _get_output_tensor(fc, name, shape): """Read a named buffer from the fused callable's output space.""" - fc.output_buffer.on = "npu" - fc.output_buffer.to("cpu") + fc.output_buffer._sync_from_device() sub = fc.get_buffer(name) - return np.frombuffer( - sub.memory_view, dtype=bfloat16, count=int(np.prod(shape)) - ).reshape(shape).astype(np.float32) + return sub.data[:int(np.prod(shape))].reshape(shape).astype(np.float32) def _verify_output(fc, golden, H, d, S, E): @@ -72,9 +66,8 @@ def _verify_output(fc, golden, H, d, S, E): ).bfloat16() chain_ref = (npu_context.float() @ golden["W_output"].float()).to(torch.bfloat16) - fc.output_buffer.on = "npu" - fc.output_buffer.to("cpu") - output_np = fc.get_buffer("attn_output").view_as_np() + fc.output_buffer._sync_from_device() + output_np = fc.get_buffer("attn_output").data output = torch.from_numpy(output_np.reshape(S, E).astype(np.float32)).bfloat16() errors = verify_buffer( @@ -112,7 +105,7 @@ def test_mha_pefill_lxl_sd(H, G, d, E, S): """Core attention: score GEMM -> scale -> mask -> softmax -> context GEMM.""" golden = generate_golden_reference(H, G, d, E, S) - op = AIEAttentionPrefillFused(H, G, d, E, S) + op = AttentionPrefillFused(H, G, d, E, S) op.compile() fc = op.get_callable() @@ -152,7 +145,7 @@ def test_attention_prefill_projected_fused(H, G, d, E, S): """Projected attention: Q/K/V proj -> RoPE -> GQA -> attention -> output proj.""" golden = generate_golden_reference(H, G, d, E, S) - op = AIEAttentionPrefillProjectedFused(H, G, d, E, S) + op = AttentionPrefillProjectedFused(H, G, d, E, S) op.compile() fc = op.get_callable() @@ -189,7 +182,7 @@ def test_mha_prefill_benchmark(H, G, d, E, S, causal): """Benchmark core MHA for GPT-2 Small across sequence lengths.""" golden = generate_golden_reference(H, G, d, E, S) - op = AIEAttentionPrefillFused(H, G, d, E, S, causal_mask=causal) + op = AttentionPrefillFused(H, G, d, E, S, causal_mask=causal) op.compile() fc = op.get_callable() @@ -226,7 +219,7 @@ def test_mha_pefill_lxl_sd_intermediates(H, G, d, E, S): """Check intermediate buffers of core attention (for debugging).""" golden = generate_golden_reference(H, G, d, E, S) - op = AIEAttentionPrefillFused(H, G, d, E, S) + op = AttentionPrefillFused(H, G, d, E, S) op.compile() fc = op.get_callable() From 675c21216ce316d4a9c488fd67a00f12fe3bada8 Mon Sep 17 00:00:00 2001 From: andrej Date: Tue, 7 Apr 2026 14:57:00 -0600 Subject: [PATCH 8/8] format --- iron/operators/gemm/design.py | 4 +- iron/operators/mha_prefill_lxl_sd/op.py | 228 +++++++++++++----- .../operators/mha_prefill_lxl_sd/reference.py | 78 +++--- iron/operators/mha_prefill_lxl_sd/test.py | 39 ++- 4 files changed, 254 insertions(+), 95 deletions(-) diff --git a/iron/operators/gemm/design.py b/iron/operators/gemm/design.py index 32222dd4..8b717dcf 100644 --- a/iron/operators/gemm/design.py +++ b/iron/operators/gemm/design.py @@ -314,7 +314,9 @@ def my_matmul( gemm_object, [C_l1_ty], ) - matmul_func_name = f"{func_prefix}matmul{scalar_suffix}_{dtype_in_str}_{dtype_out_str}" + matmul_func_name = ( + f"{func_prefix}matmul{scalar_suffix}_{dtype_in_str}_{dtype_out_str}" + ) matmul_kernel = Kernel( matmul_func_name, gemm_object, diff --git a/iron/operators/mha_prefill_lxl_sd/op.py b/iron/operators/mha_prefill_lxl_sd/op.py index b65a7171..aba8a115 100644 --- a/iron/operators/mha_prefill_lxl_sd/op.py +++ b/iron/operators/mha_prefill_lxl_sd/op.py @@ -41,25 +41,46 @@ def _build_core_ops(H, G, d, S, elf_ctx, causal_mask=True): B = 2 # bytes per bf16 element gemm_scores = GEMM( - M=S, K=d, N=S, num_aie_columns=8, tile_m=16, tile_k=64, - tile_n=_pick_tile_n(S, 8), context=elf_ctx, + M=S, + K=d, + N=S, + num_aie_columns=8, + tile_m=16, + tile_k=64, + tile_n=_pick_tile_n(S, 8), + context=elf_ctx, ) scale = ElementwiseMul( - size=H * S * S, tile_size=S * S // 8, - num_aie_columns=8, context=elf_ctx, + size=H * S * S, + tile_size=S * S // 8, + num_aie_columns=8, + context=elf_ctx, ) if causal_mask: mask = ElementwiseAdd( - size=H * S * S, tile_size=S * S // 8, - num_aie_columns=8, context=elf_ctx, + size=H * S * S, + tile_size=S * S // 8, + num_aie_columns=8, + context=elf_ctx, ) softmax = Softmax( - rows=H * S, cols=S, num_aie_columns=1, num_channels=1, - rtp_vector_size=S, context=elf_ctx, + rows=H * S, + cols=S, + num_aie_columns=1, + num_channels=1, + rtp_vector_size=S, + context=elf_ctx, ) gemm_context = GEMM( - M=S, K=S, N=d, num_aie_columns=4, tile_m=16, tile_k=64, - tile_n=16, context=elf_ctx, prio_accuracy=True, + M=S, + K=S, + N=d, + num_aie_columns=4, + tile_m=16, + tile_k=64, + tile_n=16, + context=elf_ctx, + prio_accuracy=True, ) qh = S * d * B @@ -69,11 +90,15 @@ def _build_core_ops(H, G, d, S, elf_ctx, causal_mask=True): ch = S * d * B runlist = [ - *[(gemm_scores, - f"queries[{h*qh}:{(h+1)*qh}]", - f"keys[{h*kdS}:{(h+1)*kdS}]", - f"attn_scores[{h*sh}:{(h+1)*sh}]") - for h in range(H)], + *[ + ( + gemm_scores, + f"queries[{h*qh}:{(h+1)*qh}]", + f"keys[{h*kdS}:{(h+1)*kdS}]", + f"attn_scores[{h*sh}:{(h+1)*sh}]", + ) + for h in range(H) + ], (scale, "attn_scores", "attn_scale_factor", "attn_scores_scaled"), ] @@ -88,11 +113,15 @@ def _build_core_ops(H, G, d, S, elf_ctx, causal_mask=True): ] runlist += [ - *[(gemm_context, - f"attn_weights[{h*sh}:{(h+1)*sh}]", - f"values[{h*kSd}:{(h+1)*kSd}]", - f"attn_context[{h*ch}:{(h+1)*ch}]") - for h in range(H)], + *[ + ( + gemm_context, + f"attn_weights[{h*sh}:{(h+1)*sh}]", + f"values[{h*kSd}:{(h+1)*kSd}]", + f"attn_context[{h*ch}:{(h+1)*ch}]", + ) + for h in range(H) + ], ] buffer_sizes = { @@ -116,8 +145,16 @@ class AttentionPrefillFused(FusedMLIROperator): Accepts pre-projected Q (S*H,d), K (S*G,d), V (S*G,d) in interleaved layout. """ - def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, - seq_len, causal_mask=True, context=None): + def __init__( + self, + num_heads, + num_kv_groups, + head_dim, + embedding_dim, + seq_len, + causal_mask=True, + context=None, + ): assert head_dim == 64 assert num_heads % num_kv_groups == 0 assert seq_len % 256 == 0 @@ -131,7 +168,11 @@ def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, elf_ctx = context or AIEContext() runlist, buffer_sizes = _build_core_ops( - num_heads, num_kv_groups, head_dim, seq_len, elf_ctx, + num_heads, + num_kv_groups, + head_dim, + seq_len, + elf_ctx, causal_mask=causal_mask, ) @@ -156,8 +197,16 @@ class AttentionPrefillProjectedFused(FusedMLIROperator): Accepts raw input (S, E) and rope_angles (S, d). """ - def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, - seq_len, causal_mask=True, context=None): + def __init__( + self, + num_heads, + num_kv_groups, + head_dim, + embedding_dim, + seq_len, + causal_mask=True, + context=None, + ): assert head_dim == 64 assert num_heads % num_kv_groups == 0 assert seq_len % 256 == 0 @@ -177,38 +226,73 @@ def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, # ---- Projection + RoPE ---- gemm_query = GEMM( - M=S, K=E, N=H * d, num_aie_columns=8, tile_m=16, tile_k=64, - tile_n=_pick_tile_n(H * d, 8), context=elf_ctx, + M=S, + K=E, + N=H * d, + num_aie_columns=8, + tile_m=16, + tile_k=64, + tile_n=_pick_tile_n(H * d, 8), + context=elf_ctx, ) gemm_kv = GEMM( - M=S, K=E, N=G * d, num_aie_columns=8, tile_m=16, tile_k=64, - tile_n=_pick_tile_n(G * d, 8), context=elf_ctx, + M=S, + K=E, + N=G * d, + num_aie_columns=8, + tile_m=16, + tile_k=64, + tile_n=_pick_tile_n(G * d, 8), + context=elf_ctx, ) rope_queries = RoPE(rows=S * H, cols=d, angle_rows=S, context=elf_ctx) rope_keys = RoPE(rows=S * G, cols=d, angle_rows=S, context=elf_ctx) # ---- Deinterleave ---- deinterleave_q = StridedCopy( - input_sizes=(H, S, d), input_strides=(d, H * d, 1), input_offset=0, - output_sizes=(H, S, d), output_strides=(S * d, d, 1), output_offset=0, - input_buffer_size=S * H * d, output_buffer_size=H * S * d, - transfer_size=S * d, num_aie_channels=1, context=elf_ctx, + input_sizes=(H, S, d), + input_strides=(d, H * d, 1), + input_offset=0, + output_sizes=(H, S, d), + output_strides=(S * d, d, 1), + output_offset=0, + input_buffer_size=S * H * d, + output_buffer_size=H * S * d, + transfer_size=S * d, + num_aie_channels=1, + context=elf_ctx, ) deinterleave_kv = StridedCopy( - input_sizes=(G, S, d), input_strides=(d, G * d, 1), input_offset=0, - output_sizes=(G, S, d), output_strides=(S * d, d, 1), output_offset=0, - input_buffer_size=S * G * d, output_buffer_size=G * S * d, - transfer_size=S * d, num_aie_channels=1, context=elf_ctx, + input_sizes=(G, S, d), + input_strides=(d, G * d, 1), + input_offset=0, + output_sizes=(G, S, d), + output_strides=(S * d, d, 1), + output_offset=0, + input_buffer_size=S * G * d, + output_buffer_size=G * S * d, + transfer_size=S * d, + num_aie_channels=1, + context=elf_ctx, ) # ---- Transpose keys + GQA repeat ---- transpose_keys = Transpose( - M=S, N=d, num_aie_columns=2, num_channels=1, - m=256, n=32, s=8, context=elf_ctx, + M=S, + N=d, + num_aie_columns=2, + num_channels=1, + m=256, + n=32, + s=8, + context=elf_ctx, ) repeat_kv = Repeat( - rows=G, cols=d * S, repeat=group_size, - transfer_size=d, context=elf_ctx, + rows=G, + cols=d * S, + repeat=group_size, + transfer_size=d, + context=elf_ctx, ) kSd = S * d * B @@ -223,10 +307,14 @@ def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, (deinterleave_q, "queries_roped", "queries"), (deinterleave_kv, "keys_roped", "keys_deint"), (deinterleave_kv, "values_projected", "values_deint"), - *[(transpose_keys, - f"keys_deint[{g*kSd}:{(g+1)*kSd}]", - f"keys_transposed[{g*kdS}:{(g+1)*kdS}]") - for g in range(G)], + *[ + ( + transpose_keys, + f"keys_deint[{g*kSd}:{(g+1)*kSd}]", + f"keys_transposed[{g*kdS}:{(g+1)*kdS}]", + ) + for g in range(G) + ], (repeat_kv, "keys_transposed", "keys"), (repeat_kv, "values_deint", "values"), ] @@ -242,19 +330,38 @@ def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, } core_runlist, core_buffer_sizes = _build_core_ops( - H, G, d, S, elf_ctx, causal_mask=causal_mask, + H, + G, + d, + S, + elf_ctx, + causal_mask=causal_mask, ) # ---- Reinterleave + output projection ---- reinterleave = StridedCopy( - input_sizes=(1, 1, 1, H * S * d), input_strides=(0, 0, 0, 1), input_offset=0, - output_sizes=(H, 256, S // 256, d), output_strides=(d, 256 * H * d, H * d, 1), output_offset=0, - input_buffer_size=H * S * d, output_buffer_size=S * H * d, - transfer_size=S * d, num_aie_channels=1, context=elf_ctx, + input_sizes=(1, 1, 1, H * S * d), + input_strides=(0, 0, 0, 1), + input_offset=0, + output_sizes=(H, 256, S // 256, d), + output_strides=(d, 256 * H * d, H * d, 1), + output_offset=0, + input_buffer_size=H * S * d, + output_buffer_size=S * H * d, + transfer_size=S * d, + num_aie_channels=1, + context=elf_ctx, ) gemm_output = GEMM( - M=S, K=H * d, N=E, num_aie_columns=8, tile_m=16, tile_k=64, - tile_n=_pick_tile_n(E, 8), context=elf_ctx, prio_accuracy=True, + M=S, + K=H * d, + N=E, + num_aie_columns=8, + tile_m=16, + tile_k=64, + tile_n=_pick_tile_n(E, 8), + context=elf_ctx, + prio_accuracy=True, ) suffix_runlist = [ @@ -266,8 +373,15 @@ def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, } mask_suffix = "_causal" if causal_mask else "_nomask" - input_args = ["input", "rope_angles", "W_query", "W_key", "W_value", - "W_output", "attn_scale_factor"] + input_args = [ + "input", + "rope_angles", + "W_query", + "W_key", + "W_value", + "W_output", + "attn_scale_factor", + ] if causal_mask: input_args.append("causal_mask") @@ -276,6 +390,10 @@ def __init__(self, num_heads, num_kv_groups, head_dim, embedding_dim, runlist=prefix_runlist + core_runlist + suffix_runlist, input_args=input_args, output_args=["attn_output"], - buffer_sizes={**prefix_buffer_sizes, **core_buffer_sizes, **suffix_buffer_sizes}, + buffer_sizes={ + **prefix_buffer_sizes, + **core_buffer_sizes, + **suffix_buffer_sizes, + }, context=elf_ctx, ) diff --git a/iron/operators/mha_prefill_lxl_sd/reference.py b/iron/operators/mha_prefill_lxl_sd/reference.py index 92cba961..3343fa8d 100644 --- a/iron/operators/mha_prefill_lxl_sd/reference.py +++ b/iron/operators/mha_prefill_lxl_sd/reference.py @@ -12,7 +12,7 @@ def _apply_rope_4d(x, angles): Returns: same shape as x with RoPE applied (two-halves method). """ half = x.shape[-1] // 2 - cos = angles[:, ::2].unsqueeze(0).unsqueeze(0) # (1, 1, S, half) + cos = angles[:, ::2].unsqueeze(0).unsqueeze(0) # (1, 1, S, half) sin = angles[:, 1::2].unsqueeze(0).unsqueeze(0) # (1, 1, S, half) x1, x2 = x[..., :half], x[..., half:] return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1) @@ -70,7 +70,7 @@ def generate_golden_reference( W_output = torch.randn(H * d, E, dtype=torch.bfloat16) * val_range # Scale factor: 1/sqrt(d), broadcast to (H*S*S,) - scale = 1.0 / (d ** 0.5) + scale = 1.0 / (d**0.5) attn_scale_factor = torch.full((H * S * S,), scale, dtype=torch.bfloat16) # Causal mask: (H*S, S) — 0 for valid positions, -inf for future @@ -83,39 +83,61 @@ def generate_golden_reference( ) # ---- Q/K/V projections ---- - queries_raw = _bf16_matmul(x, W_query) # (S, H*d) - keys_raw = _bf16_matmul(x, W_key) # (S, G*d) - values_raw = _bf16_matmul(x, W_value) # (S, G*d) + queries_raw = _bf16_matmul(x, W_query) # (S, H*d) + keys_raw = _bf16_matmul(x, W_key) # (S, G*d) + values_raw = _bf16_matmul(x, W_value) # (S, G*d) # ---- RoPE (reuses rope_utils.apply_rope with 4D interface) ---- # Reshape interleaved (S, N*d) → (1, N, S, d) for rope_utils - queries_roped = _apply_rope_4d( - queries_raw.reshape(S, H, d).permute(1, 0, 2).unsqueeze(0), # (1, H, S, d) - rope_angles, - ).squeeze(0).permute(1, 0, 2).contiguous().reshape(S * H, d) # (S*H, d) - - keys_roped = _apply_rope_4d( - keys_raw.reshape(S, G, d).permute(1, 0, 2).unsqueeze(0), # (1, G, S, d) - rope_angles, - ).squeeze(0).permute(1, 0, 2).contiguous().reshape(S * G, d) # (S*G, d) + queries_roped = ( + _apply_rope_4d( + queries_raw.reshape(S, H, d).permute(1, 0, 2).unsqueeze(0), # (1, H, S, d) + rope_angles, + ) + .squeeze(0) + .permute(1, 0, 2) + .contiguous() + .reshape(S * H, d) + ) # (S*H, d) + + keys_roped = ( + _apply_rope_4d( + keys_raw.reshape(S, G, d).permute(1, 0, 2).unsqueeze(0), # (1, G, S, d) + rope_angles, + ) + .squeeze(0) + .permute(1, 0, 2) + .contiguous() + .reshape(S * G, d) + ) # (S*G, d) # ---- Deinterleave Q/K/V ---- - queries_deinterleaved = queries_roped.reshape(S, H, d).transpose(0, 1).contiguous() # (H, S, d) - keys_deinterleaved = keys_roped.reshape(S, G, d).transpose(0, 1).contiguous() # (G, S, d) - keys_transposed = keys_deinterleaved.transpose(1, 2).contiguous() # (G, d, S) - values_deinterleaved = values_raw.reshape(S, G, d).transpose(0, 1).contiguous() # (G, S, d) + queries_deinterleaved = ( + queries_roped.reshape(S, H, d).transpose(0, 1).contiguous() + ) # (H, S, d) + keys_deinterleaved = ( + keys_roped.reshape(S, G, d).transpose(0, 1).contiguous() + ) # (G, S, d) + keys_transposed = keys_deinterleaved.transpose(1, 2).contiguous() # (G, d, S) + values_deinterleaved = ( + values_raw.reshape(S, G, d).transpose(0, 1).contiguous() + ) # (G, S, d) # ---- GQA repeat ---- if group_size > 1: - keys_for_scores = keys_transposed.reshape(G, d * S).repeat_interleave( - group_size, dim=0 - ).reshape(H, d, S) - values_for_context = values_deinterleaved.reshape(G, S * d).repeat_interleave( - group_size, dim=0 - ).reshape(H, S, d) + keys_for_scores = ( + keys_transposed.reshape(G, d * S) + .repeat_interleave(group_size, dim=0) + .reshape(H, d, S) + ) + values_for_context = ( + values_deinterleaved.reshape(G, S * d) + .repeat_interleave(group_size, dim=0) + .reshape(H, S, d) + ) else: - keys_for_scores = keys_transposed # (H, d, S) - values_for_context = values_deinterleaved # (H, S, d) + keys_for_scores = keys_transposed # (H, d, S) + values_for_context = values_deinterleaved # (H, S, d) # ---- Score GEMM per head ---- attn_scores = torch.stack( @@ -133,7 +155,9 @@ def generate_golden_reference( # ---- Softmax ---- attn_weights = torch.nn.functional.softmax( attn_scores_masked.float().reshape(H, S, S), dim=-1 - ).to(torch.bfloat16) # (H, S, S) + ).to( + torch.bfloat16 + ) # (H, S, S) # ---- Context GEMM per head ---- attn_context = torch.stack( diff --git a/iron/operators/mha_prefill_lxl_sd/test.py b/iron/operators/mha_prefill_lxl_sd/test.py index 318c44fc..7e90b748 100644 --- a/iron/operators/mha_prefill_lxl_sd/test.py +++ b/iron/operators/mha_prefill_lxl_sd/test.py @@ -49,14 +49,14 @@ def _get_scratch_tensor(fc, name, shape): """Read a named buffer from the fused callable's scratch space.""" fc.scratch_buffer._sync_from_device() sub = fc.get_buffer(name) - return sub.data[:int(np.prod(shape))].reshape(shape).astype(np.float32) + return sub.data[: int(np.prod(shape))].reshape(shape).astype(np.float32) def _get_output_tensor(fc, name, shape): """Read a named buffer from the fused callable's output space.""" fc.output_buffer._sync_from_device() sub = fc.get_buffer(name) - return sub.data[:int(np.prod(shape))].reshape(shape).astype(np.float32) + return sub.data[: int(np.prod(shape))].reshape(shape).astype(np.float32) def _verify_output(fc, golden, H, d, S, E): @@ -71,24 +71,28 @@ def _verify_output(fc, golden, H, d, S, E): output = torch.from_numpy(output_np.reshape(S, E).astype(np.float32)).bfloat16() errors = verify_buffer( - output, "attn_output", chain_ref.reshape(S, E), - rel_tol=REL_TOL, abs_tol=ABS_TOL, max_error_rate=MAX_ERROR_RATE, + output, + "attn_output", + chain_ref.reshape(S, E), + rel_tol=REL_TOL, + abs_tol=ABS_TOL, + max_error_rate=MAX_ERROR_RATE, ) assert not errors, f"Output verification failed with {len(errors)} errors" def _core_gemm_flops(H, G, d, E, S): """Count GEMM FLOPs for the core attention operator.""" - score_flops = H * 2 * S * d * S # H x (S,d)@(d,S) - context_flops = H * 2 * S * S * d # H x (S,S)@(S,d) + score_flops = H * 2 * S * d * S # H x (S,d)@(d,S) + context_flops = H * 2 * S * S * d # H x (S,S)@(S,d) return score_flops + context_flops def _projected_gemm_flops(H, G, d, E, S): """Count GEMM FLOPs for the projected attention operator.""" - query_proj = 2 * S * E * (H * d) # (S,E)@(E,H*d) - kv_proj = 2 * (2 * S * E * (G * d)) # key + value: (S,E)@(E,G*d) each - output_proj = 2 * S * (H * d) * E # (S,H*d)@(H*d,E) + query_proj = 2 * S * E * (H * d) # (S,E)@(E,H*d) + kv_proj = 2 * (2 * S * E * (G * d)) # key + value: (S,E)@(E,G*d) each + output_proj = 2 * S * (H * d) * E # (S,H*d)@(H*d,E) return query_proj + kv_proj + _core_gemm_flops(H, G, d, E, S) + output_proj @@ -96,6 +100,7 @@ def _projected_gemm_flops(H, G, d, E, S): # Core attention tests (pre-projected Q, K, V) # --------------------------------------------------------------------------- + @pytest.mark.metrics( Latency=r"Latency \(us\): (?P[\d\.]+)", Throughput=r"Throughput: (?P[\d\.e\+-]+) GFLOP/s", @@ -125,9 +130,12 @@ def test_mha_pefill_lxl_sd(H, G, d, E, S): actual = _get_output_tensor(fc, "attn_context", (H, S, d)) expected = golden["attn_context"].float().numpy().reshape(H, S, d) errors = verify_buffer( - torch.from_numpy(actual).bfloat16(), "attn_context", + torch.from_numpy(actual).bfloat16(), + "attn_context", torch.from_numpy(expected).bfloat16().reshape(H, S, d), - rel_tol=REL_TOL, abs_tol=ABS_TOL, max_error_rate=MAX_ERROR_RATE, + rel_tol=REL_TOL, + abs_tol=ABS_TOL, + max_error_rate=MAX_ERROR_RATE, ) assert not errors, f"Output verification failed with {len(errors)} errors" @@ -136,6 +144,7 @@ def test_mha_pefill_lxl_sd(H, G, d, E, S): # Projected attention tests (with Q/K/V projections + RoPE) # --------------------------------------------------------------------------- + @pytest.mark.metrics( Latency=r"Latency \(us\): (?P[\d\.]+)", Throughput=r"Throughput: (?P[\d\.e\+-]+) GFLOP/s", @@ -172,6 +181,7 @@ def test_attention_prefill_projected_fused(H, G, d, E, S): # Benchmark: GPT-2 Small core MHA across sequence lengths, +/- causal mask # --------------------------------------------------------------------------- + @pytest.mark.benchmark @pytest.mark.metrics( Latency=r"Latency \(us\): (?P[\d\.]+)", @@ -207,7 +217,12 @@ def test_mha_prefill_benchmark(H, G, d, E, S, causal): INTERMEDIATE_CHECKS = [ ("attn_scores", "attn_scores", lambda H, G, S, d: (H, S, S), "scratch"), - ("attn_scores_masked", "attn_scores_masked", lambda H, G, S, d: (H, S, S), "scratch"), + ( + "attn_scores_masked", + "attn_scores_masked", + lambda H, G, S, d: (H, S, S), + "scratch", + ), ("attn_weights", "attn_weights", lambda H, G, S, d: (H, S, S), "scratch"), ("attn_context", "attn_context", lambda H, G, S, d: (H, S, d), "output"), ]