Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion aie_kernels/aie2p/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <aie_api/aie.hpp>
#include <stdint.h>
#include <math.h>

#define SM_VEC_LEN 64 // 32
#define log2e 1.4453125 // 1.44269504089
Expand Down Expand Up @@ -30,7 +31,7 @@ void softmax_simple_bf16(bfloat16 *restrict input_vector, bfloat16 *restrict out
aie::vector<bfloat16, SM_VEC_LEN> in_elems, exp_val, input_bf16, log2e_vec, max_val_vec;
aie::accum<accfloat, SM_VEC_LEN> out_vals, exp_val_accum, scaled_accum, exp_in_accum;

float max_val = 0;
float max_val = -INFINITY;
float accum_exp_val = 0;
float running_max = 0;
bfloat16 col_sum_inv;
Expand Down
1 change: 1 addition & 0 deletions iron/common/compilation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,7 @@ def compile(self, graph):
str(self.aiecc_path),
"-v",
"-j1",
"--dynamic-objFifos",
"--no-compile-host",
"--no-xchesscc",
"--no-xbridge",
Expand Down
15 changes: 6 additions & 9 deletions iron/common/fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import ml_dtypes
import pyxrt
import ctypes
import time
from . import compilation as comp
from .base import AIEOperatorBase, MLIROperator
from .utils import XRTSubBuffer
Expand Down Expand Up @@ -42,8 +43,7 @@ def get_kernel_artifacts(self):
"""Collect all kernel artifacts from child operators.

Returns:
List of KernelObjectArtifact instances from all unique child operators,
with filenames and symbol prefixes disambiguated per operator index.
List of KernelObjectArtifact instances from all unique child operators.
"""
kernel_artifacts = []
seen: dict[int, object] = {}
Expand All @@ -52,9 +52,6 @@ def get_kernel_artifacts(self):
]
for idx, op in enumerate(unique_operators):
objs = op.get_kernel_artifacts()
for obj in objs:
obj.filename = f"op{idx}_{obj.filename}"
obj.prefix_symbols = f"op{idx}_"
kernel_artifacts.extend(objs)
return kernel_artifacts

Expand Down Expand Up @@ -82,8 +79,6 @@ def get_mlir_artifact(self):
]
for idx, op in enumerate(unique_operators):
mlir_artifact = op.get_mlir_artifact()
if len(op.get_kernel_artifacts()) > 0:
mlir_artifact.generator.kwargs["func_prefix"] = f"op{idx}_"
op_name = f"op{idx}_{op.__class__.__name__}"
op_names[id(op)] = op_name
operator_mlir_map[op_name] = mlir_artifact
Expand Down Expand Up @@ -290,8 +285,10 @@ def __call__(self, *args):
for i, arg in enumerate(args):
assert isinstance(arg, pyxrt.bo), f"Argument {i} is not a pyxrt.bo"
run.set_arg(i, arg)
t0 = time.perf_counter()
run.start()
ret_code = run.wait()
self.last_elapsed = time.perf_counter() - t0
if ret_code != pyxrt.ert_cmd_state.ERT_CMD_STATE_COMPLETED:
raise RuntimeError(f"Kernel execution failed with return code {ret_code}")

Expand Down Expand Up @@ -371,10 +368,10 @@ def get_buffer(self, buffer_name):
return sub_buffer

def __call__(self):
self.input_buffer.to("npu")
self.input_buffer._sync_to_device()
super().__call__(
self.input_buffer.buffer_object(),
self.output_buffer.buffer_object(),
self.scratch_buffer.buffer_object(),
)
self.output_buffer.to("cpu")
self.output_buffer._sync_from_device()
6 changes: 4 additions & 2 deletions iron/operators/gemm/design.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def my_matmul(
gemm_object,
[C_l1_ty_internal],
)
matmul_func_name = f"matmul{scalar_suffix}_{dtype_in_str}_f32"
matmul_func_name = f"{func_prefix}matmul{scalar_suffix}_{dtype_in_str}_f32"
matmul_kernel = Kernel(
matmul_func_name,
gemm_object,
Expand All @@ -314,7 +314,9 @@ def my_matmul(
gemm_object,
[C_l1_ty],
)
matmul_func_name = f"matmul{scalar_suffix}_{dtype_in_str}_{dtype_out_str}"
matmul_func_name = (
f"{func_prefix}matmul{scalar_suffix}_{dtype_in_str}_{dtype_out_str}"
)
matmul_kernel = Kernel(
matmul_func_name,
gemm_object,
Expand Down
2 changes: 2 additions & 0 deletions iron/operators/mha_prefill_lxl_sd/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# SPDX-FileCopyrightText: Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
Loading
Loading