From 513d449ca1a90a8cb249241ad16de4c331e1a756 Mon Sep 17 00:00:00 2001 From: XYenChi Date: Fri, 17 Apr 2026 18:47:31 +0800 Subject: [PATCH 1/4] [MLIR] Fix MLIR-JIT for RISC-V (#1) --- mlir/lib/ExecutionEngine/ExecutionEngine.cpp | 66 +++++++++++++++----- mlir/lib/ExecutionEngine/JitRunner.cpp | 6 ++ 2 files changed, 55 insertions(+), 17 deletions(-) diff --git a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp index 9b16e09124aa3..f9fa904a30ee8 100644 --- a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp @@ -10,19 +10,21 @@ // JIT engine. // //===----------------------------------------------------------------------===// -#include "mlir/ExecutionEngine/ExecutionEngine.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/ExecutionEngine/ExecutionEngine.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Support/FileUtilities.h" #include "mlir/Target/LLVMIR/Export.h" #include "llvm/ExecutionEngine/JITEventListener.h" +#include "llvm/ExecutionEngine/JITLink/JITLinkMemoryManager.h" #include "llvm/ExecutionEngine/ObjectCache.h" #include "llvm/ExecutionEngine/Orc/CompileUtils.h" #include "llvm/ExecutionEngine/Orc/ExecutionUtils.h" #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" #include "llvm/ExecutionEngine/Orc/IRTransformLayer.h" #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" +#include "llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h" #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" #include "llvm/IR/IRBuilder.h" #include "llvm/MC/TargetRegistry.h" @@ -314,27 +316,57 @@ ExecutionEngine::create(Operation *m, const ExecutionEngineOptions &options, auto objectLinkingLayerCreator = [&](ExecutionSession &session) { // Needed to respect AArch64 ABI requirements on the distance between // TEXT and GOT sections. - bool reserveAlloc = llvmModule->getTargetTriple().isAArch64(); - auto objectLayer = std::make_unique( - session, [sectionMemoryMapper = options.sectionMemoryMapper, - reserveAlloc](const MemoryBuffer &) { - return std::make_unique(sectionMemoryMapper, - reserveAlloc); - }); - - // Register JIT event listeners if they are enabled. - if (engine->gdbListener) - objectLayer->registerJITEventListener(*engine->gdbListener); - if (engine->perfListener) - objectLayer->registerJITEventListener(*engine->perfListener); + + // Check if we should use ObjectLinkingLayer (JITLink) + // JITLink supports modern architectures like RISC-V, AArch64 + // RuntimeDyld is older and provides better compatibility with legacy + // platforms + + // Decide which layer to use + bool useJITLink = llvmModule->getTargetTriple().isAArch64() || + llvmModule->getTargetTriple().isRISCV(); + + std::unique_ptr objectLayer; + + if (useJITLink) { + // JITLink path + objectLayer = std::make_unique(session); + + LLVM_DEBUG(llvm::dbgs() << "Using ObjectLinkingLayer (JITLink)\n"); + + } else { + // RuntimeDyld path + auto rtDyldLayer = std::make_unique( + session, + [sectionMemoryMapper = + options.sectionMemoryMapper](const llvm::MemoryBuffer &) + -> std::unique_ptr { + return std::make_unique(sectionMemoryMapper); + }); + + // Only RTDyld supports listener + if (engine->gdbListener) + rtDyldLayer->registerJITEventListener(*engine->gdbListener); + + if (engine->perfListener) + rtDyldLayer->registerJITEventListener(*engine->perfListener); + + LLVM_DEBUG(llvm::dbgs() << "Using RTDyldObjectLinkingLayer\n"); + + // Upcast + objectLayer = std::move(rtDyldLayer); + } // COFF format binaries (Windows) need special handling to deal with // exported symbol visibility. // cf llvm/lib/ExecutionEngine/Orc/LLJIT.cpp LLJIT::createObjectLinkingLayer const llvm::Triple &targetTriple = llvmModule->getTargetTriple(); - if (targetTriple.isOSBinFormatCOFF()) { - objectLayer->setOverrideObjectFlagsWithResponsibilityFlags(true); - objectLayer->setAutoClaimResponsibilityForObjectSymbols(true); + if (!useJITLink && targetTriple.isOSBinFormatCOFF()) { + if (auto *rtDyldLayer = dyn_cast( + objectLayer.get())) { + rtDyldLayer->setOverrideObjectFlagsWithResponsibilityFlags(true); + rtDyldLayer->setAutoClaimResponsibilityForObjectSymbols(true); + } } // Resolve symbols from shared libraries. diff --git a/mlir/lib/ExecutionEngine/JitRunner.cpp b/mlir/lib/ExecutionEngine/JitRunner.cpp index db0516533afcb..4b52669df9a6e 100644 --- a/mlir/lib/ExecutionEngine/JitRunner.cpp +++ b/mlir/lib/ExecutionEngine/JitRunner.cpp @@ -31,6 +31,7 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/LegacyPassNameParser.h" +#include "llvm/Support/CodeGen.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FileUtilities.h" @@ -361,6 +362,11 @@ int mlir::JitRunnerMain(int argc, char **argv, const DialectRegistry ®istry, tmBuilderOrError->getTargetTriple().setArchName(options.mArch); } + if (tmBuilderOrError->getTargetTriple().isRISCV()){ + tmBuilderOrError->setRelocationModel(llvm::Reloc::PIC_); + tmBuilderOrError->setCodeModel(llvm::CodeModel::Medium); + } + // Build TargetMachine auto tmOrError = tmBuilderOrError->createTargetMachine(); From 87438ea2821df6699fb7b68f52c1a931d4f4bca3 Mon Sep 17 00:00:00 2001 From: XYenChi Date: Sat, 18 Apr 2026 21:33:50 +0800 Subject: [PATCH 2/4] Ignore the sign of nan for RISC-V (#3) --- mlir/test/mlir-runner/test-expand-math-approx.mlir | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/test/mlir-runner/test-expand-math-approx.mlir b/mlir/test/mlir-runner/test-expand-math-approx.mlir index 06b3171a2349e..5ba0f7106e101 100644 --- a/mlir/test/mlir-runner/test-expand-math-approx.mlir +++ b/mlir/test/mlir-runner/test-expand-math-approx.mlir @@ -233,12 +233,12 @@ func.func @powf() { %g_p = arith.constant 23598.0 : f64 call @func_powff64(%g, %g_p) : (f64, f64) -> () - // CHECK-NEXT: -nan + // CHECK-NEXT: {{-?}}nan %h = arith.constant 1.0 : f64 %h_p = arith.constant 0xfff0000001000000 : f64 call @func_powff64(%h, %h_p) : (f64, f64) -> () - // CHECK-NEXT: -nan + // CHECK-NEXT: {{-?}}nan %i = arith.constant 1.0 : f32 %i_p = arith.constant 0xffffffff : f32 call @func_powff32(%i, %i_p) : (f32, f32) -> () From d5f380f5361d5a12539d31823b80a0ae32fa699a Mon Sep 17 00:00:00 2001 From: XYenChi Date: Sat, 18 Apr 2026 23:41:27 +0800 Subject: [PATCH 3/4] [MLIR] Fix use-after-free in block walk, Shape canonicalization, and opsrun test (#4) Fix three issues found on RISC-V: 1. Fix use-after-free in block walk (Visitors.h) When walking blocks with ReverseDominanceIterator, the Traversal object returned by Iterator::makeIterable(region) was passed as a temporary directly to llvm::make_early_inc_range(). The temporary was destroyed at the end of the full expression while iterators from make_early_inc_range still referenced it. Store the result in a local variable with auto&& to extend the lifetime of the temporary and bind lvalue references for ForwardIterator. 2. Fix Shape dialect CstrBroadcastableOp fold and cast canonicalization Extend getShapeVec to look through tensor.cast operations so that folds can resolve shapes behind casts inserted by earlier canonicalization passes (e.g., ShapeOfOpToConstShapeOp). Add a new fold check in CstrBroadcastableOp::fold that recognizes broadcasting is trivially valid when at most one operand is non-scalar (resolved via getShapeVec). Rewrite CanonicalizeCastExtentTensorOperandsPattern to use modifyOpInPlace instead of replaceOpWithNewOp to avoid issues with op builder template instantiation. 3. Fix opsrun.py test invocation pattern for multithreaded_tests.py Convert direct test_foo() calls to run(test_foo) so that copy_and_update in multithreaded_tests.py properly strips them during import, preventing JIT execution at module load time which crashes on RISC-V due to R_RISCV_HI20 relocation range limits. --- mlir/include/mlir/IR/Visitors.h | 14 ++++--- mlir/lib/Dialect/Shape/IR/Shape.cpp | 41 ++++++++++++++----- .../integration/dialects/linalg/opsrun.py | 21 ++++++---- 3 files changed, 51 insertions(+), 25 deletions(-) diff --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h index 907f470c0d248..d0c123c3bc600 100644 --- a/mlir/include/mlir/IR/Visitors.h +++ b/mlir/include/mlir/IR/Visitors.h @@ -120,9 +120,10 @@ void walk(Operation *op, function_ref callback, WalkOrder order) { for (auto ®ion : Iterator::makeIterable(*op)) { // Early increment here in the case where the block is erased. - // PostOrderTraversal keeps state outside of iterators, so store it here. - auto &&It = Iterator::makeIterable(region); - for (auto &block : llvm::make_early_inc_range(It)) { + // Store the block range to ensure the iteratable (e.g., + // PostOrderTraversal) outlives the iterators of make_early_inc_range. + auto &&blockRange = Iterator::makeIterable(region); + for (auto &block : llvm::make_early_inc_range(blockRange)) { if (order == WalkOrder::PreOrder) callback(&block); for (auto &nestedOp : Iterator::makeIterable(block)) @@ -196,9 +197,10 @@ WalkResult walk(Operation *op, function_ref callback, WalkOrder order) { for (auto ®ion : Iterator::makeIterable(*op)) { // Early increment here in the case where the block is erased. - // PostOrderTraversal keeps state outside of iterators, so store it here. - auto &&It = Iterator::makeIterable(region); - for (auto &block : llvm::make_early_inc_range(It)) { + // Store the block range to ensure the iteratable (e.g., + // PostOrderTraversal) outlives the iterators of make_early_inc_range. + auto &&blockRange = Iterator::makeIterable(region); + for (auto &block : llvm::make_early_inc_range(blockRange)) { if (order == WalkOrder::PreOrder) { WalkResult result = callback(&block); if (result.wasSkipped()) diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 43a6816c6d863..ae0982830e138 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -48,6 +48,9 @@ bool shape::isExtentTensorType(Type type) { LogicalResult shape::getShapeVec(Value input, SmallVectorImpl &shapeValues) { + // Look through tensor.cast operations to find the underlying shape. + if (auto castOp = input.getDefiningOp()) + return getShapeVec(castOp.getSource(), shapeValues); if (auto inputOp = input.getDefiningOp()) { auto type = llvm::cast(inputOp.getArg().getType()); if (!type.hasRank()) @@ -799,27 +802,27 @@ struct CanonicalizeCastExtentTensorOperandsPattern LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { - // Canonicalize operands. + // Canonicalize operands by stripping information-losing tensor.cast ops. + SmallVector newOperands; bool anyChange = false; - auto canonicalizeOperand = [&](Value operand) -> Value { + for (Value operand : op.getShapes()) { if (auto castOp = operand.getDefiningOp()) { // Only eliminate the cast if it holds no shape information. - bool isInformationLoosingCast = - llvm::cast(castOp.getType()).isDynamicDim(0); - if (isInformationLoosingCast) { + if (llvm::cast(castOp.getType()).isDynamicDim(0)) { anyChange = true; - return castOp.getSource(); + newOperands.push_back(castOp.getSource()); + continue; } } - return operand; - }; - auto newOperands = - llvm::map_to_vector<8>(op.getOperands(), canonicalizeOperand); + newOperands.push_back(operand); + } // Rewrite op if any change required. if (!anyChange) return failure(); - rewriter.replaceOpWithNewOp(op, op->getResultTypes(), newOperands); + rewriter.modifyOpInPlace(op, [&]() { + op.getShapesMutable().assign(newOperands); + }); return success(); } }; @@ -1017,6 +1020,22 @@ OpFoldResult CstrBroadcastableOp::fold(FoldAdaptor adaptor) { }()) return BoolAttr::get(getContext(), true); + // No broadcasting is needed if all operands but one are scalar, using + // getShapeVec to look through tensor.cast and shape_of ops. + if ([&] { + bool nonScalarSeen = false; + for (auto shapeValue : getShapes()) { + SmallVector extents; + if (failed(getShapeVec(shapeValue, extents)) || !extents.empty()) { + if (nonScalarSeen) + return false; + nonScalarSeen = true; + } + } + return true; + }()) + return BoolAttr::get(getContext(), true); + // Because a failing witness result here represents an eventual assertion // failure, we do not replace it with a constant witness. return nullptr; diff --git a/mlir/test/python/integration/dialects/linalg/opsrun.py b/mlir/test/python/integration/dialects/linalg/opsrun.py index 8eff573f98ad3..665879acb5cc9 100644 --- a/mlir/test/python/integration/dialects/linalg/opsrun.py +++ b/mlir/test/python/integration/dialects/linalg/opsrun.py @@ -19,6 +19,11 @@ def log(*args): sys.stderr.flush() +def run(f): + f() + return f + + fill_boiler = """ func.func @main() -> i32 attributes {llvm.emit_c_interface} { %O0 = memref.alloc() : memref @@ -177,7 +182,7 @@ def fill_2d_on_buffers(value, out): # CHECK: RESULT: 6 -test_fill_builtin() +run(test_fill_builtin) def test_fill_generic(): @@ -211,7 +216,7 @@ def fill_2d_on_buffers(value, out): # CHECK: RESULT: 6 -test_fill_generic() +run(test_fill_generic) def test_fill_rng_builtin(): @@ -238,7 +243,7 @@ def fill_rng_on_buffers(min, max, seed, out): # CHECK: RESULT: -480 -test_fill_rng_builtin() +run(test_fill_rng_builtin) def test_fill_rng_generic(): @@ -265,7 +270,7 @@ def fill_rng_on_buffers(min, max, seed, out): # CHECK: RESULT: -480 -test_fill_rng_generic() +run(test_fill_rng_generic) def test_max_pooling_builtin(): @@ -299,7 +304,7 @@ def pooling_on_buffers(input, shape, output): # CHECK: RESULT: 42 -test_max_pooling_builtin() +run(test_max_pooling_builtin) def test_max_pooling_generic(): @@ -338,7 +343,7 @@ def pooling_on_buffers(input, shape, output): # CHECK: RESULT: 42 -test_max_pooling_generic() +run(test_max_pooling_generic) def test_min_pooling_builtin(): @@ -370,7 +375,7 @@ def pooling_on_buffers(input, shape, output): # CHECK: RESULT: -13 -test_min_pooling_builtin() +run(test_min_pooling_builtin) def test_min_pooling_generic(): @@ -404,4 +409,4 @@ def pooling_on_buffers(input, shape, output): # CHECK: RESULT: -13 -test_min_pooling_generic() +run(test_min_pooling_generic) From 5faa117ca7ed37389a4e5ff52fb09aed2d71c031 Mon Sep 17 00:00:00 2001 From: Mingzhu Yan Date: Mon, 20 Apr 2026 11:56:33 +0000 Subject: [PATCH 4/4] Add BuddyExt --- llvm/include/llvm/IR/IntrinsicsRISCV.td | 1 + .../llvm/IR/IntrinsicsRISCVBuddyExt.td | 407 +++++ .../RISCV/Disassembler/RISCVDisassembler.cpp | 14 + llvm/lib/Target/RISCV/RISCVBuddyExt.td | 132 ++ llvm/lib/Target/RISCV/RISCVInstrInfo.td | 1 + .../Target/RISCV/RISCVInstrInfoBuddyExt.td | 1492 +++++++++++++++++ 6 files changed, 2047 insertions(+) create mode 100644 llvm/include/llvm/IR/IntrinsicsRISCVBuddyExt.td create mode 100644 llvm/lib/Target/RISCV/RISCVBuddyExt.td create mode 100644 llvm/lib/Target/RISCV/RISCVInstrInfoBuddyExt.td diff --git a/llvm/include/llvm/IR/IntrinsicsRISCV.td b/llvm/include/llvm/IR/IntrinsicsRISCV.td index 32de1a10a4fc3..c02da5ccb4c0d 100644 --- a/llvm/include/llvm/IR/IntrinsicsRISCV.td +++ b/llvm/include/llvm/IR/IntrinsicsRISCV.td @@ -2017,3 +2017,4 @@ include "llvm/IR/IntrinsicsRISCVXsf.td" include "llvm/IR/IntrinsicsRISCVXCV.td" include "llvm/IR/IntrinsicsRISCVXAndes.td" include "llvm/IR/IntrinsicsRISCVXMIPS.td" +include "llvm/IR/IntrinsicsRISCVBuddyExt.td" diff --git a/llvm/include/llvm/IR/IntrinsicsRISCVBuddyExt.td b/llvm/include/llvm/IR/IntrinsicsRISCVBuddyExt.td new file mode 100644 index 0000000000000..496efec4c8d0e --- /dev/null +++ b/llvm/include/llvm/IR/IntrinsicsRISCVBuddyExt.td @@ -0,0 +1,407 @@ +//===- IntrinsicsRISCVBuddyExt.td -----------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This is the costum intrinsic definition file of RISC-V buddy extension. +// +//===----------------------------------------------------------------------===// +let TargetPrefix = "riscv" in +def int_riscv_mvin : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty],[]>; + +let TargetPrefix = "riscv" in +def int_riscv_mvin2 : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty],[]>; + +let TargetPrefix = "riscv" in +def int_riscv_mvin3 : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty],[]>; + +let TargetPrefix = "riscv" in +def int_riscv_mvout : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; + +let TargetPrefix = "riscv" in +def int_riscv_flush : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; + +let TargetPrefix = "riscv" in +def int_riscv_config : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; + +let TargetPrefix = "riscv" in +def int_riscv_preload : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; + +let TargetPrefix = "riscv" in +def int_riscv_compute_preloaded : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; + +let TargetPrefix = "riscv" in +def int_riscv_compute_accumulated : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; + +let TargetPrefix = "riscv" in +def int_riscv_loop_ws_config_bounds : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; + +let TargetPrefix = "riscv" in +def int_riscv_loop_ws_config_addrs_ab : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; + +let TargetPrefix = "riscv" in +def int_riscv_loop_ws_config_addrs_dc : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; + +let TargetPrefix = "riscv" in +def int_riscv_loop_ws_config_strides_ab : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; + +let TargetPrefix = "riscv" in +def int_riscv_loop_ws_config_strides_dc : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; + +let TargetPrefix = "riscv" in +def int_riscv_loop_ws : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; + +let TargetPrefix = "riscv" in +def int_riscv_loop_conv_ws : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; + +let TargetPrefix = "riscv" in +def int_riscv_loop_conv_ws_config1 : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; + +let TargetPrefix = "riscv" in +def int_riscv_loop_conv_ws_config2 : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; + +let TargetPrefix = "riscv" in +def int_riscv_loop_conv_ws_config3 : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; + +let TargetPrefix = "riscv" in +def int_riscv_loop_conv_ws_config4 : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; + +let TargetPrefix = "riscv" in +def int_riscv_loop_conv_ws_config5 : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; + +let TargetPrefix = "riscv" in +def int_riscv_loop_conv_ws_config6 : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; + +//===----------------------------------------------------------------------===// +// IME Extension Intrinsics +//===----------------------------------------------------------------------===// + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadot : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadotu : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadotsu : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadotus : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vfmadot : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadot1 : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadot1u : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadot1su : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadot1us : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadot2 : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadot2u : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadot2su : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadot2us : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadot3 : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadot3u : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadot3su : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadot3us : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadotn : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty, llvm_i64_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadotnu : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty, llvm_i64_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadotnsu : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty, llvm_i64_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadotnus : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty, llvm_i64_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vfmadot1 : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vfmadot2 : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vfmadot3 : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vfmadotn : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty, llvm_i64_ty], + [IntrNoMem]>; + +//===----------------------------------------------------------------------===// +// AME (RISC-V Matrix Extension) Intrinsics +//===----------------------------------------------------------------------===// + +// Matrix configuration intrinsics (with register) +let TargetPrefix = "riscv" in { + // msettype - Set matrix type configuration + def int_riscv_buddy_msettype : Intrinsic<[llvm_i64_ty], [llvm_i64_ty], [IntrNoMem]>; + + // msettilem - Set tile M dimension from register + def int_riscv_buddy_msettilem : Intrinsic<[llvm_i64_ty], [llvm_i64_ty], [IntrNoMem]>; + + // msettilen - Set tile N dimension from register + def int_riscv_buddy_msettilen : Intrinsic<[llvm_i64_ty], [llvm_i64_ty], [IntrNoMem]>; + + // msettilek - Set tile K dimension from register + def int_riscv_buddy_msettilek : Intrinsic<[llvm_i64_ty], [llvm_i64_ty], [IntrNoMem]>; +} + +// Matrix configuration intrinsics (with immediate) +let TargetPrefix = "riscv" in { + // msettilemi - Set tile M dimension with immediate + def int_riscv_buddy_msettilemi : Intrinsic<[], [llvm_i64_ty], + [IntrNoMem, IntrHasSideEffects, ImmArg>]>; + + // msettileni - Set tile N dimension with immediate + def int_riscv_buddy_msettileni : Intrinsic<[], [llvm_i64_ty], + [IntrNoMem, IntrHasSideEffects, ImmArg>]>; + + // msettileki - Set tile K dimension with immediate + def int_riscv_buddy_msettileki : Intrinsic<[], [llvm_i64_ty], + [IntrNoMem, IntrHasSideEffects, ImmArg>]>; +} + +// Matrix load intrinsics (load to tile register) +// Format: md = tile register index, base = address, stride = row byte stride +let TargetPrefix = "riscv" in { + // mlae32.m - Load 32-bit left matrix A to tile register + def int_riscv_buddy_mlae32_m : Intrinsic<[], + [llvm_i64_ty, llvm_ptr_ty, llvm_i64_ty], + [IntrReadMem, IntrHasSideEffects, ImmArg>]>; + + // mlae64.m - Load 64-bit left matrix A to tile register + def int_riscv_buddy_mlae64_m : Intrinsic<[], + [llvm_i64_ty, llvm_ptr_ty, llvm_i64_ty], + [IntrReadMem, IntrHasSideEffects, ImmArg>]>; + + // mlbe32.m - Load 32-bit right matrix B to tile register + def int_riscv_buddy_mlbe32_m : Intrinsic<[], + [llvm_i64_ty, llvm_ptr_ty, llvm_i64_ty], + [IntrReadMem, IntrHasSideEffects, ImmArg>]>; + + // mlbe64.m - Load 64-bit right matrix B to tile register + def int_riscv_buddy_mlbe64_m : Intrinsic<[], + [llvm_i64_ty, llvm_ptr_ty, llvm_i64_ty], + [IntrReadMem, IntrHasSideEffects, ImmArg>]>; + + // mlce32.m - Load 32-bit output matrix C to accumulator + def int_riscv_buddy_mlce32_m : Intrinsic<[], + [llvm_i64_ty, llvm_ptr_ty, llvm_i64_ty], + [IntrReadMem, IntrHasSideEffects, ImmArg>]>; + + // mlce64.m - Load 64-bit output matrix C to accumulator + def int_riscv_buddy_mlce64_m : Intrinsic<[], + [llvm_i64_ty, llvm_ptr_ty, llvm_i64_ty], + [IntrReadMem, IntrHasSideEffects, ImmArg>]>; +} + +// Matrix store intrinsics (store from tile register) +let TargetPrefix = "riscv" in { + // msce32.m - Store 32-bit output matrix C from accumulator + def int_riscv_buddy_msce32_m : Intrinsic<[], + [llvm_i64_ty, llvm_ptr_ty, llvm_i64_ty], + [IntrWriteMem, IntrHasSideEffects, ImmArg>]>; + + // msce64.m - Store 64-bit output matrix C from accumulator + def int_riscv_buddy_msce64_m : Intrinsic<[], + [llvm_i64_ty, llvm_ptr_ty, llvm_i64_ty], + [IntrWriteMem, IntrHasSideEffects, ImmArg>]>; +} + +// Matrix zero intrinsic +let TargetPrefix = "riscv" in { + def int_riscv_buddy_mzero : Intrinsic<[], [llvm_i64_ty], + [IntrNoMem, IntrHasSideEffects, ImmArg>]>; +} + +// Tile register matrix multiplication intrinsics (operate on tile registers) +let TargetPrefix = "riscv" in { + // mma.w.mm - int32 tile matrix multiply: md = md + ms1 x ms2 + def int_riscv_buddy_mma_w_mm_tile : Intrinsic<[], + [llvm_i64_ty, llvm_i64_ty, llvm_i64_ty], + [IntrHasSideEffects, ImmArg>, + ImmArg>, ImmArg>]>; + + // mma.dw.mm - int64 tile matrix multiply: md = md + ms1 x ms2 + def int_riscv_buddy_mma_dw_mm_tile : Intrinsic<[], + [llvm_i64_ty, llvm_i64_ty, llvm_i64_ty], + [IntrHasSideEffects, ImmArg>, + ImmArg>, ImmArg>]>; +} + +// Legacy matrix load/store intrinsics (for backward compatibility) +let TargetPrefix = "riscv" in { + // mlae - Load matrix A with element width + def int_riscv_buddy_mlae : Intrinsic<[llvm_anyvector_ty], + [llvm_ptr_ty, llvm_i64_ty, llvm_i64_ty], + [IntrReadMem]>; + + // mlbe - Load matrix B with element width + def int_riscv_buddy_mlbe : Intrinsic<[llvm_anyvector_ty], + [llvm_ptr_ty, llvm_i64_ty, llvm_i64_ty], + [IntrReadMem]>; + + // mlce - Load matrix C (accumulator) + def int_riscv_buddy_mlce : Intrinsic<[llvm_anyvector_ty], + [llvm_ptr_ty, llvm_i64_ty, llvm_i64_ty], + [IntrReadMem]>; + + // msce - Store matrix C (accumulator) + def int_riscv_buddy_msce : Intrinsic<[], + [llvm_anyvector_ty, llvm_ptr_ty, llvm_i64_ty, llvm_i64_ty], + [IntrWriteMem]>; +} + +// Signed integer matrix multiplication intrinsics +let TargetPrefix = "riscv" in { + // mqma.b.mm - int8 quad-widen matrix multiply (int8 x int8 -> int32) + def int_riscv_buddy_mqma_b_mm : Intrinsic<[], + [llvm_ptr_ty, llvm_ptr_ty, llvm_ptr_ty, + llvm_i64_ty, llvm_i64_ty, llvm_i64_ty], + []>; + + // mma.h.mm - int16 matrix multiply + def int_riscv_buddy_mma_h_mm : Intrinsic<[], + [llvm_ptr_ty, llvm_ptr_ty, llvm_ptr_ty, + llvm_i64_ty, llvm_i64_ty, llvm_i64_ty], + []>; + + // mma.w.mm - int32 matrix multiply + def int_riscv_buddy_mma_w_mm : Intrinsic<[], + [llvm_ptr_ty, llvm_ptr_ty, llvm_ptr_ty, + llvm_i64_ty, llvm_i64_ty, llvm_i64_ty], + []>; + + // mma.dw.mm - int64 matrix multiply + def int_riscv_buddy_mma_dw_mm : Intrinsic<[], + [llvm_ptr_ty, llvm_ptr_ty, llvm_ptr_ty, + llvm_i64_ty, llvm_i64_ty, llvm_i64_ty], + []>; + + // mwma.h.mm - int16 double-widen matrix multiply (int16 x int16 -> int32) + def int_riscv_buddy_mwma_h_mm : Intrinsic<[], + [llvm_ptr_ty, llvm_ptr_ty, llvm_ptr_ty, + llvm_i64_ty, llvm_i64_ty, llvm_i64_ty], + []>; +} + +// Unsigned integer matrix multiplication intrinsics +let TargetPrefix = "riscv" in { + // mqmau.b.mm - uint8 quad-widen matrix multiply + def int_riscv_buddy_mqmau_b_mm : Intrinsic<[], + [llvm_ptr_ty, llvm_ptr_ty, llvm_ptr_ty, + llvm_i64_ty, llvm_i64_ty, llvm_i64_ty], + []>; + + // mmau.h.mm - uint16 matrix multiply + def int_riscv_buddy_mmau_h_mm : Intrinsic<[], + [llvm_ptr_ty, llvm_ptr_ty, llvm_ptr_ty, + llvm_i64_ty, llvm_i64_ty, llvm_i64_ty], + []>; +} + +// Floating-point matrix multiplication intrinsics +let TargetPrefix = "riscv" in { + // mfma.f.mm - fp32 matrix multiply + def int_riscv_buddy_mfma_f_mm : Intrinsic<[], + [llvm_ptr_ty, llvm_ptr_ty, llvm_ptr_ty, + llvm_i64_ty, llvm_i64_ty, llvm_i64_ty], + []>; + + // mfma.hf.mm - fp16 matrix multiply + def int_riscv_buddy_mfma_hf_mm : Intrinsic<[], + [llvm_ptr_ty, llvm_ptr_ty, llvm_ptr_ty, + llvm_i64_ty, llvm_i64_ty, llvm_i64_ty], + []>; + + // mfwma.hf.mm - fp16 double-widen matrix multiply (fp16 x fp16 -> fp32) + def int_riscv_buddy_mfwma_hf_mm : Intrinsic<[], + [llvm_ptr_ty, llvm_ptr_ty, llvm_ptr_ty, + llvm_i64_ty, llvm_i64_ty, llvm_i64_ty], + []>; +} diff --git a/llvm/lib/Target/RISCV/Disassembler/RISCVDisassembler.cpp b/llvm/lib/Target/RISCV/Disassembler/RISCVDisassembler.cpp index 30a5d65a901d3..23fd4a133d863 100644 --- a/llvm/lib/Target/RISCV/Disassembler/RISCVDisassembler.cpp +++ b/llvm/lib/Target/RISCV/Disassembler/RISCVDisassembler.cpp @@ -93,6 +93,20 @@ static DecodeStatus DecodeSimpleRegisterClass(MCInst &Inst, uint32_t RegNo, return MCDisassembler::Success; } +static DecodeStatus DecodeTileRegRegisterClass(MCInst &Inst, uint32_t RegNo, + uint64_t Address, const MCDisassembler *Decoder) { + if (RegNo >= 8) + return MCDisassembler::Fail; + return DecodeGPRRegisterClass(Inst, RegNo, /*Address=*/0, Decoder); +} + +static DecodeStatus DecodeAccRegRegisterClass(MCInst &Inst, uint32_t RegNo, + uint64_t Address, const MCDisassembler *Decoder) { + if (RegNo >= 8) + return MCDisassembler::Fail; + return DecodeGPRRegisterClass(Inst, RegNo, /*Address=*/0, Decoder); +} + constexpr auto DecodeGPRRegisterClass = DecodeSimpleRegisterClass; diff --git a/llvm/lib/Target/RISCV/RISCVBuddyExt.td b/llvm/lib/Target/RISCV/RISCVBuddyExt.td new file mode 100644 index 0000000000000..192641e585398 --- /dev/null +++ b/llvm/lib/Target/RISCV/RISCVBuddyExt.td @@ -0,0 +1,132 @@ +//===- RISCVBuddyExt.td ---------------------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This is the top file for target definition of RISC-V buddy extension. +// +//===----------------------------------------------------------------------===// + +def FeatureBuddyExt + : SubtargetFeature<"buddyext", "HasBuddyExt", "true", + "'BuddyExt' (Buddy RISC-V Extension)">; +def HasBuddyExt : Predicate<"Subtarget->hasBuddyExt()">, + AssemblerPredicate<(all_of FeatureBuddyExt), + "'BuddyExt' (Buddy RISC-V Extension)">; + +//===----------------------------------------------------------------------===// +// AME (RISC-V Matrix Extension) Register Definitions +//===----------------------------------------------------------------------===// +// Reference: RISC-V Matrix Extension Specification +// +// Matrix Registers: +// - 8 Tile Registers (tr0-tr7): For input matrices A and B +// Each tile register has MLEN bits of state +// - 8 Accumulation Registers (acc0-acc7): For output/accumulation matrix C +// Each accumulation register has MLEN × AMUL bits of state +// +// AMUL (Accumulation MULtiplier): +// - Can be fractional (1/8, 1/4, 1/2) or integer (1, 2, 4, 8) +// - Determines the width ratio between acc and tr registers +// - For mmi8i32 (int8→int32 quad-widen), AMUL ≥ 4 +// +// Data Flow: +// Memory → tr (via mlae/mlbe) → acc (via mma/mwma/mqma) → Memory (via msce) +//===----------------------------------------------------------------------===// + +let Namespace = "RISCV" in { + +//===----------------------------------------------------------------------===// +// AME Tile Registers (tr0-tr7) +// Used for input matrices A and B +// Size: MLEN bits per register (hardware-defined) +//===----------------------------------------------------------------------===// + +// Base class for Tile Registers +class AMETileReg Enc, string n> : Register { + let HWEncoding{2-0} = Enc; + let HWEncoding{4-3} = 0b00; // Distinguish from accumulation registers +} + +// Define 8 Tile Registers: tr0-tr7 +def TR0 : AMETileReg<0, "tr0">; +def TR1 : AMETileReg<1, "tr1">; +def TR2 : AMETileReg<2, "tr2">; +def TR3 : AMETileReg<3, "tr3">; +def TR4 : AMETileReg<4, "tr4">; +def TR5 : AMETileReg<5, "tr5">; +def TR6 : AMETileReg<6, "tr6">; +def TR7 : AMETileReg<7, "tr7">; + +//===----------------------------------------------------------------------===// +// AME Accumulation Registers (acc0-acc7) +// Used for output/accumulation matrix C +// Size: MLEN × AMUL bits per register (hardware-defined) +// +// Note: AMUL can be: +// - Fractional (1/8, 1/4, 1/2): For C = A × Bᵀ mode with large K +// - Integer (1, 2, 4, 8): For widening operations +// * AMUL=4: Required for mmi8i32 (int8→int32 quad-widen) +// * AMUL=2: Required for mmi16i32 (int16→int32 double-widen) +// * AMUL=8: Required for mmi4i32 (int4→int32 oct-widen) +//===----------------------------------------------------------------------===// + +// Base class for Accumulation Registers +class AMEAccReg Enc, string n> : Register { + let HWEncoding{2-0} = Enc; + let HWEncoding{4-3} = 0b01; // Distinguish from tile registers +} + +// Define 8 Accumulation Registers: acc0-acc7 +def ACC0 : AMEAccReg<0, "acc0">; +def ACC1 : AMEAccReg<1, "acc1">; +def ACC2 : AMEAccReg<2, "acc2">; +def ACC3 : AMEAccReg<3, "acc3">; +def ACC4 : AMEAccReg<4, "acc4">; +def ACC5 : AMEAccReg<5, "acc5">; +def ACC6 : AMEAccReg<6, "acc6">; +def ACC7 : AMEAccReg<7, "acc7">; + +} // End Namespace = "RISCV" + +//===----------------------------------------------------------------------===// +// AME Register Classes +//===----------------------------------------------------------------------===// +// These register classes define the operand types for AME instructions +// +// Usage in instructions: +// - TileReg: For ms1, ms2 (source operands in multiplication) +// - AccReg: For md (destination/accumulator in multiplication) +// - TileReg: For load/store of input matrices (A, B) +// - AccReg: For load/store of output/accumulator (C) +//===----------------------------------------------------------------------===// + +// Tile Register class (tr0-tr7) +// Used for input operands in matrix multiplication +// Note: Size is set to 256 as a placeholder; actual size depends on MLEN +def TileReg : RegisterClass<"RISCV", [untyped], 256, + (add TR0, TR1, TR2, TR3, TR4, TR5, TR6, TR7)> { + let Size = 256; // Placeholder: actual MLEN is hardware-defined +} + +// Accumulation Register class (acc0-acc7) +// Used for output/accumulator in matrix multiplication +// Note: Size can be 256×AMUL where AMUL ∈ {1/8, 1/4, 1/2, 1, 2, 4, 8} +// We use 1024 as a reasonable upper bound (256 × 4 for int8→int32) +def AccReg : RegisterClass<"RISCV", [untyped], 1024, + (add ACC0, ACC1, ACC2, ACC3, ACC4, ACC5, ACC6, ACC7)> { + let Size = 1024; // Placeholder: actual MLEN×AMUL is hardware-defined +} + +include "RISCVInstrInfoBuddyExt.td" diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.td b/llvm/lib/Target/RISCV/RISCVInstrInfo.td index ef56275118f2e..10aca8d3ba67a 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.td @@ -2412,6 +2412,7 @@ include "RISCVInstrInfoXRivos.td" include "RISCVInstrInfoXAndes.td" include "RISCVInstrInfoXSpacemiT.td" include "RISCVInstrInfoXAIF.td" +include "RISCVBuddyExt.td" //===----------------------------------------------------------------------===// // Global ISel diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoBuddyExt.td b/llvm/lib/Target/RISCV/RISCVInstrInfoBuddyExt.td new file mode 100644 index 0000000000000..6861f0248a313 --- /dev/null +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoBuddyExt.td @@ -0,0 +1,1492 @@ +//===- RISCVInstrInfoBuddyExt.td ------------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This is the instruction information file of RISC-V buddy extension. +// +//===----------------------------------------------------------------------===// + +// include "llvm/IR/IntrinsicsRISCVBuddyExt.td" + +// Gemmini defines different values as func7 +// - https://github.com/ucb-bar/gemmini-rocc-tests/blob/e326e7c43457ff08669fe88edcaa395d846474d8/include/gemmini.h#L25 + +// Gemmini uses 0x3 (0b011) as func3 +// - https://github.com/IBM/rocc-software/blob/fddb795a0b52e82f8f4ce9ead9b1428440a62ab0/src/xcustom.h#L147 + +// Gemmini uses OPC_CUSTOM_3 +// - https://github.com/IBM/rocc-software/blob/fddb795a0b52e82f8f4ce9ead9b1428440a62ab0/src/xcustom.h#L123 + +let hasSideEffects = 1, mayLoad = 1, mayStore = 1, Predicates = [HasBuddyExt] in +def MVIN : RVInstR<0b0000010, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "mvin","$rs1, $rs2"> { + let rd = 0; +} + +let hasSideEffects = 1, mayLoad = 1, mayStore = 1, Predicates = [HasBuddyExt] in +def MVIN2 : RVInstR<0b0000001, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "mvin2","$rs1, $rs2"> { + let rd = 0; +} + +let hasSideEffects = 1, mayLoad = 1, mayStore = 1, Predicates = [HasBuddyExt] in +def MVIN3 : RVInstR<0b0001110, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "mvin3","$rs1, $rs2"> { + let rd = 0; +} + +let hasSideEffects = 1, mayLoad = 1, mayStore = 1, Predicates = [HasBuddyExt] in +def MVOUT : RVInstR<0b0000011, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "mvout","$rs1, $rs2">{ + let rd = 0; +} + +let hasSideEffects = 1, mayLoad = 1, mayStore = 1, Predicates = [HasBuddyExt] in +def FLUSH : RVInstR<0b0000111, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "flush", "$rs1"> { + let rd = 0; +} + +let Predicates = [HasBuddyExt] in +def CONFIG : RVInstR<0b0000000, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "config", "$rs1, $rs2"> { + let rd = 0; +} + +let hasSideEffects = 1, mayLoad = 1, mayStore =1, Predicates = [HasBuddyExt] in +def PRELOAD : RVInstR<0b0000110, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "preload", "$rs1, $rs2">{ + let rd = 0; +} + +let Predicates = [HasBuddyExt] in +def COMPUTE_PRELOADED : RVInstR<0b0000100, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "compute_preloaded", "$rs1, $rs2">{ + let rd = 0; +} + +let Predicates = [HasBuddyExt] in +def COMPUTE_ACCUMULATED : RVInstR<0b0000101, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "compute_accumulated", "$rs1, $rs2"> { + let rd = 0; +} + +let Predicates = [HasBuddyExt] in +def LOOP_WS_CONFIG_BOUNDS : RVInstR<0b0001001, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "loop_ws_config_bounds","$rs1, $rs2">{ + let rd = 0; +} + +let Predicates = [HasBuddyExt] in +def LOOP_WS_CONFIG_ADDRS_AB : RVInstR<0b0001010, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "loop_ws_config_addrs_ab", "$rs1, $rs2"> { + let rd = 0; +} + +let Predicates = [HasBuddyExt] in +def LOOP_WS_CONFIG_ADDRS_DC : RVInstR<0b0001011, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "loop_ws_config_addrs_dc", "$rs1, $rs2"> { + let rd = 0; +} + +let Predicates = [HasBuddyExt] in +def LOOP_WS_CONFIG_STRIDES_AB : RVInstR<0b0001100, 0b011, OPC_CUSTOM_3,(outs), + (ins GPR:$rs1, GPR:$rs2), "loop_ws_config_strides_ab", "$rs1, $rs2"> { + let rd = 0; +} + +let Predicates = [HasBuddyExt] in +def LOOP_WS_CONFIG_STRIDES_DC : RVInstR<0b0001101, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "loop_ws_config_strides_dc", "$rs1, $rs2"> { + let rd = 0; +} + +let Predicates = [HasBuddyExt] in +def LOOP_WS : RVInstR<0b0001000, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "loop_ws", "$rs1, $rs2"> { + let rd = 0; +} + +let hasSideEffects = 1, mayLoad = 1, mayStore =1, Predicates = [HasBuddyExt] in +def LOOP_CONV_WS : RVInstR<0b0001111, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "loop_conv_ws", "$rs1, $rs2"> { + let rd = 0; +} + +let Predicates = [HasBuddyExt] in +def LOOP_CONV_WS_CONFIG1 : RVInstR<0b0010000, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "loop_conv_ws_config1", "$rs1, $rs2"> { + let rd = 0; +} + +let Predicates = [HasBuddyExt] in +def LOOP_CONV_WS_CONFIG2 : RVInstR<0b0010001, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "loop_conv_ws_config2", "$rs1, $rs2"> { + let rd = 0; +} + +let Predicates = [HasBuddyExt] in +def LOOP_CONV_WS_CONFIG3 : RVInstR<0b0010010, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "loop_conv_ws_config3", "$rs1, $rs2"> { + let rd = 0; +} + +let Predicates = [HasBuddyExt] in +def LOOP_CONV_WS_CONFIG4 : RVInstR<0b0010011, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "loop_conv_ws_config4", "$rs1, $rs2"> { + let rd = 0; +} + +let Predicates = [HasBuddyExt] in +def LOOP_CONV_WS_CONFIG5 : RVInstR<0b0010100, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "loop_conv_ws_config5", "$rs1, $rs2"> { + let rd = 0; +} + +let Predicates = [HasBuddyExt] in +def LOOP_CONV_WS_CONFIG6 : RVInstR<0b0010101, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "loop_conv_ws_config6", "$rs1, $rs2"> { + let rd = 0; +} + +//===----------------------------------------------------------------------===// +// IME Extension Instructions +//===----------------------------------------------------------------------===// + +class RVInstIME funct7, bits<3> funct3, dag outs, dag ins, + string opcodestr, string argstr> + : RVInst { + bits<5> vs2; + bits<5> vs1; + bits<5> vd; + + let Inst{31-25} = funct7; + let Inst{24-20} = vs2; + let Inst{19-15} = vs1; + let Inst{14-12} = funct3; + let Inst{11-7} = vd; + let Inst{6-0} = OPC_CUSTOM_1.Value; + + let Uses = [VTYPE, VL]; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOT : RVInstIME<0b1110001, 0b000, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM4:$vs1, VRM4:$vs2), + "vmadot", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOTU : RVInstIME<0b1110001, 0b011, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM4:$vs1, VRM4:$vs2), + "vmadotu", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOTSU : RVInstIME<0b1110001, 0b001, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM4:$vs1, VRM4:$vs2), + "vmadotsu", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOTUS : RVInstIME<0b1110001, 0b010, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM4:$vs1, VRM4:$vs2), + "vmadotus", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VFMADOT : RVInstIME<0b1110101, 0b000, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM4:$vs1, VRM4:$vs2), + "vfmadot", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +//===----------------------------------------------------------------------===// +// IME Sliding-Window Instructions +//===----------------------------------------------------------------------===// + +// Integer slide-1 instructions +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOT1 : RVInstIME<0b1110010, 0b000, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vmadot1", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOT1U : RVInstIME<0b1110010, 0b011, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vmadot1u", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOT1SU : RVInstIME<0b1110010, 0b001, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vmadot1su", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOT1US : RVInstIME<0b1110010, 0b010, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vmadot1us", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +// Integer slide-2 instructions +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOT2 : RVInstIME<0b1110011, 0b000, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vmadot2", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOT2U : RVInstIME<0b1110011, 0b011, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vmadot2u", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOT2SU : RVInstIME<0b1110011, 0b001, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vmadot2su", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOT2US : RVInstIME<0b1110011, 0b010, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vmadot2us", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +// Integer slide-3 instructions +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOT3 : RVInstIME<0b1110100, 0b000, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vmadot3", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOT3U : RVInstIME<0b1110100, 0b011, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vmadot3u", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOT3SU : RVInstIME<0b1110100, 0b001, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vmadot3su", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOT3US : RVInstIME<0b1110100, 0b010, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vmadot3us", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +// Floating-point slide instructions +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VFMADOT1 : RVInstIME<0b1110110, 0b000, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vfmadot1", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VFMADOT2 : RVInstIME<0b1110111, 0b000, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vfmadot2", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VFMADOT3 : RVInstIME<0b1111000, 0b000, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vfmadot3", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_mvin GPR:$rs1, GPR:$rs2), (MVIN GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_mvin2 GPR:$rs1, GPR:$rs2), (MVIN2 GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_mvin3 GPR:$rs1, GPR:$rs2), (MVIN3 GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_mvout GPR:$rs1, GPR:$rs2), (MVOUT GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_flush GPR:$rs1, GPR:$rs2), (FLUSH GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_config GPR:$rs1, GPR:$rs2), (CONFIG GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_preload GPR:$rs1, GPR:$rs2), (PRELOAD GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_compute_preloaded GPR:$rs1, GPR:$rs2), (COMPUTE_PRELOADED GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_compute_accumulated GPR:$rs1, GPR:$rs2), (COMPUTE_ACCUMULATED GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_loop_ws_config_bounds GPR:$rs1, GPR:$rs2), (LOOP_WS_CONFIG_BOUNDS GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_loop_ws_config_addrs_ab GPR:$rs1, GPR:$rs2), (LOOP_WS_CONFIG_ADDRS_AB GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_loop_ws_config_addrs_dc GPR:$rs1, GPR:$rs2), (LOOP_WS_CONFIG_ADDRS_DC GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_loop_ws_config_strides_ab GPR:$rs1, GPR:$rs2), (LOOP_WS_CONFIG_STRIDES_AB GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_loop_ws_config_strides_dc GPR:$rs1, GPR:$rs2), (LOOP_WS_CONFIG_STRIDES_DC GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_loop_ws GPR:$rs1, GPR:$rs2), (LOOP_WS GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_loop_conv_ws GPR:$rs1, GPR:$rs2), (LOOP_CONV_WS GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_loop_conv_ws_config1 GPR:$rs1, GPR:$rs2), (LOOP_CONV_WS_CONFIG1 GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_loop_conv_ws_config2 GPR:$rs1, GPR:$rs2), (LOOP_CONV_WS_CONFIG2 GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_loop_conv_ws_config3 GPR:$rs1, GPR:$rs2), (LOOP_CONV_WS_CONFIG3 GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_loop_conv_ws_config4 GPR:$rs1, GPR:$rs2), (LOOP_CONV_WS_CONFIG4 GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_loop_conv_ws_config5 GPR:$rs1, GPR:$rs2), (LOOP_CONV_WS_CONFIG5 GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_loop_conv_ws_config6 GPR:$rs1, GPR:$rs2), (LOOP_CONV_WS_CONFIG6 GPR:$rs1, GPR:$rs2)>; + +//===----------------------------------------------------------------------===// +// IME Extension Patterns +//===----------------------------------------------------------------------===// + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadot nxv8i32:$vd, nxv32i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOT VRM4:$vd, VRM4:$vs1, VRM4:$vs2)>; +} + +// int16 vmadot patterns +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadot nxv8i32:$vd, nxv16i16:$vs1, nxv16i16:$vs2)), + (IME_VMADOT VRM4:$vd, VRM4:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadotu nxv8i32:$vd, nxv32i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOTU VRM4:$vd, VRM4:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadotsu nxv8i32:$vd, nxv32i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOTSU VRM4:$vd, VRM4:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadotus nxv8i32:$vd, nxv32i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOTUS VRM4:$vd, VRM4:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv16f16 (int_riscv_ime_vfmadot nxv16f16:$vd, nxv16f16:$vs1, nxv16f16:$vs2)), + (IME_VFMADOT VRM4:$vd, VRM4:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadot1 nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOT1 VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadot1u nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOT1U VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadot1su nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOT1SU VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadot1us nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOT1US VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadot2 nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOT2 VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadot2u nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOT2U VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadot2su nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOT2SU VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadot2us nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOT2US VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadot3 nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOT3 VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadot3u nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOT3U VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadot3su nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOT3SU VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadot3us nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOT3US VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv16f16 (int_riscv_ime_vfmadot1 nxv16f16:$vd, nxv32f16:$vs1, nxv16f16:$vs2)), + (IME_VFMADOT1 VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv16f16 (int_riscv_ime_vfmadot2 nxv16f16:$vd, nxv32f16:$vs1, nxv16f16:$vs2)), + (IME_VFMADOT2 VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv16f16 (int_riscv_ime_vfmadot3 nxv16f16:$vd, nxv32f16:$vs1, nxv16f16:$vs2)), + (IME_VFMADOT3 VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +class RVInstIMEN funct7, bits<3> funct3, dag outs, dag ins, + string opcodestr, string argstr> + : RVInst { + bits<5> vs2; + bits<5> rs1; // GPR for dynamic slide value + bits<5> vd; + + let Inst{31-25} = funct7; + let Inst{24-20} = vs2; + let Inst{19-15} = rs1; + let Inst{14-12} = funct3; + let Inst{11-7} = vd; + let Inst{6-0} = OPC_CUSTOM_1.Value; + + let Uses = [VTYPE, VL]; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOTN : RVInstIMEN<0b1111001, 0b000, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2, GPR:$rs1), + "vmadotn", "$vd, $vs1, $vs2, $rs1"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOTNU : RVInstIMEN<0b1111001, 0b011, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2, GPR:$rs1), + "vmadotnu", "$vd, $vs1, $vs2, $rs1"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOTNSU : RVInstIMEN<0b1111001, 0b001, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2, GPR:$rs1), + "vmadotnsu", "$vd, $vs1, $vs2, $rs1"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOTNUS : RVInstIMEN<0b1111001, 0b010, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2, GPR:$rs1), + "vmadotnus", "$vd, $vs1, $vs2, $rs1"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VFMADOTN : RVInstIMEN<0b1111010, 0b000, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2, GPR:$rs1), + "vfmadotn", "$vd, $vs1, $vs2, $rs1"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadotn nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2, GPR:$rs1)), + (IME_VMADOTN VRM4:$vd, VRM8:$vs1, VRM4:$vs2, GPR:$rs1)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadotnu nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2, GPR:$rs1)), + (IME_VMADOTNU VRM4:$vd, VRM8:$vs1, VRM4:$vs2, GPR:$rs1)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadotnsu nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2, GPR:$rs1)), + (IME_VMADOTNSU VRM4:$vd, VRM8:$vs1, VRM4:$vs2, GPR:$rs1)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadotnus nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2, GPR:$rs1)), + (IME_VMADOTNUS VRM4:$vd, VRM8:$vs1, VRM4:$vs2, GPR:$rs1)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv16f16 (int_riscv_ime_vfmadotn nxv16f16:$vd, nxv32f16:$vs1, nxv16f16:$vs2, GPR:$rs1)), + (IME_VFMADOTN VRM4:$vd, VRM8:$vs1, VRM4:$vs2, GPR:$rs1)>; +} + +//===----------------------------------------------------------------------===// +// AME (RISC-V Matrix Extension) 64-bit Instructions +//===----------------------------------------------------------------------===// +// Reference: RISC-V Matrix Extension Specification +// 64-bit encoding format with prefix 0111111 at bits [6:0] +// +// Register Model: +// - TileReg (tr0-tr7): Input matrices A and B, MLEN bits each +// - AccReg (acc0-acc7): Accumulator matrix C, MLEN×AMUL bits each +// +// Data Flow: +// Memory --[mlae/mlbe]--> TileReg --[mma/mwma/mqma]--> AccReg --[msce]--> Memory +// +// Matrix Multiplication Instruction Format (64-bit): +// | 63:59 | 58 | 57:55 | 54:52 | 51:49 | 48:47 | 46:44 | 43:39 | 38:32 | +// | sps | sp | typ2 | typ1 | typd | bma | frm | funct5 | opcode | +// | 31:26 | 25 | 24:20 | 19:15 | 14:12 | 11:7 | 6:0 | +// | funct6 | fp | ms2 | ms1 | funct3 | md | suffix | +// +// suffix = 0111111 (AME prefix) +// funct3 = 100 (matrix multiplication) +// +// Widening Instructions (required for AMUL > 1): +// - mqma.b.mm: int8 → int32 (quad-widen, AMUL ≥ 4, mmi8i32 MANDATORY) +// - mwma.h.mm: int16 → int32 (double-widen, AMUL ≥ 2) +// - mwma.w.mm: int32 → int64 (double-widen, AMUL ≥ 2) +//===----------------------------------------------------------------------===// + +// AME Opcode suffix (bits [6:0]) +def OPC_AME : RISCVOpcode<"OPC_AME", 0b0111111>; + +//===----------------------------------------------------------------------===// +// AME Custom Operand Types for Tile Register Indices +//===----------------------------------------------------------------------===// +// These operand types allow intrinsics to pass immediate indices (0-7) +// for tile and accumulator registers. The AsmString of pseudo instructions +// hardcodes "acc" and "tr" prefixes so that indices 0-7 are printed as +// acc0-acc7 and tr0-tr7 respectively, without modifying LLVM submodule. + +// AsmOperandClass for tile register index (0-7) +def AMETileIndexAsmOperand : AsmOperandClass { + let Name = "AMETileIndex"; + let RenderMethod = "addImmOperands"; + let PredicateMethod = "isUImm3"; + let DiagnosticType = "InvalidAMETileIndex"; +} + +// AsmOperandClass for accumulator register index (0-7) +def AMEAccIndexAsmOperand : AsmOperandClass { + let Name = "AMEAccIndex"; + let RenderMethod = "addImmOperands"; + let PredicateMethod = "isUImm3"; + let DiagnosticType = "InvalidAMEAccIndex"; +} + +// Operand type for TileReg index (0-7), printed with "tr" prefix in AsmString +def AMETileIndex : RISCVOp { + let ParserMatchClass = AMETileIndexAsmOperand; + let DecoderMethod = "decodeUImmOperand<3>"; + let OperandType = "OPERAND_UIMM3"; +} + +// Operand type for AccReg index (0-7), printed with "acc" prefix in AsmString +def AMEAccIndex : RISCVOp { + let ParserMatchClass = AMEAccIndexAsmOperand; + let DecoderMethod = "decodeUImmOperand<3>"; + let OperandType = "OPERAND_UIMM3"; +} + +//===----------------------------------------------------------------------===// +// AME 64-bit Instruction Format Base Class +//===----------------------------------------------------------------------===// + +// Base class for AME 64-bit matrix multiplication instructions +// Uses TileReg for inputs (ms1, ms2) and AccReg for output (md) +class RVInstAME64 + : RVInst64 { + // Low 32 bits + bits<5> md; + bits<5> ms1; + bits<5> ms2; + bits<6> funct6; + bit fp; + bits<3> funct3; + + // High 32 bits + bits<5> funct5; + bits<3> frm; + bits<2> bma; + bits<3> typd; + bits<3> typ1; + bits<3> typ2; + bit sp; + bits<5> sps; + bits<7> opcode_hi; // [38:32] + + // Encode low 32 bits (suffix word) + let Inst{6-0} = 0b0111111; // AME prefix + let Inst{11-7} = md; + let Inst{14-12} = funct3; + let Inst{19-15} = ms1; + let Inst{24-20} = ms2; + let Inst{25} = fp; + let Inst{31-26} = funct6; + + // Encode high 32 bits (opcode word) + let Inst{38-32} = opcode_hi; + let Inst{43-39} = funct5; + let Inst{46-44} = frm; + let Inst{48-47} = bma; + let Inst{51-49} = typd; + let Inst{54-52} = typ1; + let Inst{57-55} = typ2; + let Inst{58} = sp; + let Inst{63-59} = sps; +} + +//===----------------------------------------------------------------------===// +// AME Matrix Multiplication Instructions +//===----------------------------------------------------------------------===// +// Format: mma.{h|w|dw}.mm acc, tr, tr +// Semantics: acc = acc + tr1 * tr2 +// +// Data Flow: TileReg × TileReg → AccReg (accumulate) +// - ms1: TileReg for matrix A +// - ms2: TileReg for matrix B +// - md: AccReg for accumulation result C +// +// typ1/typ2/typd encoding: +// 000 = int8 (b), 001 = int16 (h), 010 = int32 (w), 011 = int64 (dw) +// 100 = use mtype.msew, 111 = int4 +// +// funct5 encoding: +// 00001 = mma (signed, no saturation) +// 00010 = mwma (double-widen) +// 00100 = mqma (quad-widen) +// 10001 = msma (signed, saturated) +//===----------------------------------------------------------------------===// + +// No-widen matrix multiply-accumulate: acc = acc + tr1 * tr2 +// Input and output have the same element width +class AME_MMA_MM typd_val, bits<3> typ1_val, bits<3> typ2_val, + string opcodestr> + : RVInstAME64<(outs AccReg:$md), + (ins AccReg:$md_in, TileReg:$ms1, TileReg:$ms2), + opcodestr, "$md, $ms1, $ms2"> { + let Constraints = "$md = $md_in"; // Accumulator constraint + let funct6 = 0b000000; + let fp = 0; + let funct3 = 0b100; // Matrix multiplication + let opcode_hi = 0b0000011; // xxyyy11 where xx=00, yyy=001 + let funct5 = 0b00001; // mma (signed, no saturation) + let frm = 0b000; + let bma = 0b00; // Default: not agnostic + let typd = typd_val; + let typ1 = typ1_val; + let typ2 = typ2_val; + let sp = 0; // No sparsity + let sps = 0b00000; +} + +// mma.h.mm - int16 × int16 → int16 accumulate (no widen) +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MMA_H_MM : AME_MMA_MM<0b001, 0b001, 0b001, "mma.h.mm">; + +// mma.w.mm - int32 × int32 → int32 accumulate (no widen) +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MMA_W_MM : AME_MMA_MM<0b010, 0b010, 0b010, "mma.w.mm">; + +// mma.dw.mm - int64 × int64 → int64 accumulate (no widen) +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MMA_DW_MM : AME_MMA_MM<0b011, 0b011, 0b011, "mma.dw.mm">; + +//===----------------------------------------------------------------------===// +// AME Widening Matrix Multiplication Instructions +//===----------------------------------------------------------------------===// +// Double-widen: output element is 2× width of input elements +// Quad-widen: output element is 4× width of input elements +// +// These require AMUL ≥ 2 (double) or AMUL ≥ 4 (quad) to ensure +// accumulator has sufficient width. +// +// Mandatory: mqma.b.mm (int8→int32) for mmi8i32 feature +//===----------------------------------------------------------------------===// + +// Double-widen: acc = acc + tr1 * tr2, output is 2× width +// Requires AMUL ≥ 2 +class AME_MWMA_MM typd_val, bits<3> typ1_val, bits<3> typ2_val, + string opcodestr> + : RVInstAME64<(outs AccReg:$md), + (ins AccReg:$md_in, TileReg:$ms1, TileReg:$ms2), + opcodestr, "$md, $ms1, $ms2"> { + let Constraints = "$md = $md_in"; + let funct6 = 0b000000; + let fp = 0; + let funct3 = 0b100; + let opcode_hi = 0b0000011; + let funct5 = 0b00010; // mwma (double-widen) + let frm = 0b000; + let bma = 0b00; + let typd = typd_val; + let typ1 = typ1_val; + let typ2 = typ2_val; + let sp = 0; + let sps = 0b00000; +} + +// mwma.h.mm - int16 × int16 → int32 accumulate (double-widen) +// Requires: AMUL ≥ 2, mmi16i32 feature +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MWMA_H_MM : AME_MWMA_MM<0b010, 0b001, 0b001, "mwma.h.mm">; + +// mwma.w.mm - int32 × int32 → int64 accumulate (double-widen) +// Requires: AMUL ≥ 2, mmi32i64 feature +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MWMA_W_MM : AME_MWMA_MM<0b011, 0b010, 0b010, "mwma.w.mm">; + +// Quad-widen: acc = acc + tr1 * tr2, output is 4× width +// Requires AMUL ≥ 4 +class AME_MQMA_MM typd_val, bits<3> typ1_val, bits<3> typ2_val, + string opcodestr> + : RVInstAME64<(outs AccReg:$md), + (ins AccReg:$md_in, TileReg:$ms1, TileReg:$ms2), + opcodestr, "$md, $ms1, $ms2"> { + let Constraints = "$md = $md_in"; + let funct6 = 0b000000; + let fp = 0; + let funct3 = 0b100; + let opcode_hi = 0b0000011; + let funct5 = 0b00100; // mqma (quad-widen) + let frm = 0b000; + let bma = 0b00; + let typd = typd_val; + let typ1 = typ1_val; + let typ2 = typ2_val; + let sp = 0; + let sps = 0b00000; +} + +// mqma.b.mm - int8 × int8 → int32 accumulate (quad-widen) +// MANDATORY for mmi8i32 feature (required by Spec) +// Requires: AMUL ≥ 4 +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MQMA_B_MM : AME_MQMA_MM<0b010, 0b000, 0b000, "mqma.b.mm">; + +//===----------------------------------------------------------------------===// +// AME Configuration Instructions +//===----------------------------------------------------------------------===// +// Configuration instruction format (64-bit): +// | 63:43 | 42:39 | 38:32 | +// | imm[31:11] | funct4 | opcode | +// | 31:26 | 25:20 | 19:15 | 14:12 | 11:7 | 6:0 | +// | funct6 | imm[10:5] | rs1 | funct3 | rd | suffix | +//===----------------------------------------------------------------------===// + +class RVInstAMEConfig64 funct6_val, string opcodestr> + : RVInst64<(outs GPR:$rd), (ins GPR:$rs1), + opcodestr, "$rd, $rs1", [], InstFormatOther> { + bits<5> rd; + bits<5> rs1; + + let Inst{6-0} = 0b0111111; // AME prefix + let Inst{11-7} = rd; + let Inst{14-12} = 0b000; // funct3 for config + let Inst{19-15} = rs1; + let Inst{25-20} = 0b000000; + let Inst{31-26} = funct6_val; + let Inst{38-32} = 0b0000011; // opcode + let Inst{42-39} = 0b0000; // funct4 + let Inst{63-43} = 0; // imm[31:11] = 0 +} + +class RVInstAMEConfigImm64 funct6_val, string opcodestr> + : RVInst64<(outs GPR:$rd), (ins uimm32:$imm), + opcodestr, "$rd, $imm", [], InstFormatOther> { + bits<5> rd; + bits<32> imm; + + let Inst{6-0} = 0b0111111; // AME prefix + let Inst{11-7} = rd; + let Inst{14-12} = 0b000; // funct3 for config + let Inst{19-15} = imm{4-0}; // imm[4:0] in rs1 field + let Inst{25-20} = imm{10-5}; // imm[10:5] + let Inst{31-26} = funct6_val; + let Inst{38-32} = 0b0000011; // opcode + let Inst{42-39} = 0b0000; // funct4 + let Inst{63-43} = imm{31-11}; // imm[31:11] +} + +// msettilem - set tile M dimension from register +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MSETTILEM : RVInstAMEConfig64<0b000100, "msettilem">; + +// msettilemi - set tile M dimension from immediate +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MSETTILEMI : RVInstAMEConfigImm64<0b000101, "msettilemi">; + +// msettilen - set tile N dimension from register +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MSETTILEN : RVInstAMEConfig64<0b001000, "msettilen">; + +// msettileni - set tile N dimension from immediate +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MSETTILENI : RVInstAMEConfigImm64<0b001001, "msettileni">; + +// msettilek - set tile K dimension from register +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MSETTILEK : RVInstAMEConfig64<0b001100, "msettilek">; + +// msettileki - set tile K dimension from immediate +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MSETTILEKI : RVInstAMEConfigImm64<0b001101, "msettileki">; + +//===----------------------------------------------------------------------===// +// AME Load/Store Instructions +//===----------------------------------------------------------------------===// +// Load/Store instruction format (64-bit): +// | 63:51 | 50:49 | 48:47 | 46:44 | 43:39 | 38:32 | +// | resv | mt | bma | eew | funct5 | opcode | +// | 31:26 | 25 | 24:20 | 19:15 | 14:12 | 11:7 | 6:0 | +// | funct6 | ls | rs2 | rs1 | funct3 | ms3/md | suffix | +// +// mt (matrix type): 00=accumulator(C), 01=left(A), 10=right(B), 11=result +// eew (element width): 000=8b, 001=16b, 010=32b, 011=64b +// +// Register Usage: +// - mt=01 (A) or mt=10 (B): Uses TileReg +// - mt=00 (C): Uses AccReg +// +// Data Flow Examples: +// mlae32.m tr0, (a0), a1 # Load matrix A into TileReg +// mlbe32.m tr1, (a0), a1 # Load matrix B into TileReg +// mqma.b.mm acc0, tr0, tr1 # Compute: AccReg = AccReg + TileReg × TileReg +// msce32.m acc0, (a0), a1 # Store AccReg to memory +//===----------------------------------------------------------------------===// + +// Load into TileReg (for matrix A and B) +class RVInstAMELoadTile64 mt_val, bits<3> eew_val, string opcodestr> + : RVInst64<(outs TileReg:$md), (ins GPR:$rs1, GPR:$rs2), + opcodestr, "$md, $rs1, $rs2", [], InstFormatOther> { + bits<5> md; + bits<5> rs1; + bits<5> rs2; + + let Inst{6-0} = 0b0111111; // AME prefix + let Inst{11-7} = md; + let Inst{14-12} = 0b001; // funct3 for load/store + let Inst{19-15} = rs1; + let Inst{24-20} = rs2; + let Inst{25} = 0; // ls = 0 for load + let Inst{31-26} = 0b000000; // funct6 + let Inst{38-32} = 0b0000011; // opcode + let Inst{43-39} = 0b00000; // funct5 + let Inst{46-44} = eew_val; // eew + let Inst{48-47} = 0b00; // bma + let Inst{50-49} = mt_val; // mt + let Inst{63-51} = 0; // reserved +} + +// Load into AccReg (for accumulator C) +class RVInstAMELoadAcc64 mt_val, bits<3> eew_val, string opcodestr> + : RVInst64<(outs AccReg:$md), (ins GPR:$rs1, GPR:$rs2), + opcodestr, "$md, $rs1, $rs2", [], InstFormatOther> { + bits<5> md; + bits<5> rs1; + bits<5> rs2; + + let Inst{6-0} = 0b0111111; + let Inst{11-7} = md; + let Inst{14-12} = 0b001; + let Inst{19-15} = rs1; + let Inst{24-20} = rs2; + let Inst{25} = 0; + let Inst{31-26} = 0b000000; + let Inst{38-32} = 0b0000011; + let Inst{43-39} = 0b00000; + let Inst{46-44} = eew_val; + let Inst{48-47} = 0b00; + let Inst{50-49} = mt_val; + let Inst{63-51} = 0; +} + +// Store from TileReg (for matrix A and B) +class RVInstAMEStoreTile64 mt_val, bits<3> eew_val, string opcodestr> + : RVInst64<(outs), (ins TileReg:$ms3, GPR:$rs1, GPR:$rs2), + opcodestr, "$ms3, $rs1, $rs2", [], InstFormatOther> { + bits<5> ms3; + bits<5> rs1; + bits<5> rs2; + + let Inst{6-0} = 0b0111111; + let Inst{11-7} = ms3; + let Inst{14-12} = 0b001; + let Inst{19-15} = rs1; + let Inst{24-20} = rs2; + let Inst{25} = 1; // ls = 1 for store + let Inst{31-26} = 0b000000; + let Inst{38-32} = 0b0000011; + let Inst{43-39} = 0b00000; + let Inst{46-44} = eew_val; + let Inst{48-47} = 0b00; + let Inst{50-49} = mt_val; + let Inst{63-51} = 0; +} + +// Store from AccReg (for accumulator C) +class RVInstAMEStoreAcc64 mt_val, bits<3> eew_val, string opcodestr> + : RVInst64<(outs), (ins AccReg:$ms3, GPR:$rs1, GPR:$rs2), + opcodestr, "$ms3, $rs1, $rs2", [], InstFormatOther> { + bits<5> ms3; + bits<5> rs1; + bits<5> rs2; + + let Inst{6-0} = 0b0111111; + let Inst{11-7} = ms3; + let Inst{14-12} = 0b001; + let Inst{19-15} = rs1; + let Inst{24-20} = rs2; + let Inst{25} = 1; + let Inst{31-26} = 0b000000; + let Inst{38-32} = 0b0000011; + let Inst{43-39} = 0b00000; + let Inst{46-44} = eew_val; + let Inst{48-47} = 0b00; + let Inst{50-49} = mt_val; + let Inst{63-51} = 0; +} + +//===----------------------------------------------------------------------===// +// Load matrix A (left operand) into TileReg - mlae*.m +// Syntax: mlae{8|16|32|64}.m tr, (rs1), rs2 +// tr: Destination TileReg +// rs1: Base address (GPR) +// rs2: Row stride in bytes (GPR) +//===----------------------------------------------------------------------===// +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 1, mayStore = 0 in { + def AME_MLAE8_M : RVInstAMELoadTile64<0b01, 0b000, "mlae8.m">; + def AME_MLAE16_M : RVInstAMELoadTile64<0b01, 0b001, "mlae16.m">; + def AME_MLAE32_M : RVInstAMELoadTile64<0b01, 0b010, "mlae32.m">; + def AME_MLAE64_M : RVInstAMELoadTile64<0b01, 0b011, "mlae64.m">; +} + +//===----------------------------------------------------------------------===// +// Load matrix B (right operand) into TileReg - mlbe*.m +// Syntax: mlbe{8|16|32|64}.m tr, (rs1), rs2 +//===----------------------------------------------------------------------===// +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 1, mayStore = 0 in { + def AME_MLBE8_M : RVInstAMELoadTile64<0b10, 0b000, "mlbe8.m">; + def AME_MLBE16_M : RVInstAMELoadTile64<0b10, 0b001, "mlbe16.m">; + def AME_MLBE32_M : RVInstAMELoadTile64<0b10, 0b010, "mlbe32.m">; + def AME_MLBE64_M : RVInstAMELoadTile64<0b10, 0b011, "mlbe64.m">; +} + +//===----------------------------------------------------------------------===// +// Load matrix C (accumulator) into AccReg - mlce*.m +// Syntax: mlce{8|16|32|64}.m acc, (rs1), rs2 +// acc: Destination AccReg (MLEN×AMUL bits) +// Note: eew here refers to the output element width after widening +//===----------------------------------------------------------------------===// +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 1, mayStore = 0 in { + def AME_MLCE8_M : RVInstAMELoadAcc64<0b00, 0b000, "mlce8.m">; + def AME_MLCE16_M : RVInstAMELoadAcc64<0b00, 0b001, "mlce16.m">; + def AME_MLCE32_M : RVInstAMELoadAcc64<0b00, 0b010, "mlce32.m">; + def AME_MLCE64_M : RVInstAMELoadAcc64<0b00, 0b011, "mlce64.m">; +} + +//===----------------------------------------------------------------------===// +// Store matrix A from TileReg - msae*.m +// Syntax: msae{8|16|32|64}.m tr, (rs1), rs2 +//===----------------------------------------------------------------------===// +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 1 in { + def AME_MSAE8_M : RVInstAMEStoreTile64<0b01, 0b000, "msae8.m">; + def AME_MSAE16_M : RVInstAMEStoreTile64<0b01, 0b001, "msae16.m">; + def AME_MSAE32_M : RVInstAMEStoreTile64<0b01, 0b010, "msae32.m">; + def AME_MSAE64_M : RVInstAMEStoreTile64<0b01, 0b011, "msae64.m">; +} + +//===----------------------------------------------------------------------===// +// Store matrix B from TileReg - msbe*.m +// Syntax: msbe{8|16|32|64}.m tr, (rs1), rs2 +//===----------------------------------------------------------------------===// +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 1 in { + def AME_MSBE8_M : RVInstAMEStoreTile64<0b10, 0b000, "msbe8.m">; + def AME_MSBE16_M : RVInstAMEStoreTile64<0b10, 0b001, "msbe16.m">; + def AME_MSBE32_M : RVInstAMEStoreTile64<0b10, 0b010, "msbe32.m">; + def AME_MSBE64_M : RVInstAMEStoreTile64<0b10, 0b011, "msbe64.m">; +} + +//===----------------------------------------------------------------------===// +// Store matrix C (accumulator) from AccReg - msce*.m +// Syntax: msce{8|16|32|64}.m acc, (rs1), rs2 +// acc: Source AccReg (MLEN×AMUL bits) +// This is the primary store for computation results +//===----------------------------------------------------------------------===// +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 1 in { + def AME_MSCE8_M : RVInstAMEStoreAcc64<0b00, 0b000, "msce8.m">; + def AME_MSCE16_M : RVInstAMEStoreAcc64<0b00, 0b001, "msce16.m">; + def AME_MSCE32_M : RVInstAMEStoreAcc64<0b00, 0b010, "msce32.m">; + def AME_MSCE64_M : RVInstAMEStoreAcc64<0b00, 0b011, "msce64.m">; +} + +//===----------------------------------------------------------------------===// +// AME Extension Pattern Matching +//===----------------------------------------------------------------------===// +// Connect LLVM intrinsics to AME machine instructions +// +// Register Model: +// - TileReg (tr0-tr7): For input matrices A and B +// - AccReg (acc0-acc7): For output/accumulator C +// +// Typical Data Flow for Matrix Multiplication (e.g., int8→int32): +// 1. msettilem/n/k: Configure tile dimensions +// 2. mlae8.m tr0, (a0), stride_a: Load matrix A into TileReg +// 3. mlbe8.m tr1, (a1), stride_b: Load matrix B into TileReg +// 4. mqma.b.mm acc0, tr0, tr1: Compute acc0 = acc0 + tr0 × tr1 +// 5. msce32.m acc0, (a2), stride_c: Store AccReg to memory +// +// Note: For quad-widen (int8→int32), input uses 8-bit load (mlae8/mlbe8) +// but output uses 32-bit store (msce32) because AMUL=4 widens the output. +//===----------------------------------------------------------------------===// + +// Configuration instruction patterns +// These use GPR operands and return the actual configured value +let Predicates = [HasBuddyExt] in { + // msettilem - set tile M dimension from GPR, returns actual value + def : Pat<(i64 (int_riscv_buddy_msettilem GPR:$rs1)), + (AME_MSETTILEM GPR:$rs1)>; + + // msettilen - set tile N dimension from GPR, returns actual value + def : Pat<(i64 (int_riscv_buddy_msettilen GPR:$rs1)), + (AME_MSETTILEN GPR:$rs1)>; + + // msettilek - set tile K dimension from GPR, returns actual value + def : Pat<(i64 (int_riscv_buddy_msettilek GPR:$rs1)), + (AME_MSETTILEK GPR:$rs1)>; +} + +//===----------------------------------------------------------------------===// +// AME Additional Integer Matrix Operations +//===----------------------------------------------------------------------===// +// Unsigned and mixed-sign matrix multiplication variants +// +// Naming convention: +// - mma: signed × signed +// - mmau: unsigned × unsigned +// - mmasu: signed × unsigned +// - mmaus: unsigned × signed +//===----------------------------------------------------------------------===// + +// Unsigned no-widen multiply-accumulate +class AME_MMAU_MM typd_val, bits<3> typ1_val, bits<3> typ2_val, + string opcodestr> + : RVInstAME64<(outs AccReg:$md), + (ins AccReg:$md_in, TileReg:$ms1, TileReg:$ms2), + opcodestr, "$md, $ms1, $ms2"> { + let Constraints = "$md = $md_in"; + let funct6 = 0b000000; + let fp = 0; + let funct3 = 0b100; + let opcode_hi = 0b0000011; + let funct5 = 0b01001; // mmau (unsigned) + let frm = 0b000; + let bma = 0b00; + let typd = typd_val; + let typ1 = typ1_val; + let typ2 = typ2_val; + let sp = 0; + let sps = 0b00000; +} + +// mmau.b.mm - uint8 × uint8 → uint32 (for unsigned int8) +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MMAU_B_MM : AME_MMAU_MM<0b010, 0b000, 0b000, "mmau.b.mm">; + +// Unsigned quad-widen multiply-accumulate +class AME_MQMAU_MM typd_val, bits<3> typ1_val, bits<3> typ2_val, + string opcodestr> + : RVInstAME64<(outs AccReg:$md), + (ins AccReg:$md_in, TileReg:$ms1, TileReg:$ms2), + opcodestr, "$md, $ms1, $ms2"> { + let Constraints = "$md = $md_in"; + let funct6 = 0b000000; + let fp = 0; + let funct3 = 0b100; + let opcode_hi = 0b0000011; + let funct5 = 0b01100; // mqmau (unsigned quad-widen) + let frm = 0b000; + let bma = 0b00; + let typd = typd_val; + let typ1 = typ1_val; + let typ2 = typ2_val; + let sp = 0; + let sps = 0b00000; +} + +// mqmau.b.mm - uint8 × uint8 → uint32 (quad-widen unsigned) +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MQMAU_B_MM : AME_MQMAU_MM<0b010, 0b000, 0b000, "mqmau.b.mm">; + +// Mixed-sign: signed × unsigned +class AME_MQMASU_MM typd_val, bits<3> typ1_val, bits<3> typ2_val, + string opcodestr> + : RVInstAME64<(outs AccReg:$md), + (ins AccReg:$md_in, TileReg:$ms1, TileReg:$ms2), + opcodestr, "$md, $ms1, $ms2"> { + let Constraints = "$md = $md_in"; + let funct6 = 0b000000; + let fp = 0; + let funct3 = 0b100; + let opcode_hi = 0b0000011; + let funct5 = 0b00101; // mqmasu (signed × unsigned quad-widen) + let frm = 0b000; + let bma = 0b00; + let typd = typd_val; + let typ1 = typ1_val; + let typ2 = typ2_val; + let sp = 0; + let sps = 0b00000; +} + +// mqmasu.b.mm - int8 × uint8 → int32 +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MQMASU_B_MM : AME_MQMASU_MM<0b010, 0b000, 0b000, "mqmasu.b.mm">; + +// Mixed-sign: unsigned × signed +class AME_MQMAUS_MM typd_val, bits<3> typ1_val, bits<3> typ2_val, + string opcodestr> + : RVInstAME64<(outs AccReg:$md), + (ins AccReg:$md_in, TileReg:$ms1, TileReg:$ms2), + opcodestr, "$md, $ms1, $ms2"> { + let Constraints = "$md = $md_in"; + let funct6 = 0b000000; + let fp = 0; + let funct3 = 0b100; + let opcode_hi = 0b0000011; + let funct5 = 0b00110; // mqmaus (unsigned × signed quad-widen) + let frm = 0b000; + let bma = 0b00; + let typd = typd_val; + let typ1 = typ1_val; + let typ2 = typ2_val; + let sp = 0; + let sps = 0b00000; +} + +// mqmaus.b.mm - uint8 × int8 → int32 +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MQMAUS_B_MM : AME_MQMAUS_MM<0b010, 0b000, 0b000, "mqmaus.b.mm">; + +//===----------------------------------------------------------------------===// +// AME Zero/Initialize Instructions +//===----------------------------------------------------------------------===// +// mzero - Zero out an accumulation register +// Useful for initializing before accumulation loop + +class AME_MZERO + : RVInst64<(outs AccReg:$md), (ins), + opcodestr, "$md", [], InstFormatOther> { + bits<5> md; + + let Inst{6-0} = 0b0111111; + let Inst{11-7} = md; + let Inst{14-12} = 0b101; // funct3 for arithmetic + let Inst{19-15} = 0b00000; + let Inst{24-20} = 0b00000; + let Inst{25} = 0; + let Inst{31-26} = 0b000000; + let Inst{38-32} = 0b0000011; + let Inst{43-39} = 0b10000; // funct5 for zero + let Inst{63-44} = 0; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MZERO_M : AME_MZERO<"mzero.m">; + +//===----------------------------------------------------------------------===// +// AME Intrinsic Pattern Matching +//===----------------------------------------------------------------------===// +// These patterns map LLVM intrinsics to AME machine instructions. +// +// Note: AME intrinsics use i64 indices for tile registers instead of +// actual register operands. This is because the MLIR lowering generates +// calls with constant indices that get mapped to physical registers +// at the final code generation stage. +// +// For tile-based operations, the tile register index (0-7) is encoded +// directly into the instruction's register field. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Configuration Instruction Patterns +//===----------------------------------------------------------------------===// + +// msettilem - set tile M dimension from GPR +let Predicates = [HasBuddyExt] in { + def : Pat<(i64 (int_riscv_buddy_msettilem i64:$rs1)), + (AME_MSETTILEM GPR:$rs1)>; +} + +// msettilen - set tile N dimension from GPR +let Predicates = [HasBuddyExt] in { + def : Pat<(i64 (int_riscv_buddy_msettilen i64:$rs1)), + (AME_MSETTILEN GPR:$rs1)>; +} + +// msettilek - set tile K dimension from GPR +let Predicates = [HasBuddyExt] in { + def : Pat<(i64 (int_riscv_buddy_msettilek i64:$rs1)), + (AME_MSETTILEK GPR:$rs1)>; +} + +//===----------------------------------------------------------------------===// +// Pseudo Instructions for Index-Based Operations +//===----------------------------------------------------------------------===// +// These pseudo instructions accept i64 indices and have AsmString for direct +// assembly output. This allows the pseudo instructions to be printed directly +// without needing complex expansion logic. +// +// For load/store/mma instructions, the AsmString hardcodes "acc" and "tr" +// prefixes so that indices 0-7 are printed as acc0-acc7 and tr0-tr7. +//===----------------------------------------------------------------------===// + +// Pseudo instruction for msettilemi with i64 immediate +// Output: msettilemi x0, +let Predicates = [HasBuddyExt], hasSideEffects = 1, mayLoad = 0, mayStore = 0, + isCodeGenOnly = 1 in +def AME_MSETTILEMI_PSEUDO : Pseudo<(outs), (ins i64imm:$imm), []> { + let AsmString = "msettilemi\tx0, $imm"; +} + +// Pseudo instruction for msettileni with i64 immediate +let Predicates = [HasBuddyExt], hasSideEffects = 1, mayLoad = 0, mayStore = 0, + isCodeGenOnly = 1 in +def AME_MSETTILENI_PSEUDO : Pseudo<(outs), (ins i64imm:$imm), []> { + let AsmString = "msettileni\tx0, $imm"; +} + +// Pseudo instruction for msettileki with i64 immediate +let Predicates = [HasBuddyExt], hasSideEffects = 1, mayLoad = 0, mayStore = 0, + isCodeGenOnly = 1 in +def AME_MSETTILEKI_PSEUDO : Pseudo<(outs), (ins i64imm:$imm), []> { + let AsmString = "msettileki\tx0, $imm"; +} + +// Pseudo instruction for mzero with accumulator index +let Predicates = [HasBuddyExt], hasSideEffects = 1, mayLoad = 0, mayStore = 0, + isCodeGenOnly = 1 in +def AME_MZERO_PSEUDO : Pseudo<(outs), (ins AMEAccIndex:$md), []> { + let AsmString = "mzero.m\tacc$md"; +} + +// Pseudo instruction for mlae32.m with tile index +let Predicates = [HasBuddyExt], hasSideEffects = 1, mayLoad = 1, mayStore = 0, + isCodeGenOnly = 1 in +def AME_MLAE32_M_PSEUDO : Pseudo<(outs), (ins AMETileIndex:$md, GPR:$rs1, GPR:$rs2), []> { + let AsmString = "mlae32.m\ttr$md, ($rs1), $rs2"; +} + +// Pseudo instruction for mlbe32.m with tile index +let Predicates = [HasBuddyExt], hasSideEffects = 1, mayLoad = 1, mayStore = 0, + isCodeGenOnly = 1 in +def AME_MLBE32_M_PSEUDO : Pseudo<(outs), (ins AMETileIndex:$md, GPR:$rs1, GPR:$rs2), []> { + let AsmString = "mlbe32.m\ttr$md, ($rs1), $rs2"; +} + +// Pseudo instruction for msce32.m with accumulator index +let Predicates = [HasBuddyExt], hasSideEffects = 1, mayLoad = 0, mayStore = 1, + isCodeGenOnly = 1 in +def AME_MSCE32_M_PSEUDO : Pseudo<(outs), (ins AMEAccIndex:$ms3, GPR:$rs1, GPR:$rs2), []> { + let AsmString = "msce32.m\tacc$ms3, ($rs1), $rs2"; +} + +// Pseudo instruction for mma.w.mm.tile with indices +// md: AccReg index (0-7), ms1/ms2: TileReg indices (0-7) +let Predicates = [HasBuddyExt], hasSideEffects = 1, mayLoad = 1, mayStore = 1, + isCodeGenOnly = 1 in +def AME_MMA_W_MM_TILE_PSEUDO : Pseudo<(outs), + (ins AMEAccIndex:$md, AMETileIndex:$ms1, AMETileIndex:$ms2), []> { + let AsmString = "mma.w.mm\tacc$md, tr$ms1, tr$ms2"; +} + +//===----------------------------------------------------------------------===// +// Pattern Matching for Immediate Configuration Instructions +//===----------------------------------------------------------------------===// + +// msettilemi - set tile M dimension with immediate +let Predicates = [HasBuddyExt] in { + def : Pat<(int_riscv_buddy_msettilemi timm:$imm), + (AME_MSETTILEMI_PSEUDO timm:$imm)>; +} + +// msettileni - set tile N dimension with immediate +let Predicates = [HasBuddyExt] in { + def : Pat<(int_riscv_buddy_msettileni timm:$imm), + (AME_MSETTILENI_PSEUDO timm:$imm)>; +} + +// msettileki - set tile K dimension with immediate +let Predicates = [HasBuddyExt] in { + def : Pat<(int_riscv_buddy_msettileki timm:$imm), + (AME_MSETTILEKI_PSEUDO timm:$imm)>; +} + +//===----------------------------------------------------------------------===// +// Pattern Matching for Zero Instruction +//===----------------------------------------------------------------------===// + +let Predicates = [HasBuddyExt] in { + def : Pat<(int_riscv_buddy_mzero timm:$md), + (AME_MZERO_PSEUDO timm:$md)>; +} + +//===----------------------------------------------------------------------===// +// Pattern Matching for Load Instructions +//===----------------------------------------------------------------------===// + +// mlae32.m - Load 32-bit left matrix A to tile register +let Predicates = [HasBuddyExt] in { + def : Pat<(int_riscv_buddy_mlae32_m timm:$md, iPTR:$rs1, i64:$rs2), + (AME_MLAE32_M_PSEUDO timm:$md, GPR:$rs1, GPR:$rs2)>; +} + +// mlbe32.m - Load 32-bit right matrix B to tile register +let Predicates = [HasBuddyExt] in { + def : Pat<(int_riscv_buddy_mlbe32_m timm:$md, iPTR:$rs1, i64:$rs2), + (AME_MLBE32_M_PSEUDO timm:$md, GPR:$rs1, GPR:$rs2)>; +} + +//===----------------------------------------------------------------------===// +// Pattern Matching for Store Instructions +//===----------------------------------------------------------------------===// + +// msce32.m - Store 32-bit output matrix C from accumulator +let Predicates = [HasBuddyExt] in { + def : Pat<(int_riscv_buddy_msce32_m timm:$ms3, iPTR:$rs1, i64:$rs2), + (AME_MSCE32_M_PSEUDO timm:$ms3, GPR:$rs1, GPR:$rs2)>; +} + +//===----------------------------------------------------------------------===// +// Pattern Matching for Tile-Based Matrix Multiply +//===----------------------------------------------------------------------===// + +// mma.w.mm.tile - int32 tile matrix multiply: md = md + ms1 x ms2 +let Predicates = [HasBuddyExt] in { + def : Pat<(int_riscv_buddy_mma_w_mm_tile timm:$md, timm:$ms1, timm:$ms2), + (AME_MMA_W_MM_TILE_PSEUDO timm:$md, timm:$ms1, timm:$ms2)>; +} + +//===----------------------------------------------------------------------===// +// AME Summary +//===----------------------------------------------------------------------===// +// Complete instruction set for basic matrix operations: +// +// Configuration: +// - msettilem, msettilemi: Set M dimension +// - msettilen, msettileni: Set N dimension +// - msettilek, msettileki: Set K dimension +// +// Load (Memory → Register): +// - mlae{8|16|32|64}.m: Load A into TileReg +// - mlbe{8|16|32|64}.m: Load B into TileReg +// - mlce{8|16|32|64}.m: Load C into AccReg +// +// Compute (TileReg × TileReg → AccReg): +// No-widen: +// - mma.{h|w|dw}.mm: Signed int16/32/64 +// Double-widen (AMUL ≥ 2): +// - mwma.{h|w}.mm: int16→int32, int32→int64 +// Quad-widen (AMUL ≥ 4, mmi8i32 MANDATORY): +// - mqma.b.mm: int8 → int32 (signed) +// - mqmau.b.mm: uint8 → uint32 (unsigned) +// - mqmasu.b.mm: int8 × uint8 → int32 +// - mqmaus.b.mm: uint8 × int8 → int32 +// +// Store (Register → Memory): +// - msae{8|16|32|64}.m: Store TileReg A +// - msbe{8|16|32|64}.m: Store TileReg B +// - msce{8|16|32|64}.m: Store AccReg C +// +// Utility: +// - mzero.m: Zero out AccReg +//===----------------------------------------------------------------------===//