From 4a3314df778917989d570249cb61f93d2d2408b3 Mon Sep 17 00:00:00 2001 From: albiol2004 Date: Thu, 9 Apr 2026 19:04:15 +0200 Subject: [PATCH 1/5] first working version of INT 4 GEMV --- aie_kernels/generic/fused_dequant_gemv.cc | 94 +++++++++++++ iron/operators/gemv_int4/__init__.py | 2 + iron/operators/gemv_int4/design.py | 160 ++++++++++++++++++++++ iron/operators/gemv_int4/op.py | 110 +++++++++++++++ iron/operators/gemv_int4/reference.py | 141 +++++++++++++++++++ iron/operators/gemv_int4/test.py | 91 ++++++++++++ 6 files changed, 598 insertions(+) create mode 100644 aie_kernels/generic/fused_dequant_gemv.cc create mode 100644 iron/operators/gemv_int4/__init__.py create mode 100644 iron/operators/gemv_int4/design.py create mode 100644 iron/operators/gemv_int4/op.py create mode 100644 iron/operators/gemv_int4/reference.py create mode 100644 iron/operators/gemv_int4/test.py diff --git a/aie_kernels/generic/fused_dequant_gemv.cc b/aie_kernels/generic/fused_dequant_gemv.cc new file mode 100644 index 00000000..8b0b33db --- /dev/null +++ b/aie_kernels/generic/fused_dequant_gemv.cc @@ -0,0 +1,94 @@ +// SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +// Fused INT4 dequantization + GEMV kernel for AIE2+. +// +// Loads INT4-packed weights, dequantizes in-register, and performs +// matrix-vector multiplication in a single pass. +// +// Weight layout per tile (m rows x K cols, group_size G): +// [m * K / 2 bytes of packed uint4 weights] +// [m * (K / G) bf16 scale factors, stored as (m * K / G * 2) bytes] +// +// Dequantization: w_bf16 = scale * unpack_uint4_to_bf16(w_uint4) +// +// The unpack chain matches the existing dequant kernel (expand.cc): +// uint4 -> uint8 (aie::unpack) -> uint16 (aie::unpack) -> bf16 (aie::to_float) + +#define NOCPP + +#include "../aie_kernel_utils.h" + +#include +#include +#include + +template +void fused_dequant_matvec(uint32_t m, + uint32_t k, + const uint8_t *__restrict a_in, + const bfloat16 *__restrict b_in, + bfloat16 *__restrict c_out, + uint32_t group_size) +{ + static_assert(block_size == 32, "block_size must be 32 to match dequant vector width"); + + ::aie::set_rounding(aie::rounding_mode::conv_even); + + const uint4 *weights_packed = reinterpret_cast(a_in); + const uint8_t *scale_bytes = a_in + m * k / 2; + const bfloat16 *scales = reinterpret_cast(scale_bytes); + + const uint32_t groups_per_row = k / group_size; + const uint32_t blocks_per_group = group_size / block_size; + + event0(); + for (uint32_t row = 0; row < m; row++) { + const uint4 *row_weights = weights_packed + row * k / 2; + const bfloat16 *row_scales = scales + row * groups_per_row; + const bfloat16 *b_ptr = b_in; + + aie::accum acc = aie::zeros(); + + for (uint32_t g = 0; g < groups_per_row; g++) { + bfloat16 sf = row_scales[g]; + aie::vector sf_broadcast = aie::broadcast(sf); + + for (uint32_t blk = 0; blk < blocks_per_group; blk++) { + aie::vector I0 = aie::load_v(row_weights); + row_weights += block_size / 2; + + aie::vector as_int8 = aie::unpack(I0); + aie::vector as_int16 = aie::unpack(as_int8); + aie::vector as_bf16 = aie::to_float(as_int16, 0); + + aie::vector w_dequant = aie::mul(as_bf16, sf_broadcast).template to_vector(); + + aie::vector b_vec = aie::load_v(b_ptr); + b_ptr += block_size; + + acc = aie::mac(acc, w_dequant, b_vec); + } + } + + *c_out = static_cast(aie::reduce_add(acc.template to_vector())); + c_out++; + } + event1(); +} + +extern "C" { + +void fused_dequant_matvec_bf16(uint32_t m, + uint32_t k, + uint32_t row_offset, + const uint8_t *__restrict a_in, + const bfloat16 *__restrict b_in, + bfloat16 *__restrict c_out, + uint32_t group_size) +{ + c_out += row_offset; + fused_dequant_matvec<32>(m, k, a_in, b_in, c_out, group_size); +} + +} // extern "C" diff --git a/iron/operators/gemv_int4/__init__.py b/iron/operators/gemv_int4/__init__.py new file mode 100644 index 00000000..c8ac4702 --- /dev/null +++ b/iron/operators/gemv_int4/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/iron/operators/gemv_int4/design.py b/iron/operators/gemv_int4/design.py new file mode 100644 index 00000000..ba622865 --- /dev/null +++ b/iron/operators/gemv_int4/design.py @@ -0,0 +1,160 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Fused INT4 dequantization matrix-vector design. + +Performs a fused dequantize-GEMV where the weight matrix is stored in packed +INT4 format (two 4-bit values per uint8 byte) with per-group bfloat16 scale +factors. The activation vector and output are bfloat16. + +Each AIE column processes a contiguous block of output rows. Within a column, +the worker iterates over tiles of packed weight rows, acquires the full +activation vector once per outer iteration, and calls the fused dequant-matvec +kernel which unpacks, dequantizes, and accumulates in a single pass. + +Buffer layout for A (packed weights, uint8): + For each tile of m_input rows: [m_input * K / 2 bytes of packed weights] + [m_input * (K / group_size) * 2 bytes of scales] +""" + +import numpy as np +from ml_dtypes import bfloat16 + +import aie.dialects.index as index +from aie.dialects.aie import T +from aie.helpers.dialects.scf import _for as range_ +from aie.helpers.taplib import TensorAccessPattern +from aie.iron import Kernel, ObjectFifo, Program, Runtime, Worker +from aie.iron.placers import SequentialPlacer + + +def my_fused_dequant_matvec( + dev, cols, M, K, m_input, m_output=None, group_size=32 +): + if m_output is None: + m_output = m_input + + # --- Assertions --- + assert ( + m_output % m_input == 0 and m_output >= m_input + ), "m_output must be a multiple of m_input" + assert m_output <= M // cols, "m_output must be less than or equal to M/cols" + assert (M // cols) % m_output == 0, "m_output must evenly divide M/cols" + assert m_input <= M // cols, "m_input must be less than or equal to M/cols" + assert (M // cols) % m_input == 0, "m_input must evenly divide M/cols" + assert K % group_size == 0, "K must be divisible by group_size" + assert group_size % 32 == 0, "group_size must be a multiple of 32" + assert M % cols == 0, "M must be divisible by cols" + + # --- Data types --- + dtype_in = np.dtype[np.uint8] + dtype_vec = np.dtype[bfloat16] + dtype_out = np.dtype[bfloat16] + + # --- Per-tile sizes (in uint8 bytes) --- + num_groups_per_row = K // group_size + packed_tile_bytes = m_input * K // 2 + m_input * num_groups_per_row * 2 + rows_per_col = M // cols + tiles_per_col = rows_per_col // m_input + bytes_per_col = tiles_per_col * packed_tile_bytes + packed_total_bytes = cols * bytes_per_col + + # --- L1 (on-chip) tensor types --- + L1_A_ty = np.ndarray[(packed_tile_bytes,), dtype_in] + L1_B_ty = np.ndarray[(K,), dtype_vec] + L1_C_ty = np.ndarray[(m_output,), dtype_out] + + # --- L3 (DDR) tensor types --- + L3_A_ty = np.ndarray[(packed_total_bytes,), dtype_in] + L3_B_ty = np.ndarray[(K,), dtype_vec] + L3_C_ty = np.ndarray[(M,), dtype_out] + + # --- Kernel declaration --- + fused_matvec = Kernel( + "fused_dequant_matvec_bf16", + "fused_dequant_gemv.o", + [np.int32, np.int32, np.int32, L1_A_ty, L1_B_ty, L1_C_ty, np.int32], + ) + + # --- ObjectFIFOs --- + A_L3L1_fifos = [ + ObjectFifo(L1_A_ty, name=f"A_L3L1_{i}", depth=2) for i in range(cols) + ] + B_L3L1_fifos = [ + ObjectFifo(L1_B_ty, name=f"B_L3L1_{i}", depth=1) for i in range(cols) + ] + C_L1L3_fifos = [ + ObjectFifo(L1_C_ty, name=f"C_L1L3_{i}", depth=2) for i in range(cols) + ] + + # --- Worker core body --- + N_div_n = tiles_per_col // (m_output // m_input) + + def core_body(A_L3L1_fifo, B_L3L1_fifo, C_L1L3_fifo, fused_matvec_fn): + for _ in range_(0xFFFFFFFF): + b = B_L3L1_fifo.acquire(1) + for i_idx in range_(N_div_n): + c = C_L1L3_fifo.acquire(1) + for j_idx in range_(m_output // m_input): + j_i32 = index.casts(T.i32(), j_idx) + output_row_offset = j_i32 * m_input + a = A_L3L1_fifo.acquire(1) + fused_matvec_fn( + m_input, K, output_row_offset, a, b, c, group_size + ) + A_L3L1_fifo.release(1) + C_L1L3_fifo.release(1) + B_L3L1_fifo.release(1) + + workers = [ + Worker( + core_body, + [ + A_L3L1_fifos[i].cons(), + B_L3L1_fifos[i].cons(), + C_L1L3_fifos[i].prod(), + fused_matvec, + ], + ) + for i in range(cols) + ] + + # --- TensorAccessPatterns --- + # A: each column gets a contiguous chunk of bytes_per_col packed bytes + A_taps = [ + TensorAccessPattern( + tensor_dims=L3_A_ty.__args__[0], + offset=col * bytes_per_col, + sizes=[1, 1, 1, bytes_per_col], + strides=[0, 0, 0, 1], + ) + for col in range(cols) + ] + + # C: each column writes contiguous rows_per_col bfloat16 values + C_taps = [ + TensorAccessPattern( + tensor_dims=L3_C_ty.__args__[0], + offset=col * rows_per_col, + sizes=[1, 1, 1, rows_per_col], + strides=[0, 0, 0, 1], + ) + for col in range(cols) + ] + + # --- Runtime sequence --- + rt = Runtime() + with rt.sequence(L3_A_ty, L3_B_ty, L3_C_ty) as (A, B, C): + rt.start(*workers) + tg = rt.task_group() + for i in range(cols): + rt.fill(A_L3L1_fifos[i].prod(), A, A_taps[i], task_group=tg) + rt.fill(B_L3L1_fifos[i].prod(), B, task_group=tg) + for i in range(cols): + rt.drain( + C_L1L3_fifos[i].cons(), C, C_taps[i], task_group=tg, wait=True + ) + rt.finish_task_group(tg) + + return Program(dev, rt).resolve_program(SequentialPlacer()) diff --git a/iron/operators/gemv_int4/op.py b/iron/operators/gemv_int4/op.py new file mode 100644 index 00000000..40daff39 --- /dev/null +++ b/iron/operators/gemv_int4/op.py @@ -0,0 +1,110 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass, field +from typing import ClassVar, Dict + +import numpy as np +from ml_dtypes import bfloat16 + +from iron.common import ( + MLIROperator, + AIERuntimeArgSpec, + KernelObjectArtifact, + SourceArtifact, + PythonGeneratedMLIRArtifact, + DesignGenerator, +) +import aie.utils as aie_utils + + +@dataclass +class GEMVInt4(MLIROperator): + """AIE-accelerated fused INT4 dequantization and GEMV operator""" + + M: int + K: int + num_aie_columns: int = 4 + tile_size_input: int = 1 + tile_size_output: int | None = None + group_size: int = field(default=32, repr=False) + context: object = field(default=None, repr=False) + + _name_aliases: ClassVar[Dict[str, str]] = { + **MLIROperator._name_aliases, + "num_aie_columns": "col", + "tile_size_input": "tsi", + "tile_size_output": "tso", + "group_size": "g", + } + + def __post_init__(self): + if self.tile_size_output is None: + self.tile_size_output = self.M // self.num_aie_columns + + if not ( + self.tile_size_output % self.tile_size_input == 0 + and self.tile_size_output >= self.tile_size_input + ): + raise ValueError("tile_size_output must be a multiple of tile_size_input") + if not (self.K % self.group_size == 0): + raise ValueError("K must be a multiple of group_size") + if not (self.group_size % 32 == 0): + raise ValueError("group_size must be a multiple of 32") + if not (self.M % self.num_aie_columns == 0): + raise ValueError("M must be a multiple of num_aie_columns") + + MLIROperator.__init__(self, context=self.context) + + @property + def _packed_buffer_size(self): + num_groups_per_row = self.K // self.group_size + packed_tile_bytes = ( + self.tile_size_input * self.K // 2 + + self.tile_size_input * num_groups_per_row * 2 + ) + rows_per_col = self.M // self.num_aie_columns + tiles_per_col = rows_per_col // self.tile_size_input + return self.num_aie_columns * tiles_per_col * packed_tile_bytes + + def get_mlir_artifact(self): + return PythonGeneratedMLIRArtifact( + f"{self.name}.mlir", + DesignGenerator( + self.operator_dir / "design.py", + "my_fused_dequant_matvec", + ( + aie_utils.get_current_device(), + self.num_aie_columns, + self.M, + self.K, + self.tile_size_input, + self.tile_size_output, + self.group_size, + ), + ), + ) + + def get_kernel_artifacts(self): + return [ + KernelObjectArtifact( + "fused_dequant_gemv.o", + dependencies=[ + SourceArtifact( + self.context.base_dir + / "aie_kernels" + / "generic" + / "fused_dequant_gemv.cc" + ) + ], + ), + ] + + def get_arg_spec(self): + return [ + AIERuntimeArgSpec( + "in", (self._packed_buffer_size,), dtype=np.uint8 + ), # packed INT4 weights + AIERuntimeArgSpec("in", (self.K,)), # bf16 activation vector + AIERuntimeArgSpec("out", (self.M,)), # bf16 output vector + ] diff --git a/iron/operators/gemv_int4/reference.py b/iron/operators/gemv_int4/reference.py new file mode 100644 index 00000000..e9e27506 --- /dev/null +++ b/iron/operators/gemv_int4/reference.py @@ -0,0 +1,141 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import torch +import numpy as np +from ml_dtypes import bfloat16 + + +def quantize_and_pack(M, K, group_size=32, m_input=1, cols=4): + """Generate quantized INT4 weights and pack for the fused dequant-GEMV kernel. + + Uses the same quantization scheme as the existing dequant operator + (iron/operators/dequant/reference.py): unsigned INT4 values with per-group + bf16 scale factors, zero-point fixed at 0. + + The DDR buffer is laid out per-tile, where each tile corresponds to + ``m_input`` matrix rows. Tiles for column 0 come first, then column 1, + etc. Within each tile the layout is: + + [m_input * K / 2 bytes] packed uint4 weights (2 values per byte, + low nibble first in little-endian order) + [m_input * (K / group_size) * 2 bytes] bf16 scale factors + + Args: + M: Number of rows in the weight matrix. + K: Number of columns in the weight matrix. + group_size: Number of elements per quantization group (default 32). + m_input: Number of rows per kernel tile invocation. + cols: Number of AIE columns the work is split across. + + Returns: + packed: numpy uint8 array with the complete packed DDR buffer. + W_dequant: torch.bfloat16 (M, K) tensor of dequantized weights. + """ + assert K % group_size == 0, "K must be a multiple of group_size" + assert M % cols == 0, "M must be a multiple of cols" + rows_per_col = M // cols + assert rows_per_col % m_input == 0, "rows_per_col must be a multiple of m_input" + + num_groups_per_row = K // group_size + val_range = 3.75 + r1, r2 = 1 / val_range, 1.0 + + # Generate per-group scale factors in [r1, r2) + total_groups = M * num_groups_per_row + scales_flat = r1 + (r2 - r1) * torch.rand(total_groups, dtype=torch.bfloat16) + zero_points = torch.zeros(total_groups, dtype=torch.bfloat16) + + # Generate random data in [0, val_range) shaped for per-group quantization + W_grouped = torch.rand(total_groups, group_size, dtype=torch.bfloat16) * val_range + + # Quantize with PyTorch per-channel (per-group) quantization + A_quant = torch.quantize_per_channel( + W_grouped.to(torch.float32), + scales=scales_flat.to(torch.float32), + zero_points=zero_points.to(torch.float32), + axis=0, + dtype=torch.quint8, + ) + W_dequant = torch.dequantize(A_quant).to(torch.bfloat16).reshape(M, K) + A_int = A_quant.int_repr() # (total_groups, group_size) with values in [0,15] + + # Now pack into the tile-based DDR layout. + # Tile order: column 0 tiles first, then column 1, etc. + packed_bytes_per_tile = m_input * K // 2 + m_input * num_groups_per_row * 2 + tiles_per_col = rows_per_col // m_input + total_tiles = cols * tiles_per_col + total_bytes = total_tiles * packed_bytes_per_tile + + packed = np.zeros(total_bytes, dtype=np.uint8) + + for col in range(cols): + for tile_idx in range(tiles_per_col): + # Global row range for this tile + row_start = col * rows_per_col + tile_idx * m_input + # Offset into the packed buffer + flat_tile = col * tiles_per_col + tile_idx + tile_offset = flat_tile * packed_bytes_per_tile + + # 1) Pack uint4 weights for m_input rows + for r in range(m_input): + global_row = row_start + r + for grp in range(num_groups_per_row): + flat_grp = global_row * num_groups_per_row + grp + for k in range(group_size // 2): + val_lo = int(A_int[flat_grp, 2 * k].item()) & 0x0F + val_hi = int(A_int[flat_grp, 2 * k + 1].item()) & 0x0F + byte_idx = ( + tile_offset + r * (K // 2) + grp * (group_size // 2) + k + ) + packed[byte_idx] = val_lo | (val_hi << 4) + + # 2) Pack bf16 scale factors for m_input rows + scale_region_start = tile_offset + m_input * K // 2 + for r in range(m_input): + global_row = row_start + r + for grp in range(num_groups_per_row): + flat_grp = global_row * num_groups_per_row + grp + sf_val = scales_flat[flat_grp] + sf_uint16 = sf_val.view(torch.uint16).item() + sf_offset = scale_region_start + (r * num_groups_per_row + grp) * 2 + packed[sf_offset] = sf_uint16 & 0xFF + packed[sf_offset + 1] = (sf_uint16 >> 8) & 0xFF + + return packed, W_dequant + + +def generate_golden_reference( + M=2048, K=2048, group_size=32, m_input=1, cols=4, seed=42 +): + """Generate golden reference for fused dequant-GEMV. + + Args: + M: Number of rows in the weight matrix. + K: Number of columns (== input vector length). + group_size: Quantization group size. + m_input: Number of rows per kernel tile invocation. + cols: Number of AIE columns. + seed: Random seed for reproducibility. + + Returns: + dict with packed_weights, x, output, W_dequant. + """ + torch.manual_seed(seed) + + # Generate random input vector + val_range = 4 + x = torch.randn(K, dtype=torch.bfloat16) * val_range + + # Generate quantized + packed weights + packed_weights, W_dequant = quantize_and_pack(M, K, group_size, m_input, cols) + + # Reference output: dequantized_weights @ x + output = W_dequant @ x + + return { + "packed_weights": packed_weights, + "x": x, + "output": output, + "W_dequant": W_dequant, + } diff --git a/iron/operators/gemv_int4/test.py b/iron/operators/gemv_int4/test.py new file mode 100644 index 00000000..603487ae --- /dev/null +++ b/iron/operators/gemv_int4/test.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +import aie.utils as aie_utils + +from iron.operators.gemv_int4.op import GEMVInt4 +from iron.operators.gemv_int4.reference import generate_golden_reference +from iron.common.test_utils import run_test + + +def get_params(): + max_aie_columns = aie_utils.get_current_device().cols + + params_list = [ + # (M, K, num_aie_columns, tile_size_input, tile_size_output, group_size) + (2048, 2048, 4, 1, 512, 32), # Basic + (8192, 2048, 4, 1, 2048, 32), # Llama down_proj + (2048, 8192, 4, 1, 512, 32), # Llama up_proj + ] + + params = [] + for p in params_list: + M, K, num_aie_columns, tile_size_input, tile_size_output, group_size = p + # Skip tests that require more columns than available on the device + if num_aie_columns > max_aie_columns: + continue + params.append( + pytest.param( + *p, + id=f"gemv_int4_{M}x{K}_{tile_size_input}tsi_{tile_size_output}tso_{num_aie_columns}col_g{group_size}", + ) + ) + return params + + +@pytest.mark.metrics( + Latency=r"Latency \(us\): (?P[\d\.]+)", + Bandwidth=r"Effective Bandwidth: (?P[\d\.e\+-]+) GB/s", + Throughput=r"Throughput: (?P[\d\.e\+-]+) GFLOP/s", +) +@pytest.mark.parametrize( + "M,K,num_aie_columns,tile_size_input,tile_size_output,group_size", get_params() +) +def test_gemv_int4( + M, K, num_aie_columns, tile_size_input, tile_size_output, group_size, aie_context +): + golden_ref = generate_golden_reference( + M=M, + K=K, + group_size=group_size, + m_input=tile_size_input, + cols=num_aie_columns, + ) + + operator = GEMVInt4( + M=M, + K=K, + num_aie_columns=num_aie_columns, + tile_size_input=tile_size_input, + tile_size_output=tile_size_output, + group_size=group_size, + context=aie_context, + ) + + input_buffers = { + "packed_weights": torch.from_numpy(golden_ref["packed_weights"]), + "vector": golden_ref["x"], + } + output_buffers = {"output": golden_ref["output"]} + + errors, latency_us, bandwidth_gbps = run_test( + operator, input_buffers, output_buffers, rel_tol=0.07, abs_tol=0.7 + ) + + print(f"\nLatency (us): {latency_us:.1f}") + + gflops = (2.0 * M * K) / (latency_us * 1e-6) / 1e9 + print(f"Throughput: {gflops:.2e} GFLOP/s") + + # INT4 weights: M*K/2 bytes + scales (bf16): M*(K//group_size)*2 bytes + weight_bytes = M * K / 2 + M * (K // group_size) * 2 + vector_bytes = K * 2 # bf16 + output_bytes = M * 2 # bf16 + total_bytes = weight_bytes + vector_bytes + output_bytes + bandwidth = total_bytes / (latency_us * 1e-6) / 1e9 + print(f"Effective Bandwidth: {bandwidth:.2e} GB/s\n") + + assert not errors, f"Test failed with errors: {errors}" From 9cd050a22e24db5ce233afbad35d226b3bf65d63 Mon Sep 17 00:00:00 2001 From: albiol2004 Date: Thu, 9 Apr 2026 19:15:30 +0200 Subject: [PATCH 2/5] 1.6x speedup, GROUP_SIZE at compile time --- aie_kernels/generic/fused_dequant_gemv.cc | 26 +++++++++++++++-------- iron/operators/gemv_int4/design.py | 7 +++--- iron/operators/gemv_int4/op.py | 5 ++++- iron/operators/gemv_int4/test.py | 10 ++++++--- 4 files changed, 32 insertions(+), 16 deletions(-) diff --git a/aie_kernels/generic/fused_dequant_gemv.cc b/aie_kernels/generic/fused_dequant_gemv.cc index 8b0b33db..1b9da296 100644 --- a/aie_kernels/generic/fused_dequant_gemv.cc +++ b/aie_kernels/generic/fused_dequant_gemv.cc @@ -23,15 +23,18 @@ #include #include -template +// block_size: dequant vector width (must be 32 for aie::unpack) +// G: group size (compile-time for pipelining, must be multiple of block_size) +template void fused_dequant_matvec(uint32_t m, uint32_t k, const uint8_t *__restrict a_in, const bfloat16 *__restrict b_in, - bfloat16 *__restrict c_out, - uint32_t group_size) + bfloat16 *__restrict c_out) { static_assert(block_size == 32, "block_size must be 32 to match dequant vector width"); + static_assert(G % block_size == 0, "group_size must be a multiple of block_size"); + constexpr uint32_t blocks_per_group = G / block_size; ::aie::set_rounding(aie::rounding_mode::conv_even); @@ -39,8 +42,7 @@ void fused_dequant_matvec(uint32_t m, const uint8_t *scale_bytes = a_in + m * k / 2; const bfloat16 *scales = reinterpret_cast(scale_bytes); - const uint32_t groups_per_row = k / group_size; - const uint32_t blocks_per_group = group_size / block_size; + const uint32_t groups_per_row = k / G; event0(); for (uint32_t row = 0; row < m; row++) { @@ -50,10 +52,13 @@ void fused_dequant_matvec(uint32_t m, aie::accum acc = aie::zeros(); - for (uint32_t g = 0; g < groups_per_row; g++) { + for (uint32_t g = 0; g < groups_per_row; g++) + chess_prepare_for_pipelining chess_loop_range(1, ) + { bfloat16 sf = row_scales[g]; aie::vector sf_broadcast = aie::broadcast(sf); + AIE_LOOP_MIN_ITERATION_COUNT(1) for (uint32_t blk = 0; blk < blocks_per_group; blk++) { aie::vector I0 = aie::load_v(row_weights); row_weights += block_size / 2; @@ -77,6 +82,10 @@ void fused_dequant_matvec(uint32_t m, event1(); } +#ifndef GROUP_SIZE +#define GROUP_SIZE 32 +#endif + extern "C" { void fused_dequant_matvec_bf16(uint32_t m, @@ -84,11 +93,10 @@ void fused_dequant_matvec_bf16(uint32_t m, uint32_t row_offset, const uint8_t *__restrict a_in, const bfloat16 *__restrict b_in, - bfloat16 *__restrict c_out, - uint32_t group_size) + bfloat16 *__restrict c_out) { c_out += row_offset; - fused_dequant_matvec<32>(m, k, a_in, b_in, c_out, group_size); + fused_dequant_matvec<32, GROUP_SIZE>(m, k, a_in, b_in, c_out); } } // extern "C" diff --git a/iron/operators/gemv_int4/design.py b/iron/operators/gemv_int4/design.py index ba622865..cb6829ab 100644 --- a/iron/operators/gemv_int4/design.py +++ b/iron/operators/gemv_int4/design.py @@ -71,10 +71,11 @@ def my_fused_dequant_matvec( L3_C_ty = np.ndarray[(M,), dtype_out] # --- Kernel declaration --- + # group_size is compile-time via -DGROUP_SIZE, not a runtime parameter. fused_matvec = Kernel( "fused_dequant_matvec_bf16", - "fused_dequant_gemv.o", - [np.int32, np.int32, np.int32, L1_A_ty, L1_B_ty, L1_C_ty, np.int32], + f"fused_dequant_gemv_g{group_size}.o", + [np.int32, np.int32, np.int32, L1_A_ty, L1_B_ty, L1_C_ty], ) # --- ObjectFIFOs --- @@ -101,7 +102,7 @@ def core_body(A_L3L1_fifo, B_L3L1_fifo, C_L1L3_fifo, fused_matvec_fn): output_row_offset = j_i32 * m_input a = A_L3L1_fifo.acquire(1) fused_matvec_fn( - m_input, K, output_row_offset, a, b, c, group_size + m_input, K, output_row_offset, a, b, c ) A_L3L1_fifo.release(1) C_L1L3_fifo.release(1) diff --git a/iron/operators/gemv_int4/op.py b/iron/operators/gemv_int4/op.py index 40daff39..8eb88741 100644 --- a/iron/operators/gemv_int4/op.py +++ b/iron/operators/gemv_int4/op.py @@ -88,7 +88,7 @@ def get_mlir_artifact(self): def get_kernel_artifacts(self): return [ KernelObjectArtifact( - "fused_dequant_gemv.o", + f"fused_dequant_gemv_g{self.group_size}.o", dependencies=[ SourceArtifact( self.context.base_dir @@ -97,6 +97,9 @@ def get_kernel_artifacts(self): / "fused_dequant_gemv.cc" ) ], + extra_flags=[ + f"-DGROUP_SIZE={self.group_size}", + ], ), ] diff --git a/iron/operators/gemv_int4/test.py b/iron/operators/gemv_int4/test.py index 603487ae..b798e5d2 100644 --- a/iron/operators/gemv_int4/test.py +++ b/iron/operators/gemv_int4/test.py @@ -16,9 +16,13 @@ def get_params(): params_list = [ # (M, K, num_aie_columns, tile_size_input, tile_size_output, group_size) - (2048, 2048, 4, 1, 512, 32), # Basic - (8192, 2048, 4, 1, 2048, 32), # Llama down_proj - (2048, 8192, 4, 1, 512, 32), # Llama up_proj + (2048, 2048, 4, 1, 512, 32), # Basic, 4 cols + (8192, 2048, 4, 1, 2048, 32), # Llama down_proj, 4 cols + (2048, 8192, 4, 1, 512, 32), # Llama up_proj, 4 cols + (2048, 8192, 8, 1, 256, 32), # Llama up_proj, 8 cols + (8192, 2048, 8, 1, 1024, 32), # Llama down_proj, 8 cols + (2048, 8192, 4, 4, 512, 32), # tsi=4 for better amortization + (8192, 2048, 4, 4, 2048, 32), # tsi=4 ] params = [] From 6c483213ff12618b8aab274a102474f771eaf985 Mon Sep 17 00:00:00 2001 From: albiol2004 Date: Thu, 9 Apr 2026 19:22:45 +0200 Subject: [PATCH 3/5] double-pump kernel + K at compile time --- aie_kernels/generic/fused_dequant_gemv.cc | 130 ++++++++++++++++++---- iron/operators/gemv_int4/design.py | 8 +- iron/operators/gemv_int4/op.py | 3 +- 3 files changed, 113 insertions(+), 28 deletions(-) diff --git a/aie_kernels/generic/fused_dequant_gemv.cc b/aie_kernels/generic/fused_dequant_gemv.cc index 1b9da296..9eeebb2e 100644 --- a/aie_kernels/generic/fused_dequant_gemv.cc +++ b/aie_kernels/generic/fused_dequant_gemv.cc @@ -14,6 +14,10 @@ // // The unpack chain matches the existing dequant kernel (expand.cc): // uint4 -> uint8 (aie::unpack) -> uint16 (aie::unpack) -> bf16 (aie::to_float) +// +// Optimization: double-pump — process 2 groups (64 elements) per iteration +// so the compiler can interleave the two independent unpack chains, hiding +// the dequant latency behind computation. #define NOCPP @@ -23,11 +27,33 @@ #include #include +// Dequant one 32-element block: load uint4 → unpack → scale → return bf16 +template +inline __attribute__((always_inline)) aie::vector +dequant_block(const uint4 *&weights, const bfloat16 *&b_ptr, + aie::vector sf_broadcast, + aie::accum &acc) { + aie::vector I0 = aie::load_v(weights); + weights += block_size / 2; + + aie::vector as_int8 = aie::unpack(I0); + aie::vector as_int16 = aie::unpack(as_int8); + aie::vector as_bf16 = aie::to_float(as_int16, 0); + aie::vector w_dequant = + aie::mul(as_bf16, sf_broadcast).template to_vector(); + + aie::vector b_vec = aie::load_v(b_ptr); + b_ptr += block_size; + + acc = aie::mac(acc, w_dequant, b_vec); + return w_dequant; +} + // block_size: dequant vector width (must be 32 for aie::unpack) // G: group size (compile-time for pipelining, must be multiple of block_size) -template +// DK: K dimension (compile-time for loop count optimization) +template void fused_dequant_matvec(uint32_t m, - uint32_t k, const uint8_t *__restrict a_in, const bfloat16 *__restrict b_in, bfloat16 *__restrict c_out) @@ -35,44 +61,99 @@ void fused_dequant_matvec(uint32_t m, static_assert(block_size == 32, "block_size must be 32 to match dequant vector width"); static_assert(G % block_size == 0, "group_size must be a multiple of block_size"); constexpr uint32_t blocks_per_group = G / block_size; + constexpr uint32_t groups_per_row = DK / G; + // For double-pump: process 2 groups per iteration when possible + constexpr bool can_double_pump = (groups_per_row >= 2) && (groups_per_row % 2 == 0); + constexpr uint32_t pump_groups = can_double_pump ? 2 : 1; + constexpr uint32_t loop_iters = groups_per_row / pump_groups; ::aie::set_rounding(aie::rounding_mode::conv_even); const uint4 *weights_packed = reinterpret_cast(a_in); - const uint8_t *scale_bytes = a_in + m * k / 2; + const uint8_t *scale_bytes = a_in + m * DK / 2; const bfloat16 *scales = reinterpret_cast(scale_bytes); - const uint32_t groups_per_row = k / G; - event0(); for (uint32_t row = 0; row < m; row++) { - const uint4 *row_weights = weights_packed + row * k / 2; + const uint4 *row_weights = weights_packed + row * DK / 2; const bfloat16 *row_scales = scales + row * groups_per_row; const bfloat16 *b_ptr = b_in; aie::accum acc = aie::zeros(); - for (uint32_t g = 0; g < groups_per_row; g++) - chess_prepare_for_pipelining chess_loop_range(1, ) - { - bfloat16 sf = row_scales[g]; - aie::vector sf_broadcast = aie::broadcast(sf); - - AIE_LOOP_MIN_ITERATION_COUNT(1) - for (uint32_t blk = 0; blk < blocks_per_group; blk++) { - aie::vector I0 = aie::load_v(row_weights); + if constexpr (can_double_pump && blocks_per_group == 1) { + // Optimized path: 2 groups per iteration, 1 block per group + // Two independent unpack chains for the compiler to interleave. + AIE_LOOP_MIN_ITERATION_COUNT(loop_iters) + for (uint32_t g = 0; g < groups_per_row; g += 2) + chess_prepare_for_pipelining + { + // --- Chain A: group g --- + bfloat16 sf_a = row_scales[g]; + aie::vector sf_a_bc = + aie::broadcast(sf_a); + + aie::vector I0_a = aie::load_v(row_weights); row_weights += block_size / 2; - aie::vector as_int8 = aie::unpack(I0); - aie::vector as_int16 = aie::unpack(as_int8); - aie::vector as_bf16 = aie::to_float(as_int16, 0); + // --- Chain B: group g+1 (interleaved) --- + bfloat16 sf_b = row_scales[g + 1]; + aie::vector sf_b_bc = + aie::broadcast(sf_b); - aie::vector w_dequant = aie::mul(as_bf16, sf_broadcast).template to_vector(); + aie::vector I0_b = aie::load_v(row_weights); + row_weights += block_size / 2; - aie::vector b_vec = aie::load_v(b_ptr); + // Unpack chain A + aie::vector a8_a = aie::unpack(I0_a); + aie::vector a16_a = aie::unpack(a8_a); + aie::vector abf_a = aie::to_float(a16_a, 0); + aie::vector w_a = + aie::mul(abf_a, sf_a_bc).template to_vector(); + + // Unpack chain B + aie::vector a8_b = aie::unpack(I0_b); + aie::vector a16_b = aie::unpack(a8_b); + aie::vector abf_b = aie::to_float(a16_b, 0); + aie::vector w_b = + aie::mul(abf_b, sf_b_bc).template to_vector(); + + // Load activation vectors and MAC + aie::vector b_a = aie::load_v(b_ptr); b_ptr += block_size; + acc = aie::mac(acc, w_a, b_a); - acc = aie::mac(acc, w_dequant, b_vec); + aie::vector b_b = aie::load_v(b_ptr); + b_ptr += block_size; + acc = aie::mac(acc, w_b, b_b); + } + } else { + // Generic path: 1 group per iteration + AIE_LOOP_MIN_ITERATION_COUNT(loop_iters) + for (uint32_t g = 0; g < groups_per_row; g++) + chess_prepare_for_pipelining + { + bfloat16 sf = row_scales[g]; + aie::vector sf_broadcast = + aie::broadcast(sf); + + AIE_LOOP_MIN_ITERATION_COUNT(blocks_per_group) + for (uint32_t blk = 0; blk < blocks_per_group; blk++) { + aie::vector I0 = aie::load_v(row_weights); + row_weights += block_size / 2; + + aie::vector as_int8 = aie::unpack(I0); + aie::vector as_int16 = aie::unpack(as_int8); + aie::vector as_bf16 = + aie::to_float(as_int16, 0); + aie::vector w_dequant = + aie::mul(as_bf16, sf_broadcast).template to_vector(); + + aie::vector b_vec = aie::load_v(b_ptr); + b_ptr += block_size; + + acc = aie::mac(acc, w_dequant, b_vec); + } } } @@ -86,17 +167,20 @@ void fused_dequant_matvec(uint32_t m, #define GROUP_SIZE 32 #endif +#ifndef DIM_K +#define DIM_K 2048 +#endif + extern "C" { void fused_dequant_matvec_bf16(uint32_t m, - uint32_t k, uint32_t row_offset, const uint8_t *__restrict a_in, const bfloat16 *__restrict b_in, bfloat16 *__restrict c_out) { c_out += row_offset; - fused_dequant_matvec<32, GROUP_SIZE>(m, k, a_in, b_in, c_out); + fused_dequant_matvec<32, GROUP_SIZE, DIM_K>(m, a_in, b_in, c_out); } } // extern "C" diff --git a/iron/operators/gemv_int4/design.py b/iron/operators/gemv_int4/design.py index cb6829ab..0fae355d 100644 --- a/iron/operators/gemv_int4/design.py +++ b/iron/operators/gemv_int4/design.py @@ -71,11 +71,11 @@ def my_fused_dequant_matvec( L3_C_ty = np.ndarray[(M,), dtype_out] # --- Kernel declaration --- - # group_size is compile-time via -DGROUP_SIZE, not a runtime parameter. + # K and group_size are compile-time via -DDIM_K/-DGROUP_SIZE. fused_matvec = Kernel( "fused_dequant_matvec_bf16", - f"fused_dequant_gemv_g{group_size}.o", - [np.int32, np.int32, np.int32, L1_A_ty, L1_B_ty, L1_C_ty], + f"fused_dequant_gemv_{K}k_g{group_size}.o", + [np.int32, np.int32, L1_A_ty, L1_B_ty, L1_C_ty], ) # --- ObjectFIFOs --- @@ -102,7 +102,7 @@ def core_body(A_L3L1_fifo, B_L3L1_fifo, C_L1L3_fifo, fused_matvec_fn): output_row_offset = j_i32 * m_input a = A_L3L1_fifo.acquire(1) fused_matvec_fn( - m_input, K, output_row_offset, a, b, c + m_input, output_row_offset, a, b, c ) A_L3L1_fifo.release(1) C_L1L3_fifo.release(1) diff --git a/iron/operators/gemv_int4/op.py b/iron/operators/gemv_int4/op.py index 8eb88741..a419b7f3 100644 --- a/iron/operators/gemv_int4/op.py +++ b/iron/operators/gemv_int4/op.py @@ -88,7 +88,7 @@ def get_mlir_artifact(self): def get_kernel_artifacts(self): return [ KernelObjectArtifact( - f"fused_dequant_gemv_g{self.group_size}.o", + f"fused_dequant_gemv_{self.K}k_g{self.group_size}.o", dependencies=[ SourceArtifact( self.context.base_dir @@ -98,6 +98,7 @@ def get_kernel_artifacts(self): ) ], extra_flags=[ + f"-DDIM_K={self.K}", f"-DGROUP_SIZE={self.group_size}", ], ), From 1db9c310c104a6797399076ffb65306dbb15c9a1 Mon Sep 17 00:00:00 2001 From: albiol2004 Date: Thu, 9 Apr 2026 19:43:46 +0200 Subject: [PATCH 4/5] Add fused INT4 dequant-GEMV operator Fused INT4 weight dequantization + matrix-vector multiplication in a single kernel pass. Loads packed uint4 weights from DDR, dequantizes in-register using the aie::unpack chain, and MACs with bf16 activation vector, 4x DDR bandwidth reduction vs bf16 GEMV. Kernel optimizations: - Compile-time GROUP_SIZE and DIM_K for loop count optimization - Double-pump: processes 2 groups (64 elements) per iteration, giving the compiler two independent unpack chains to interleave - AIE_PREPARE_FOR_PIPELINING and AIE_LOOP_MIN_ITERATION_COUNT hints Tested on AMD Ryzen AI 9 HX 370 (NPU2, 8 columns): - 2048x8192 (Llama up_proj): 561 us, 16.9 GB/s effective bandwidth - 8192x2048 (Llama down_proj): 665 us, 14.2 GB/s effective bandwidth - Integration tested with real Llama 3.2 1B weights (cosine sim >0.999) --- aie_kernels/generic/fused_dequant_gemv.cc | 26 ++--------------------- iron/operators/gemv_int4/__init__.py | 2 -- iron/operators/gemv_int4/op.py | 2 +- iron/operators/gemv_int4/test.py | 2 ++ 4 files changed, 5 insertions(+), 27 deletions(-) delete mode 100644 iron/operators/gemv_int4/__init__.py diff --git a/aie_kernels/generic/fused_dequant_gemv.cc b/aie_kernels/generic/fused_dequant_gemv.cc index 9eeebb2e..fc911801 100644 --- a/aie_kernels/generic/fused_dequant_gemv.cc +++ b/aie_kernels/generic/fused_dequant_gemv.cc @@ -27,28 +27,6 @@ #include #include -// Dequant one 32-element block: load uint4 → unpack → scale → return bf16 -template -inline __attribute__((always_inline)) aie::vector -dequant_block(const uint4 *&weights, const bfloat16 *&b_ptr, - aie::vector sf_broadcast, - aie::accum &acc) { - aie::vector I0 = aie::load_v(weights); - weights += block_size / 2; - - aie::vector as_int8 = aie::unpack(I0); - aie::vector as_int16 = aie::unpack(as_int8); - aie::vector as_bf16 = aie::to_float(as_int16, 0); - aie::vector w_dequant = - aie::mul(as_bf16, sf_broadcast).template to_vector(); - - aie::vector b_vec = aie::load_v(b_ptr); - b_ptr += block_size; - - acc = aie::mac(acc, w_dequant, b_vec); - return w_dequant; -} - // block_size: dequant vector width (must be 32 for aie::unpack) // G: group size (compile-time for pipelining, must be multiple of block_size) // DK: K dimension (compile-time for loop count optimization) @@ -86,7 +64,7 @@ void fused_dequant_matvec(uint32_t m, // Two independent unpack chains for the compiler to interleave. AIE_LOOP_MIN_ITERATION_COUNT(loop_iters) for (uint32_t g = 0; g < groups_per_row; g += 2) - chess_prepare_for_pipelining + AIE_PREPARE_FOR_PIPELINING { // --- Chain A: group g --- bfloat16 sf_a = row_scales[g]; @@ -131,7 +109,7 @@ void fused_dequant_matvec(uint32_t m, // Generic path: 1 group per iteration AIE_LOOP_MIN_ITERATION_COUNT(loop_iters) for (uint32_t g = 0; g < groups_per_row; g++) - chess_prepare_for_pipelining + AIE_PREPARE_FOR_PIPELINING { bfloat16 sf = row_scales[g]; aie::vector sf_broadcast = diff --git a/iron/operators/gemv_int4/__init__.py b/iron/operators/gemv_int4/__init__.py deleted file mode 100644 index c8ac4702..00000000 --- a/iron/operators/gemv_int4/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 diff --git a/iron/operators/gemv_int4/op.py b/iron/operators/gemv_int4/op.py index a419b7f3..89973427 100644 --- a/iron/operators/gemv_int4/op.py +++ b/iron/operators/gemv_int4/op.py @@ -27,7 +27,7 @@ class GEMVInt4(MLIROperator): num_aie_columns: int = 4 tile_size_input: int = 1 tile_size_output: int | None = None - group_size: int = field(default=32, repr=False) + group_size: int = 32 context: object = field(default=None, repr=False) _name_aliases: ClassVar[Dict[str, str]] = { diff --git a/iron/operators/gemv_int4/test.py b/iron/operators/gemv_int4/test.py index b798e5d2..6c8b32b2 100644 --- a/iron/operators/gemv_int4/test.py +++ b/iron/operators/gemv_int4/test.py @@ -75,6 +75,8 @@ def test_gemv_int4( } output_buffers = {"output": golden_ref["output"]} + # Tolerances are looser than bf16 GEMV (rel_tol=0.04, abs_tol=1e-3) because + # INT4 quantization introduces significant per-group rounding error. errors, latency_us, bandwidth_gbps = run_test( operator, input_buffers, output_buffers, rel_tol=0.07, abs_tol=0.7 ) From 32571c64049f1e2d91414e98f3396d390aa062da Mon Sep 17 00:00:00 2001 From: albiol2004 Date: Fri, 10 Apr 2026 17:46:36 +0200 Subject: [PATCH 5/5] ran clang formatter --- aie_kernels/generic/fused_dequant_gemv.cc | 118 ++++++++++------------ 1 file changed, 56 insertions(+), 62 deletions(-) diff --git a/aie_kernels/generic/fused_dequant_gemv.cc b/aie_kernels/generic/fused_dequant_gemv.cc index fc911801..be954cec 100644 --- a/aie_kernels/generic/fused_dequant_gemv.cc +++ b/aie_kernels/generic/fused_dequant_gemv.cc @@ -65,74 +65,68 @@ void fused_dequant_matvec(uint32_t m, AIE_LOOP_MIN_ITERATION_COUNT(loop_iters) for (uint32_t g = 0; g < groups_per_row; g += 2) AIE_PREPARE_FOR_PIPELINING - { - // --- Chain A: group g --- - bfloat16 sf_a = row_scales[g]; - aie::vector sf_a_bc = - aie::broadcast(sf_a); - - aie::vector I0_a = aie::load_v(row_weights); - row_weights += block_size / 2; - - // --- Chain B: group g+1 (interleaved) --- - bfloat16 sf_b = row_scales[g + 1]; - aie::vector sf_b_bc = - aie::broadcast(sf_b); - - aie::vector I0_b = aie::load_v(row_weights); - row_weights += block_size / 2; - - // Unpack chain A - aie::vector a8_a = aie::unpack(I0_a); - aie::vector a16_a = aie::unpack(a8_a); - aie::vector abf_a = aie::to_float(a16_a, 0); - aie::vector w_a = - aie::mul(abf_a, sf_a_bc).template to_vector(); - - // Unpack chain B - aie::vector a8_b = aie::unpack(I0_b); - aie::vector a16_b = aie::unpack(a8_b); - aie::vector abf_b = aie::to_float(a16_b, 0); - aie::vector w_b = - aie::mul(abf_b, sf_b_bc).template to_vector(); - - // Load activation vectors and MAC - aie::vector b_a = aie::load_v(b_ptr); - b_ptr += block_size; - acc = aie::mac(acc, w_a, b_a); - - aie::vector b_b = aie::load_v(b_ptr); - b_ptr += block_size; - acc = aie::mac(acc, w_b, b_b); - } - } else { - // Generic path: 1 group per iteration - AIE_LOOP_MIN_ITERATION_COUNT(loop_iters) - for (uint32_t g = 0; g < groups_per_row; g++) - AIE_PREPARE_FOR_PIPELINING - { - bfloat16 sf = row_scales[g]; - aie::vector sf_broadcast = - aie::broadcast(sf); - - AIE_LOOP_MIN_ITERATION_COUNT(blocks_per_group) - for (uint32_t blk = 0; blk < blocks_per_group; blk++) { - aie::vector I0 = aie::load_v(row_weights); + { + // --- Chain A: group g --- + bfloat16 sf_a = row_scales[g]; + aie::vector sf_a_bc = aie::broadcast(sf_a); + + aie::vector I0_a = aie::load_v(row_weights); + row_weights += block_size / 2; + + // --- Chain B: group g+1 (interleaved) --- + bfloat16 sf_b = row_scales[g + 1]; + aie::vector sf_b_bc = aie::broadcast(sf_b); + + aie::vector I0_b = aie::load_v(row_weights); row_weights += block_size / 2; - aie::vector as_int8 = aie::unpack(I0); - aie::vector as_int16 = aie::unpack(as_int8); - aie::vector as_bf16 = - aie::to_float(as_int16, 0); - aie::vector w_dequant = - aie::mul(as_bf16, sf_broadcast).template to_vector(); + // Unpack chain A + aie::vector a8_a = aie::unpack(I0_a); + aie::vector a16_a = aie::unpack(a8_a); + aie::vector abf_a = aie::to_float(a16_a, 0); + aie::vector w_a = aie::mul(abf_a, sf_a_bc).template to_vector(); - aie::vector b_vec = aie::load_v(b_ptr); + // Unpack chain B + aie::vector a8_b = aie::unpack(I0_b); + aie::vector a16_b = aie::unpack(a8_b); + aie::vector abf_b = aie::to_float(a16_b, 0); + aie::vector w_b = aie::mul(abf_b, sf_b_bc).template to_vector(); + + // Load activation vectors and MAC + aie::vector b_a = aie::load_v(b_ptr); b_ptr += block_size; + acc = aie::mac(acc, w_a, b_a); - acc = aie::mac(acc, w_dequant, b_vec); + aie::vector b_b = aie::load_v(b_ptr); + b_ptr += block_size; + acc = aie::mac(acc, w_b, b_b); + } + } else { + // Generic path: 1 group per iteration + AIE_LOOP_MIN_ITERATION_COUNT(loop_iters) + for (uint32_t g = 0; g < groups_per_row; g++) + AIE_PREPARE_FOR_PIPELINING + { + bfloat16 sf = row_scales[g]; + aie::vector sf_broadcast = aie::broadcast(sf); + + AIE_LOOP_MIN_ITERATION_COUNT(blocks_per_group) + for (uint32_t blk = 0; blk < blocks_per_group; blk++) { + aie::vector I0 = aie::load_v(row_weights); + row_weights += block_size / 2; + + aie::vector as_int8 = aie::unpack(I0); + aie::vector as_int16 = aie::unpack(as_int8); + aie::vector as_bf16 = aie::to_float(as_int16, 0); + aie::vector w_dequant = + aie::mul(as_bf16, sf_broadcast).template to_vector(); + + aie::vector b_vec = aie::load_v(b_ptr); + b_ptr += block_size; + + acc = aie::mac(acc, w_dequant, b_vec); + } } - } } *c_out = static_cast(aie::reduce_add(acc.template to_vector()));