diff --git a/llvm/include/llvm/IR/IntrinsicsRISCV.td b/llvm/include/llvm/IR/IntrinsicsRISCV.td index 32de1a10a4fc3..c02da5ccb4c0d 100644 --- a/llvm/include/llvm/IR/IntrinsicsRISCV.td +++ b/llvm/include/llvm/IR/IntrinsicsRISCV.td @@ -2017,3 +2017,4 @@ include "llvm/IR/IntrinsicsRISCVXsf.td" include "llvm/IR/IntrinsicsRISCVXCV.td" include "llvm/IR/IntrinsicsRISCVXAndes.td" include "llvm/IR/IntrinsicsRISCVXMIPS.td" +include "llvm/IR/IntrinsicsRISCVBuddyExt.td" diff --git a/llvm/include/llvm/IR/IntrinsicsRISCVBuddyExt.td b/llvm/include/llvm/IR/IntrinsicsRISCVBuddyExt.td new file mode 100644 index 0000000000000..496efec4c8d0e --- /dev/null +++ b/llvm/include/llvm/IR/IntrinsicsRISCVBuddyExt.td @@ -0,0 +1,407 @@ +//===- IntrinsicsRISCVBuddyExt.td -----------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This is the costum intrinsic definition file of RISC-V buddy extension. +// +//===----------------------------------------------------------------------===// +let TargetPrefix = "riscv" in +def int_riscv_mvin : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty],[]>; + +let TargetPrefix = "riscv" in +def int_riscv_mvin2 : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty],[]>; + +let TargetPrefix = "riscv" in +def int_riscv_mvin3 : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty],[]>; + +let TargetPrefix = "riscv" in +def int_riscv_mvout : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; + +let TargetPrefix = "riscv" in +def int_riscv_flush : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; + +let TargetPrefix = "riscv" in +def int_riscv_config : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; + +let TargetPrefix = "riscv" in +def int_riscv_preload : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; + +let TargetPrefix = "riscv" in +def int_riscv_compute_preloaded : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; + +let TargetPrefix = "riscv" in +def int_riscv_compute_accumulated : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; + +let TargetPrefix = "riscv" in +def int_riscv_loop_ws_config_bounds : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; + +let TargetPrefix = "riscv" in +def int_riscv_loop_ws_config_addrs_ab : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; + +let TargetPrefix = "riscv" in +def int_riscv_loop_ws_config_addrs_dc : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; + +let TargetPrefix = "riscv" in +def int_riscv_loop_ws_config_strides_ab : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; + +let TargetPrefix = "riscv" in +def int_riscv_loop_ws_config_strides_dc : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; + +let TargetPrefix = "riscv" in +def int_riscv_loop_ws : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; + +let TargetPrefix = "riscv" in +def int_riscv_loop_conv_ws : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; + +let TargetPrefix = "riscv" in +def int_riscv_loop_conv_ws_config1 : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; + +let TargetPrefix = "riscv" in +def int_riscv_loop_conv_ws_config2 : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; + +let TargetPrefix = "riscv" in +def int_riscv_loop_conv_ws_config3 : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; + +let TargetPrefix = "riscv" in +def int_riscv_loop_conv_ws_config4 : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; + +let TargetPrefix = "riscv" in +def int_riscv_loop_conv_ws_config5 : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; + +let TargetPrefix = "riscv" in +def int_riscv_loop_conv_ws_config6 : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; + +//===----------------------------------------------------------------------===// +// IME Extension Intrinsics +//===----------------------------------------------------------------------===// + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadot : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadotu : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadotsu : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadotus : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vfmadot : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadot1 : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadot1u : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadot1su : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadot1us : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadot2 : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadot2u : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadot2su : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadot2us : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadot3 : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadot3u : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadot3su : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadot3us : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadotn : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty, llvm_i64_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadotnu : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty, llvm_i64_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadotnsu : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty, llvm_i64_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadotnus : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty, llvm_i64_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vfmadot1 : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vfmadot2 : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vfmadot3 : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vfmadotn : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty, llvm_i64_ty], + [IntrNoMem]>; + +//===----------------------------------------------------------------------===// +// AME (RISC-V Matrix Extension) Intrinsics +//===----------------------------------------------------------------------===// + +// Matrix configuration intrinsics (with register) +let TargetPrefix = "riscv" in { + // msettype - Set matrix type configuration + def int_riscv_buddy_msettype : Intrinsic<[llvm_i64_ty], [llvm_i64_ty], [IntrNoMem]>; + + // msettilem - Set tile M dimension from register + def int_riscv_buddy_msettilem : Intrinsic<[llvm_i64_ty], [llvm_i64_ty], [IntrNoMem]>; + + // msettilen - Set tile N dimension from register + def int_riscv_buddy_msettilen : Intrinsic<[llvm_i64_ty], [llvm_i64_ty], [IntrNoMem]>; + + // msettilek - Set tile K dimension from register + def int_riscv_buddy_msettilek : Intrinsic<[llvm_i64_ty], [llvm_i64_ty], [IntrNoMem]>; +} + +// Matrix configuration intrinsics (with immediate) +let TargetPrefix = "riscv" in { + // msettilemi - Set tile M dimension with immediate + def int_riscv_buddy_msettilemi : Intrinsic<[], [llvm_i64_ty], + [IntrNoMem, IntrHasSideEffects, ImmArg>]>; + + // msettileni - Set tile N dimension with immediate + def int_riscv_buddy_msettileni : Intrinsic<[], [llvm_i64_ty], + [IntrNoMem, IntrHasSideEffects, ImmArg>]>; + + // msettileki - Set tile K dimension with immediate + def int_riscv_buddy_msettileki : Intrinsic<[], [llvm_i64_ty], + [IntrNoMem, IntrHasSideEffects, ImmArg>]>; +} + +// Matrix load intrinsics (load to tile register) +// Format: md = tile register index, base = address, stride = row byte stride +let TargetPrefix = "riscv" in { + // mlae32.m - Load 32-bit left matrix A to tile register + def int_riscv_buddy_mlae32_m : Intrinsic<[], + [llvm_i64_ty, llvm_ptr_ty, llvm_i64_ty], + [IntrReadMem, IntrHasSideEffects, ImmArg>]>; + + // mlae64.m - Load 64-bit left matrix A to tile register + def int_riscv_buddy_mlae64_m : Intrinsic<[], + [llvm_i64_ty, llvm_ptr_ty, llvm_i64_ty], + [IntrReadMem, IntrHasSideEffects, ImmArg>]>; + + // mlbe32.m - Load 32-bit right matrix B to tile register + def int_riscv_buddy_mlbe32_m : Intrinsic<[], + [llvm_i64_ty, llvm_ptr_ty, llvm_i64_ty], + [IntrReadMem, IntrHasSideEffects, ImmArg>]>; + + // mlbe64.m - Load 64-bit right matrix B to tile register + def int_riscv_buddy_mlbe64_m : Intrinsic<[], + [llvm_i64_ty, llvm_ptr_ty, llvm_i64_ty], + [IntrReadMem, IntrHasSideEffects, ImmArg>]>; + + // mlce32.m - Load 32-bit output matrix C to accumulator + def int_riscv_buddy_mlce32_m : Intrinsic<[], + [llvm_i64_ty, llvm_ptr_ty, llvm_i64_ty], + [IntrReadMem, IntrHasSideEffects, ImmArg>]>; + + // mlce64.m - Load 64-bit output matrix C to accumulator + def int_riscv_buddy_mlce64_m : Intrinsic<[], + [llvm_i64_ty, llvm_ptr_ty, llvm_i64_ty], + [IntrReadMem, IntrHasSideEffects, ImmArg>]>; +} + +// Matrix store intrinsics (store from tile register) +let TargetPrefix = "riscv" in { + // msce32.m - Store 32-bit output matrix C from accumulator + def int_riscv_buddy_msce32_m : Intrinsic<[], + [llvm_i64_ty, llvm_ptr_ty, llvm_i64_ty], + [IntrWriteMem, IntrHasSideEffects, ImmArg>]>; + + // msce64.m - Store 64-bit output matrix C from accumulator + def int_riscv_buddy_msce64_m : Intrinsic<[], + [llvm_i64_ty, llvm_ptr_ty, llvm_i64_ty], + [IntrWriteMem, IntrHasSideEffects, ImmArg>]>; +} + +// Matrix zero intrinsic +let TargetPrefix = "riscv" in { + def int_riscv_buddy_mzero : Intrinsic<[], [llvm_i64_ty], + [IntrNoMem, IntrHasSideEffects, ImmArg>]>; +} + +// Tile register matrix multiplication intrinsics (operate on tile registers) +let TargetPrefix = "riscv" in { + // mma.w.mm - int32 tile matrix multiply: md = md + ms1 x ms2 + def int_riscv_buddy_mma_w_mm_tile : Intrinsic<[], + [llvm_i64_ty, llvm_i64_ty, llvm_i64_ty], + [IntrHasSideEffects, ImmArg>, + ImmArg>, ImmArg>]>; + + // mma.dw.mm - int64 tile matrix multiply: md = md + ms1 x ms2 + def int_riscv_buddy_mma_dw_mm_tile : Intrinsic<[], + [llvm_i64_ty, llvm_i64_ty, llvm_i64_ty], + [IntrHasSideEffects, ImmArg>, + ImmArg>, ImmArg>]>; +} + +// Legacy matrix load/store intrinsics (for backward compatibility) +let TargetPrefix = "riscv" in { + // mlae - Load matrix A with element width + def int_riscv_buddy_mlae : Intrinsic<[llvm_anyvector_ty], + [llvm_ptr_ty, llvm_i64_ty, llvm_i64_ty], + [IntrReadMem]>; + + // mlbe - Load matrix B with element width + def int_riscv_buddy_mlbe : Intrinsic<[llvm_anyvector_ty], + [llvm_ptr_ty, llvm_i64_ty, llvm_i64_ty], + [IntrReadMem]>; + + // mlce - Load matrix C (accumulator) + def int_riscv_buddy_mlce : Intrinsic<[llvm_anyvector_ty], + [llvm_ptr_ty, llvm_i64_ty, llvm_i64_ty], + [IntrReadMem]>; + + // msce - Store matrix C (accumulator) + def int_riscv_buddy_msce : Intrinsic<[], + [llvm_anyvector_ty, llvm_ptr_ty, llvm_i64_ty, llvm_i64_ty], + [IntrWriteMem]>; +} + +// Signed integer matrix multiplication intrinsics +let TargetPrefix = "riscv" in { + // mqma.b.mm - int8 quad-widen matrix multiply (int8 x int8 -> int32) + def int_riscv_buddy_mqma_b_mm : Intrinsic<[], + [llvm_ptr_ty, llvm_ptr_ty, llvm_ptr_ty, + llvm_i64_ty, llvm_i64_ty, llvm_i64_ty], + []>; + + // mma.h.mm - int16 matrix multiply + def int_riscv_buddy_mma_h_mm : Intrinsic<[], + [llvm_ptr_ty, llvm_ptr_ty, llvm_ptr_ty, + llvm_i64_ty, llvm_i64_ty, llvm_i64_ty], + []>; + + // mma.w.mm - int32 matrix multiply + def int_riscv_buddy_mma_w_mm : Intrinsic<[], + [llvm_ptr_ty, llvm_ptr_ty, llvm_ptr_ty, + llvm_i64_ty, llvm_i64_ty, llvm_i64_ty], + []>; + + // mma.dw.mm - int64 matrix multiply + def int_riscv_buddy_mma_dw_mm : Intrinsic<[], + [llvm_ptr_ty, llvm_ptr_ty, llvm_ptr_ty, + llvm_i64_ty, llvm_i64_ty, llvm_i64_ty], + []>; + + // mwma.h.mm - int16 double-widen matrix multiply (int16 x int16 -> int32) + def int_riscv_buddy_mwma_h_mm : Intrinsic<[], + [llvm_ptr_ty, llvm_ptr_ty, llvm_ptr_ty, + llvm_i64_ty, llvm_i64_ty, llvm_i64_ty], + []>; +} + +// Unsigned integer matrix multiplication intrinsics +let TargetPrefix = "riscv" in { + // mqmau.b.mm - uint8 quad-widen matrix multiply + def int_riscv_buddy_mqmau_b_mm : Intrinsic<[], + [llvm_ptr_ty, llvm_ptr_ty, llvm_ptr_ty, + llvm_i64_ty, llvm_i64_ty, llvm_i64_ty], + []>; + + // mmau.h.mm - uint16 matrix multiply + def int_riscv_buddy_mmau_h_mm : Intrinsic<[], + [llvm_ptr_ty, llvm_ptr_ty, llvm_ptr_ty, + llvm_i64_ty, llvm_i64_ty, llvm_i64_ty], + []>; +} + +// Floating-point matrix multiplication intrinsics +let TargetPrefix = "riscv" in { + // mfma.f.mm - fp32 matrix multiply + def int_riscv_buddy_mfma_f_mm : Intrinsic<[], + [llvm_ptr_ty, llvm_ptr_ty, llvm_ptr_ty, + llvm_i64_ty, llvm_i64_ty, llvm_i64_ty], + []>; + + // mfma.hf.mm - fp16 matrix multiply + def int_riscv_buddy_mfma_hf_mm : Intrinsic<[], + [llvm_ptr_ty, llvm_ptr_ty, llvm_ptr_ty, + llvm_i64_ty, llvm_i64_ty, llvm_i64_ty], + []>; + + // mfwma.hf.mm - fp16 double-widen matrix multiply (fp16 x fp16 -> fp32) + def int_riscv_buddy_mfwma_hf_mm : Intrinsic<[], + [llvm_ptr_ty, llvm_ptr_ty, llvm_ptr_ty, + llvm_i64_ty, llvm_i64_ty, llvm_i64_ty], + []>; +} diff --git a/llvm/lib/Target/RISCV/Disassembler/RISCVDisassembler.cpp b/llvm/lib/Target/RISCV/Disassembler/RISCVDisassembler.cpp index 30a5d65a901d3..23fd4a133d863 100644 --- a/llvm/lib/Target/RISCV/Disassembler/RISCVDisassembler.cpp +++ b/llvm/lib/Target/RISCV/Disassembler/RISCVDisassembler.cpp @@ -93,6 +93,20 @@ static DecodeStatus DecodeSimpleRegisterClass(MCInst &Inst, uint32_t RegNo, return MCDisassembler::Success; } +static DecodeStatus DecodeTileRegRegisterClass(MCInst &Inst, uint32_t RegNo, + uint64_t Address, const MCDisassembler *Decoder) { + if (RegNo >= 8) + return MCDisassembler::Fail; + return DecodeGPRRegisterClass(Inst, RegNo, /*Address=*/0, Decoder); +} + +static DecodeStatus DecodeAccRegRegisterClass(MCInst &Inst, uint32_t RegNo, + uint64_t Address, const MCDisassembler *Decoder) { + if (RegNo >= 8) + return MCDisassembler::Fail; + return DecodeGPRRegisterClass(Inst, RegNo, /*Address=*/0, Decoder); +} + constexpr auto DecodeGPRRegisterClass = DecodeSimpleRegisterClass; diff --git a/llvm/lib/Target/RISCV/RISCVBuddyExt.td b/llvm/lib/Target/RISCV/RISCVBuddyExt.td new file mode 100644 index 0000000000000..192641e585398 --- /dev/null +++ b/llvm/lib/Target/RISCV/RISCVBuddyExt.td @@ -0,0 +1,132 @@ +//===- RISCVBuddyExt.td ---------------------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This is the top file for target definition of RISC-V buddy extension. +// +//===----------------------------------------------------------------------===// + +def FeatureBuddyExt + : SubtargetFeature<"buddyext", "HasBuddyExt", "true", + "'BuddyExt' (Buddy RISC-V Extension)">; +def HasBuddyExt : Predicate<"Subtarget->hasBuddyExt()">, + AssemblerPredicate<(all_of FeatureBuddyExt), + "'BuddyExt' (Buddy RISC-V Extension)">; + +//===----------------------------------------------------------------------===// +// AME (RISC-V Matrix Extension) Register Definitions +//===----------------------------------------------------------------------===// +// Reference: RISC-V Matrix Extension Specification +// +// Matrix Registers: +// - 8 Tile Registers (tr0-tr7): For input matrices A and B +// Each tile register has MLEN bits of state +// - 8 Accumulation Registers (acc0-acc7): For output/accumulation matrix C +// Each accumulation register has MLEN × AMUL bits of state +// +// AMUL (Accumulation MULtiplier): +// - Can be fractional (1/8, 1/4, 1/2) or integer (1, 2, 4, 8) +// - Determines the width ratio between acc and tr registers +// - For mmi8i32 (int8→int32 quad-widen), AMUL ≥ 4 +// +// Data Flow: +// Memory → tr (via mlae/mlbe) → acc (via mma/mwma/mqma) → Memory (via msce) +//===----------------------------------------------------------------------===// + +let Namespace = "RISCV" in { + +//===----------------------------------------------------------------------===// +// AME Tile Registers (tr0-tr7) +// Used for input matrices A and B +// Size: MLEN bits per register (hardware-defined) +//===----------------------------------------------------------------------===// + +// Base class for Tile Registers +class AMETileReg Enc, string n> : Register { + let HWEncoding{2-0} = Enc; + let HWEncoding{4-3} = 0b00; // Distinguish from accumulation registers +} + +// Define 8 Tile Registers: tr0-tr7 +def TR0 : AMETileReg<0, "tr0">; +def TR1 : AMETileReg<1, "tr1">; +def TR2 : AMETileReg<2, "tr2">; +def TR3 : AMETileReg<3, "tr3">; +def TR4 : AMETileReg<4, "tr4">; +def TR5 : AMETileReg<5, "tr5">; +def TR6 : AMETileReg<6, "tr6">; +def TR7 : AMETileReg<7, "tr7">; + +//===----------------------------------------------------------------------===// +// AME Accumulation Registers (acc0-acc7) +// Used for output/accumulation matrix C +// Size: MLEN × AMUL bits per register (hardware-defined) +// +// Note: AMUL can be: +// - Fractional (1/8, 1/4, 1/2): For C = A × Bᵀ mode with large K +// - Integer (1, 2, 4, 8): For widening operations +// * AMUL=4: Required for mmi8i32 (int8→int32 quad-widen) +// * AMUL=2: Required for mmi16i32 (int16→int32 double-widen) +// * AMUL=8: Required for mmi4i32 (int4→int32 oct-widen) +//===----------------------------------------------------------------------===// + +// Base class for Accumulation Registers +class AMEAccReg Enc, string n> : Register { + let HWEncoding{2-0} = Enc; + let HWEncoding{4-3} = 0b01; // Distinguish from tile registers +} + +// Define 8 Accumulation Registers: acc0-acc7 +def ACC0 : AMEAccReg<0, "acc0">; +def ACC1 : AMEAccReg<1, "acc1">; +def ACC2 : AMEAccReg<2, "acc2">; +def ACC3 : AMEAccReg<3, "acc3">; +def ACC4 : AMEAccReg<4, "acc4">; +def ACC5 : AMEAccReg<5, "acc5">; +def ACC6 : AMEAccReg<6, "acc6">; +def ACC7 : AMEAccReg<7, "acc7">; + +} // End Namespace = "RISCV" + +//===----------------------------------------------------------------------===// +// AME Register Classes +//===----------------------------------------------------------------------===// +// These register classes define the operand types for AME instructions +// +// Usage in instructions: +// - TileReg: For ms1, ms2 (source operands in multiplication) +// - AccReg: For md (destination/accumulator in multiplication) +// - TileReg: For load/store of input matrices (A, B) +// - AccReg: For load/store of output/accumulator (C) +//===----------------------------------------------------------------------===// + +// Tile Register class (tr0-tr7) +// Used for input operands in matrix multiplication +// Note: Size is set to 256 as a placeholder; actual size depends on MLEN +def TileReg : RegisterClass<"RISCV", [untyped], 256, + (add TR0, TR1, TR2, TR3, TR4, TR5, TR6, TR7)> { + let Size = 256; // Placeholder: actual MLEN is hardware-defined +} + +// Accumulation Register class (acc0-acc7) +// Used for output/accumulator in matrix multiplication +// Note: Size can be 256×AMUL where AMUL ∈ {1/8, 1/4, 1/2, 1, 2, 4, 8} +// We use 1024 as a reasonable upper bound (256 × 4 for int8→int32) +def AccReg : RegisterClass<"RISCV", [untyped], 1024, + (add ACC0, ACC1, ACC2, ACC3, ACC4, ACC5, ACC6, ACC7)> { + let Size = 1024; // Placeholder: actual MLEN×AMUL is hardware-defined +} + +include "RISCVInstrInfoBuddyExt.td" diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.td b/llvm/lib/Target/RISCV/RISCVInstrInfo.td index ef56275118f2e..10aca8d3ba67a 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.td @@ -2412,6 +2412,7 @@ include "RISCVInstrInfoXRivos.td" include "RISCVInstrInfoXAndes.td" include "RISCVInstrInfoXSpacemiT.td" include "RISCVInstrInfoXAIF.td" +include "RISCVBuddyExt.td" //===----------------------------------------------------------------------===// // Global ISel diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoBuddyExt.td b/llvm/lib/Target/RISCV/RISCVInstrInfoBuddyExt.td new file mode 100644 index 0000000000000..6861f0248a313 --- /dev/null +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoBuddyExt.td @@ -0,0 +1,1492 @@ +//===- RISCVInstrInfoBuddyExt.td ------------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This is the instruction information file of RISC-V buddy extension. +// +//===----------------------------------------------------------------------===// + +// include "llvm/IR/IntrinsicsRISCVBuddyExt.td" + +// Gemmini defines different values as func7 +// - https://github.com/ucb-bar/gemmini-rocc-tests/blob/e326e7c43457ff08669fe88edcaa395d846474d8/include/gemmini.h#L25 + +// Gemmini uses 0x3 (0b011) as func3 +// - https://github.com/IBM/rocc-software/blob/fddb795a0b52e82f8f4ce9ead9b1428440a62ab0/src/xcustom.h#L147 + +// Gemmini uses OPC_CUSTOM_3 +// - https://github.com/IBM/rocc-software/blob/fddb795a0b52e82f8f4ce9ead9b1428440a62ab0/src/xcustom.h#L123 + +let hasSideEffects = 1, mayLoad = 1, mayStore = 1, Predicates = [HasBuddyExt] in +def MVIN : RVInstR<0b0000010, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "mvin","$rs1, $rs2"> { + let rd = 0; +} + +let hasSideEffects = 1, mayLoad = 1, mayStore = 1, Predicates = [HasBuddyExt] in +def MVIN2 : RVInstR<0b0000001, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "mvin2","$rs1, $rs2"> { + let rd = 0; +} + +let hasSideEffects = 1, mayLoad = 1, mayStore = 1, Predicates = [HasBuddyExt] in +def MVIN3 : RVInstR<0b0001110, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "mvin3","$rs1, $rs2"> { + let rd = 0; +} + +let hasSideEffects = 1, mayLoad = 1, mayStore = 1, Predicates = [HasBuddyExt] in +def MVOUT : RVInstR<0b0000011, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "mvout","$rs1, $rs2">{ + let rd = 0; +} + +let hasSideEffects = 1, mayLoad = 1, mayStore = 1, Predicates = [HasBuddyExt] in +def FLUSH : RVInstR<0b0000111, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "flush", "$rs1"> { + let rd = 0; +} + +let Predicates = [HasBuddyExt] in +def CONFIG : RVInstR<0b0000000, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "config", "$rs1, $rs2"> { + let rd = 0; +} + +let hasSideEffects = 1, mayLoad = 1, mayStore =1, Predicates = [HasBuddyExt] in +def PRELOAD : RVInstR<0b0000110, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "preload", "$rs1, $rs2">{ + let rd = 0; +} + +let Predicates = [HasBuddyExt] in +def COMPUTE_PRELOADED : RVInstR<0b0000100, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "compute_preloaded", "$rs1, $rs2">{ + let rd = 0; +} + +let Predicates = [HasBuddyExt] in +def COMPUTE_ACCUMULATED : RVInstR<0b0000101, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "compute_accumulated", "$rs1, $rs2"> { + let rd = 0; +} + +let Predicates = [HasBuddyExt] in +def LOOP_WS_CONFIG_BOUNDS : RVInstR<0b0001001, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "loop_ws_config_bounds","$rs1, $rs2">{ + let rd = 0; +} + +let Predicates = [HasBuddyExt] in +def LOOP_WS_CONFIG_ADDRS_AB : RVInstR<0b0001010, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "loop_ws_config_addrs_ab", "$rs1, $rs2"> { + let rd = 0; +} + +let Predicates = [HasBuddyExt] in +def LOOP_WS_CONFIG_ADDRS_DC : RVInstR<0b0001011, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "loop_ws_config_addrs_dc", "$rs1, $rs2"> { + let rd = 0; +} + +let Predicates = [HasBuddyExt] in +def LOOP_WS_CONFIG_STRIDES_AB : RVInstR<0b0001100, 0b011, OPC_CUSTOM_3,(outs), + (ins GPR:$rs1, GPR:$rs2), "loop_ws_config_strides_ab", "$rs1, $rs2"> { + let rd = 0; +} + +let Predicates = [HasBuddyExt] in +def LOOP_WS_CONFIG_STRIDES_DC : RVInstR<0b0001101, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "loop_ws_config_strides_dc", "$rs1, $rs2"> { + let rd = 0; +} + +let Predicates = [HasBuddyExt] in +def LOOP_WS : RVInstR<0b0001000, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "loop_ws", "$rs1, $rs2"> { + let rd = 0; +} + +let hasSideEffects = 1, mayLoad = 1, mayStore =1, Predicates = [HasBuddyExt] in +def LOOP_CONV_WS : RVInstR<0b0001111, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "loop_conv_ws", "$rs1, $rs2"> { + let rd = 0; +} + +let Predicates = [HasBuddyExt] in +def LOOP_CONV_WS_CONFIG1 : RVInstR<0b0010000, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "loop_conv_ws_config1", "$rs1, $rs2"> { + let rd = 0; +} + +let Predicates = [HasBuddyExt] in +def LOOP_CONV_WS_CONFIG2 : RVInstR<0b0010001, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "loop_conv_ws_config2", "$rs1, $rs2"> { + let rd = 0; +} + +let Predicates = [HasBuddyExt] in +def LOOP_CONV_WS_CONFIG3 : RVInstR<0b0010010, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "loop_conv_ws_config3", "$rs1, $rs2"> { + let rd = 0; +} + +let Predicates = [HasBuddyExt] in +def LOOP_CONV_WS_CONFIG4 : RVInstR<0b0010011, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "loop_conv_ws_config4", "$rs1, $rs2"> { + let rd = 0; +} + +let Predicates = [HasBuddyExt] in +def LOOP_CONV_WS_CONFIG5 : RVInstR<0b0010100, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "loop_conv_ws_config5", "$rs1, $rs2"> { + let rd = 0; +} + +let Predicates = [HasBuddyExt] in +def LOOP_CONV_WS_CONFIG6 : RVInstR<0b0010101, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "loop_conv_ws_config6", "$rs1, $rs2"> { + let rd = 0; +} + +//===----------------------------------------------------------------------===// +// IME Extension Instructions +//===----------------------------------------------------------------------===// + +class RVInstIME funct7, bits<3> funct3, dag outs, dag ins, + string opcodestr, string argstr> + : RVInst { + bits<5> vs2; + bits<5> vs1; + bits<5> vd; + + let Inst{31-25} = funct7; + let Inst{24-20} = vs2; + let Inst{19-15} = vs1; + let Inst{14-12} = funct3; + let Inst{11-7} = vd; + let Inst{6-0} = OPC_CUSTOM_1.Value; + + let Uses = [VTYPE, VL]; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOT : RVInstIME<0b1110001, 0b000, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM4:$vs1, VRM4:$vs2), + "vmadot", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOTU : RVInstIME<0b1110001, 0b011, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM4:$vs1, VRM4:$vs2), + "vmadotu", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOTSU : RVInstIME<0b1110001, 0b001, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM4:$vs1, VRM4:$vs2), + "vmadotsu", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOTUS : RVInstIME<0b1110001, 0b010, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM4:$vs1, VRM4:$vs2), + "vmadotus", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VFMADOT : RVInstIME<0b1110101, 0b000, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM4:$vs1, VRM4:$vs2), + "vfmadot", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +//===----------------------------------------------------------------------===// +// IME Sliding-Window Instructions +//===----------------------------------------------------------------------===// + +// Integer slide-1 instructions +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOT1 : RVInstIME<0b1110010, 0b000, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vmadot1", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOT1U : RVInstIME<0b1110010, 0b011, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vmadot1u", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOT1SU : RVInstIME<0b1110010, 0b001, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vmadot1su", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOT1US : RVInstIME<0b1110010, 0b010, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vmadot1us", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +// Integer slide-2 instructions +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOT2 : RVInstIME<0b1110011, 0b000, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vmadot2", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOT2U : RVInstIME<0b1110011, 0b011, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vmadot2u", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOT2SU : RVInstIME<0b1110011, 0b001, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vmadot2su", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOT2US : RVInstIME<0b1110011, 0b010, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vmadot2us", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +// Integer slide-3 instructions +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOT3 : RVInstIME<0b1110100, 0b000, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vmadot3", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOT3U : RVInstIME<0b1110100, 0b011, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vmadot3u", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOT3SU : RVInstIME<0b1110100, 0b001, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vmadot3su", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOT3US : RVInstIME<0b1110100, 0b010, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vmadot3us", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +// Floating-point slide instructions +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VFMADOT1 : RVInstIME<0b1110110, 0b000, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vfmadot1", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VFMADOT2 : RVInstIME<0b1110111, 0b000, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vfmadot2", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VFMADOT3 : RVInstIME<0b1111000, 0b000, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vfmadot3", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_mvin GPR:$rs1, GPR:$rs2), (MVIN GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_mvin2 GPR:$rs1, GPR:$rs2), (MVIN2 GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_mvin3 GPR:$rs1, GPR:$rs2), (MVIN3 GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_mvout GPR:$rs1, GPR:$rs2), (MVOUT GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_flush GPR:$rs1, GPR:$rs2), (FLUSH GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_config GPR:$rs1, GPR:$rs2), (CONFIG GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_preload GPR:$rs1, GPR:$rs2), (PRELOAD GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_compute_preloaded GPR:$rs1, GPR:$rs2), (COMPUTE_PRELOADED GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_compute_accumulated GPR:$rs1, GPR:$rs2), (COMPUTE_ACCUMULATED GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_loop_ws_config_bounds GPR:$rs1, GPR:$rs2), (LOOP_WS_CONFIG_BOUNDS GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_loop_ws_config_addrs_ab GPR:$rs1, GPR:$rs2), (LOOP_WS_CONFIG_ADDRS_AB GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_loop_ws_config_addrs_dc GPR:$rs1, GPR:$rs2), (LOOP_WS_CONFIG_ADDRS_DC GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_loop_ws_config_strides_ab GPR:$rs1, GPR:$rs2), (LOOP_WS_CONFIG_STRIDES_AB GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_loop_ws_config_strides_dc GPR:$rs1, GPR:$rs2), (LOOP_WS_CONFIG_STRIDES_DC GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_loop_ws GPR:$rs1, GPR:$rs2), (LOOP_WS GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_loop_conv_ws GPR:$rs1, GPR:$rs2), (LOOP_CONV_WS GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_loop_conv_ws_config1 GPR:$rs1, GPR:$rs2), (LOOP_CONV_WS_CONFIG1 GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_loop_conv_ws_config2 GPR:$rs1, GPR:$rs2), (LOOP_CONV_WS_CONFIG2 GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_loop_conv_ws_config3 GPR:$rs1, GPR:$rs2), (LOOP_CONV_WS_CONFIG3 GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_loop_conv_ws_config4 GPR:$rs1, GPR:$rs2), (LOOP_CONV_WS_CONFIG4 GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_loop_conv_ws_config5 GPR:$rs1, GPR:$rs2), (LOOP_CONV_WS_CONFIG5 GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_loop_conv_ws_config6 GPR:$rs1, GPR:$rs2), (LOOP_CONV_WS_CONFIG6 GPR:$rs1, GPR:$rs2)>; + +//===----------------------------------------------------------------------===// +// IME Extension Patterns +//===----------------------------------------------------------------------===// + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadot nxv8i32:$vd, nxv32i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOT VRM4:$vd, VRM4:$vs1, VRM4:$vs2)>; +} + +// int16 vmadot patterns +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadot nxv8i32:$vd, nxv16i16:$vs1, nxv16i16:$vs2)), + (IME_VMADOT VRM4:$vd, VRM4:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadotu nxv8i32:$vd, nxv32i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOTU VRM4:$vd, VRM4:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadotsu nxv8i32:$vd, nxv32i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOTSU VRM4:$vd, VRM4:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadotus nxv8i32:$vd, nxv32i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOTUS VRM4:$vd, VRM4:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv16f16 (int_riscv_ime_vfmadot nxv16f16:$vd, nxv16f16:$vs1, nxv16f16:$vs2)), + (IME_VFMADOT VRM4:$vd, VRM4:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadot1 nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOT1 VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadot1u nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOT1U VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadot1su nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOT1SU VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadot1us nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOT1US VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadot2 nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOT2 VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadot2u nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOT2U VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadot2su nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOT2SU VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadot2us nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOT2US VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadot3 nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOT3 VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadot3u nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOT3U VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadot3su nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOT3SU VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadot3us nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOT3US VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv16f16 (int_riscv_ime_vfmadot1 nxv16f16:$vd, nxv32f16:$vs1, nxv16f16:$vs2)), + (IME_VFMADOT1 VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv16f16 (int_riscv_ime_vfmadot2 nxv16f16:$vd, nxv32f16:$vs1, nxv16f16:$vs2)), + (IME_VFMADOT2 VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv16f16 (int_riscv_ime_vfmadot3 nxv16f16:$vd, nxv32f16:$vs1, nxv16f16:$vs2)), + (IME_VFMADOT3 VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +class RVInstIMEN funct7, bits<3> funct3, dag outs, dag ins, + string opcodestr, string argstr> + : RVInst { + bits<5> vs2; + bits<5> rs1; // GPR for dynamic slide value + bits<5> vd; + + let Inst{31-25} = funct7; + let Inst{24-20} = vs2; + let Inst{19-15} = rs1; + let Inst{14-12} = funct3; + let Inst{11-7} = vd; + let Inst{6-0} = OPC_CUSTOM_1.Value; + + let Uses = [VTYPE, VL]; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOTN : RVInstIMEN<0b1111001, 0b000, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2, GPR:$rs1), + "vmadotn", "$vd, $vs1, $vs2, $rs1"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOTNU : RVInstIMEN<0b1111001, 0b011, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2, GPR:$rs1), + "vmadotnu", "$vd, $vs1, $vs2, $rs1"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOTNSU : RVInstIMEN<0b1111001, 0b001, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2, GPR:$rs1), + "vmadotnsu", "$vd, $vs1, $vs2, $rs1"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOTNUS : RVInstIMEN<0b1111001, 0b010, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2, GPR:$rs1), + "vmadotnus", "$vd, $vs1, $vs2, $rs1"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VFMADOTN : RVInstIMEN<0b1111010, 0b000, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2, GPR:$rs1), + "vfmadotn", "$vd, $vs1, $vs2, $rs1"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadotn nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2, GPR:$rs1)), + (IME_VMADOTN VRM4:$vd, VRM8:$vs1, VRM4:$vs2, GPR:$rs1)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadotnu nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2, GPR:$rs1)), + (IME_VMADOTNU VRM4:$vd, VRM8:$vs1, VRM4:$vs2, GPR:$rs1)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadotnsu nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2, GPR:$rs1)), + (IME_VMADOTNSU VRM4:$vd, VRM8:$vs1, VRM4:$vs2, GPR:$rs1)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadotnus nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2, GPR:$rs1)), + (IME_VMADOTNUS VRM4:$vd, VRM8:$vs1, VRM4:$vs2, GPR:$rs1)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv16f16 (int_riscv_ime_vfmadotn nxv16f16:$vd, nxv32f16:$vs1, nxv16f16:$vs2, GPR:$rs1)), + (IME_VFMADOTN VRM4:$vd, VRM8:$vs1, VRM4:$vs2, GPR:$rs1)>; +} + +//===----------------------------------------------------------------------===// +// AME (RISC-V Matrix Extension) 64-bit Instructions +//===----------------------------------------------------------------------===// +// Reference: RISC-V Matrix Extension Specification +// 64-bit encoding format with prefix 0111111 at bits [6:0] +// +// Register Model: +// - TileReg (tr0-tr7): Input matrices A and B, MLEN bits each +// - AccReg (acc0-acc7): Accumulator matrix C, MLEN×AMUL bits each +// +// Data Flow: +// Memory --[mlae/mlbe]--> TileReg --[mma/mwma/mqma]--> AccReg --[msce]--> Memory +// +// Matrix Multiplication Instruction Format (64-bit): +// | 63:59 | 58 | 57:55 | 54:52 | 51:49 | 48:47 | 46:44 | 43:39 | 38:32 | +// | sps | sp | typ2 | typ1 | typd | bma | frm | funct5 | opcode | +// | 31:26 | 25 | 24:20 | 19:15 | 14:12 | 11:7 | 6:0 | +// | funct6 | fp | ms2 | ms1 | funct3 | md | suffix | +// +// suffix = 0111111 (AME prefix) +// funct3 = 100 (matrix multiplication) +// +// Widening Instructions (required for AMUL > 1): +// - mqma.b.mm: int8 → int32 (quad-widen, AMUL ≥ 4, mmi8i32 MANDATORY) +// - mwma.h.mm: int16 → int32 (double-widen, AMUL ≥ 2) +// - mwma.w.mm: int32 → int64 (double-widen, AMUL ≥ 2) +//===----------------------------------------------------------------------===// + +// AME Opcode suffix (bits [6:0]) +def OPC_AME : RISCVOpcode<"OPC_AME", 0b0111111>; + +//===----------------------------------------------------------------------===// +// AME Custom Operand Types for Tile Register Indices +//===----------------------------------------------------------------------===// +// These operand types allow intrinsics to pass immediate indices (0-7) +// for tile and accumulator registers. The AsmString of pseudo instructions +// hardcodes "acc" and "tr" prefixes so that indices 0-7 are printed as +// acc0-acc7 and tr0-tr7 respectively, without modifying LLVM submodule. + +// AsmOperandClass for tile register index (0-7) +def AMETileIndexAsmOperand : AsmOperandClass { + let Name = "AMETileIndex"; + let RenderMethod = "addImmOperands"; + let PredicateMethod = "isUImm3"; + let DiagnosticType = "InvalidAMETileIndex"; +} + +// AsmOperandClass for accumulator register index (0-7) +def AMEAccIndexAsmOperand : AsmOperandClass { + let Name = "AMEAccIndex"; + let RenderMethod = "addImmOperands"; + let PredicateMethod = "isUImm3"; + let DiagnosticType = "InvalidAMEAccIndex"; +} + +// Operand type for TileReg index (0-7), printed with "tr" prefix in AsmString +def AMETileIndex : RISCVOp { + let ParserMatchClass = AMETileIndexAsmOperand; + let DecoderMethod = "decodeUImmOperand<3>"; + let OperandType = "OPERAND_UIMM3"; +} + +// Operand type for AccReg index (0-7), printed with "acc" prefix in AsmString +def AMEAccIndex : RISCVOp { + let ParserMatchClass = AMEAccIndexAsmOperand; + let DecoderMethod = "decodeUImmOperand<3>"; + let OperandType = "OPERAND_UIMM3"; +} + +//===----------------------------------------------------------------------===// +// AME 64-bit Instruction Format Base Class +//===----------------------------------------------------------------------===// + +// Base class for AME 64-bit matrix multiplication instructions +// Uses TileReg for inputs (ms1, ms2) and AccReg for output (md) +class RVInstAME64 + : RVInst64 { + // Low 32 bits + bits<5> md; + bits<5> ms1; + bits<5> ms2; + bits<6> funct6; + bit fp; + bits<3> funct3; + + // High 32 bits + bits<5> funct5; + bits<3> frm; + bits<2> bma; + bits<3> typd; + bits<3> typ1; + bits<3> typ2; + bit sp; + bits<5> sps; + bits<7> opcode_hi; // [38:32] + + // Encode low 32 bits (suffix word) + let Inst{6-0} = 0b0111111; // AME prefix + let Inst{11-7} = md; + let Inst{14-12} = funct3; + let Inst{19-15} = ms1; + let Inst{24-20} = ms2; + let Inst{25} = fp; + let Inst{31-26} = funct6; + + // Encode high 32 bits (opcode word) + let Inst{38-32} = opcode_hi; + let Inst{43-39} = funct5; + let Inst{46-44} = frm; + let Inst{48-47} = bma; + let Inst{51-49} = typd; + let Inst{54-52} = typ1; + let Inst{57-55} = typ2; + let Inst{58} = sp; + let Inst{63-59} = sps; +} + +//===----------------------------------------------------------------------===// +// AME Matrix Multiplication Instructions +//===----------------------------------------------------------------------===// +// Format: mma.{h|w|dw}.mm acc, tr, tr +// Semantics: acc = acc + tr1 * tr2 +// +// Data Flow: TileReg × TileReg → AccReg (accumulate) +// - ms1: TileReg for matrix A +// - ms2: TileReg for matrix B +// - md: AccReg for accumulation result C +// +// typ1/typ2/typd encoding: +// 000 = int8 (b), 001 = int16 (h), 010 = int32 (w), 011 = int64 (dw) +// 100 = use mtype.msew, 111 = int4 +// +// funct5 encoding: +// 00001 = mma (signed, no saturation) +// 00010 = mwma (double-widen) +// 00100 = mqma (quad-widen) +// 10001 = msma (signed, saturated) +//===----------------------------------------------------------------------===// + +// No-widen matrix multiply-accumulate: acc = acc + tr1 * tr2 +// Input and output have the same element width +class AME_MMA_MM typd_val, bits<3> typ1_val, bits<3> typ2_val, + string opcodestr> + : RVInstAME64<(outs AccReg:$md), + (ins AccReg:$md_in, TileReg:$ms1, TileReg:$ms2), + opcodestr, "$md, $ms1, $ms2"> { + let Constraints = "$md = $md_in"; // Accumulator constraint + let funct6 = 0b000000; + let fp = 0; + let funct3 = 0b100; // Matrix multiplication + let opcode_hi = 0b0000011; // xxyyy11 where xx=00, yyy=001 + let funct5 = 0b00001; // mma (signed, no saturation) + let frm = 0b000; + let bma = 0b00; // Default: not agnostic + let typd = typd_val; + let typ1 = typ1_val; + let typ2 = typ2_val; + let sp = 0; // No sparsity + let sps = 0b00000; +} + +// mma.h.mm - int16 × int16 → int16 accumulate (no widen) +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MMA_H_MM : AME_MMA_MM<0b001, 0b001, 0b001, "mma.h.mm">; + +// mma.w.mm - int32 × int32 → int32 accumulate (no widen) +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MMA_W_MM : AME_MMA_MM<0b010, 0b010, 0b010, "mma.w.mm">; + +// mma.dw.mm - int64 × int64 → int64 accumulate (no widen) +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MMA_DW_MM : AME_MMA_MM<0b011, 0b011, 0b011, "mma.dw.mm">; + +//===----------------------------------------------------------------------===// +// AME Widening Matrix Multiplication Instructions +//===----------------------------------------------------------------------===// +// Double-widen: output element is 2× width of input elements +// Quad-widen: output element is 4× width of input elements +// +// These require AMUL ≥ 2 (double) or AMUL ≥ 4 (quad) to ensure +// accumulator has sufficient width. +// +// Mandatory: mqma.b.mm (int8→int32) for mmi8i32 feature +//===----------------------------------------------------------------------===// + +// Double-widen: acc = acc + tr1 * tr2, output is 2× width +// Requires AMUL ≥ 2 +class AME_MWMA_MM typd_val, bits<3> typ1_val, bits<3> typ2_val, + string opcodestr> + : RVInstAME64<(outs AccReg:$md), + (ins AccReg:$md_in, TileReg:$ms1, TileReg:$ms2), + opcodestr, "$md, $ms1, $ms2"> { + let Constraints = "$md = $md_in"; + let funct6 = 0b000000; + let fp = 0; + let funct3 = 0b100; + let opcode_hi = 0b0000011; + let funct5 = 0b00010; // mwma (double-widen) + let frm = 0b000; + let bma = 0b00; + let typd = typd_val; + let typ1 = typ1_val; + let typ2 = typ2_val; + let sp = 0; + let sps = 0b00000; +} + +// mwma.h.mm - int16 × int16 → int32 accumulate (double-widen) +// Requires: AMUL ≥ 2, mmi16i32 feature +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MWMA_H_MM : AME_MWMA_MM<0b010, 0b001, 0b001, "mwma.h.mm">; + +// mwma.w.mm - int32 × int32 → int64 accumulate (double-widen) +// Requires: AMUL ≥ 2, mmi32i64 feature +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MWMA_W_MM : AME_MWMA_MM<0b011, 0b010, 0b010, "mwma.w.mm">; + +// Quad-widen: acc = acc + tr1 * tr2, output is 4× width +// Requires AMUL ≥ 4 +class AME_MQMA_MM typd_val, bits<3> typ1_val, bits<3> typ2_val, + string opcodestr> + : RVInstAME64<(outs AccReg:$md), + (ins AccReg:$md_in, TileReg:$ms1, TileReg:$ms2), + opcodestr, "$md, $ms1, $ms2"> { + let Constraints = "$md = $md_in"; + let funct6 = 0b000000; + let fp = 0; + let funct3 = 0b100; + let opcode_hi = 0b0000011; + let funct5 = 0b00100; // mqma (quad-widen) + let frm = 0b000; + let bma = 0b00; + let typd = typd_val; + let typ1 = typ1_val; + let typ2 = typ2_val; + let sp = 0; + let sps = 0b00000; +} + +// mqma.b.mm - int8 × int8 → int32 accumulate (quad-widen) +// MANDATORY for mmi8i32 feature (required by Spec) +// Requires: AMUL ≥ 4 +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MQMA_B_MM : AME_MQMA_MM<0b010, 0b000, 0b000, "mqma.b.mm">; + +//===----------------------------------------------------------------------===// +// AME Configuration Instructions +//===----------------------------------------------------------------------===// +// Configuration instruction format (64-bit): +// | 63:43 | 42:39 | 38:32 | +// | imm[31:11] | funct4 | opcode | +// | 31:26 | 25:20 | 19:15 | 14:12 | 11:7 | 6:0 | +// | funct6 | imm[10:5] | rs1 | funct3 | rd | suffix | +//===----------------------------------------------------------------------===// + +class RVInstAMEConfig64 funct6_val, string opcodestr> + : RVInst64<(outs GPR:$rd), (ins GPR:$rs1), + opcodestr, "$rd, $rs1", [], InstFormatOther> { + bits<5> rd; + bits<5> rs1; + + let Inst{6-0} = 0b0111111; // AME prefix + let Inst{11-7} = rd; + let Inst{14-12} = 0b000; // funct3 for config + let Inst{19-15} = rs1; + let Inst{25-20} = 0b000000; + let Inst{31-26} = funct6_val; + let Inst{38-32} = 0b0000011; // opcode + let Inst{42-39} = 0b0000; // funct4 + let Inst{63-43} = 0; // imm[31:11] = 0 +} + +class RVInstAMEConfigImm64 funct6_val, string opcodestr> + : RVInst64<(outs GPR:$rd), (ins uimm32:$imm), + opcodestr, "$rd, $imm", [], InstFormatOther> { + bits<5> rd; + bits<32> imm; + + let Inst{6-0} = 0b0111111; // AME prefix + let Inst{11-7} = rd; + let Inst{14-12} = 0b000; // funct3 for config + let Inst{19-15} = imm{4-0}; // imm[4:0] in rs1 field + let Inst{25-20} = imm{10-5}; // imm[10:5] + let Inst{31-26} = funct6_val; + let Inst{38-32} = 0b0000011; // opcode + let Inst{42-39} = 0b0000; // funct4 + let Inst{63-43} = imm{31-11}; // imm[31:11] +} + +// msettilem - set tile M dimension from register +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MSETTILEM : RVInstAMEConfig64<0b000100, "msettilem">; + +// msettilemi - set tile M dimension from immediate +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MSETTILEMI : RVInstAMEConfigImm64<0b000101, "msettilemi">; + +// msettilen - set tile N dimension from register +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MSETTILEN : RVInstAMEConfig64<0b001000, "msettilen">; + +// msettileni - set tile N dimension from immediate +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MSETTILENI : RVInstAMEConfigImm64<0b001001, "msettileni">; + +// msettilek - set tile K dimension from register +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MSETTILEK : RVInstAMEConfig64<0b001100, "msettilek">; + +// msettileki - set tile K dimension from immediate +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MSETTILEKI : RVInstAMEConfigImm64<0b001101, "msettileki">; + +//===----------------------------------------------------------------------===// +// AME Load/Store Instructions +//===----------------------------------------------------------------------===// +// Load/Store instruction format (64-bit): +// | 63:51 | 50:49 | 48:47 | 46:44 | 43:39 | 38:32 | +// | resv | mt | bma | eew | funct5 | opcode | +// | 31:26 | 25 | 24:20 | 19:15 | 14:12 | 11:7 | 6:0 | +// | funct6 | ls | rs2 | rs1 | funct3 | ms3/md | suffix | +// +// mt (matrix type): 00=accumulator(C), 01=left(A), 10=right(B), 11=result +// eew (element width): 000=8b, 001=16b, 010=32b, 011=64b +// +// Register Usage: +// - mt=01 (A) or mt=10 (B): Uses TileReg +// - mt=00 (C): Uses AccReg +// +// Data Flow Examples: +// mlae32.m tr0, (a0), a1 # Load matrix A into TileReg +// mlbe32.m tr1, (a0), a1 # Load matrix B into TileReg +// mqma.b.mm acc0, tr0, tr1 # Compute: AccReg = AccReg + TileReg × TileReg +// msce32.m acc0, (a0), a1 # Store AccReg to memory +//===----------------------------------------------------------------------===// + +// Load into TileReg (for matrix A and B) +class RVInstAMELoadTile64 mt_val, bits<3> eew_val, string opcodestr> + : RVInst64<(outs TileReg:$md), (ins GPR:$rs1, GPR:$rs2), + opcodestr, "$md, $rs1, $rs2", [], InstFormatOther> { + bits<5> md; + bits<5> rs1; + bits<5> rs2; + + let Inst{6-0} = 0b0111111; // AME prefix + let Inst{11-7} = md; + let Inst{14-12} = 0b001; // funct3 for load/store + let Inst{19-15} = rs1; + let Inst{24-20} = rs2; + let Inst{25} = 0; // ls = 0 for load + let Inst{31-26} = 0b000000; // funct6 + let Inst{38-32} = 0b0000011; // opcode + let Inst{43-39} = 0b00000; // funct5 + let Inst{46-44} = eew_val; // eew + let Inst{48-47} = 0b00; // bma + let Inst{50-49} = mt_val; // mt + let Inst{63-51} = 0; // reserved +} + +// Load into AccReg (for accumulator C) +class RVInstAMELoadAcc64 mt_val, bits<3> eew_val, string opcodestr> + : RVInst64<(outs AccReg:$md), (ins GPR:$rs1, GPR:$rs2), + opcodestr, "$md, $rs1, $rs2", [], InstFormatOther> { + bits<5> md; + bits<5> rs1; + bits<5> rs2; + + let Inst{6-0} = 0b0111111; + let Inst{11-7} = md; + let Inst{14-12} = 0b001; + let Inst{19-15} = rs1; + let Inst{24-20} = rs2; + let Inst{25} = 0; + let Inst{31-26} = 0b000000; + let Inst{38-32} = 0b0000011; + let Inst{43-39} = 0b00000; + let Inst{46-44} = eew_val; + let Inst{48-47} = 0b00; + let Inst{50-49} = mt_val; + let Inst{63-51} = 0; +} + +// Store from TileReg (for matrix A and B) +class RVInstAMEStoreTile64 mt_val, bits<3> eew_val, string opcodestr> + : RVInst64<(outs), (ins TileReg:$ms3, GPR:$rs1, GPR:$rs2), + opcodestr, "$ms3, $rs1, $rs2", [], InstFormatOther> { + bits<5> ms3; + bits<5> rs1; + bits<5> rs2; + + let Inst{6-0} = 0b0111111; + let Inst{11-7} = ms3; + let Inst{14-12} = 0b001; + let Inst{19-15} = rs1; + let Inst{24-20} = rs2; + let Inst{25} = 1; // ls = 1 for store + let Inst{31-26} = 0b000000; + let Inst{38-32} = 0b0000011; + let Inst{43-39} = 0b00000; + let Inst{46-44} = eew_val; + let Inst{48-47} = 0b00; + let Inst{50-49} = mt_val; + let Inst{63-51} = 0; +} + +// Store from AccReg (for accumulator C) +class RVInstAMEStoreAcc64 mt_val, bits<3> eew_val, string opcodestr> + : RVInst64<(outs), (ins AccReg:$ms3, GPR:$rs1, GPR:$rs2), + opcodestr, "$ms3, $rs1, $rs2", [], InstFormatOther> { + bits<5> ms3; + bits<5> rs1; + bits<5> rs2; + + let Inst{6-0} = 0b0111111; + let Inst{11-7} = ms3; + let Inst{14-12} = 0b001; + let Inst{19-15} = rs1; + let Inst{24-20} = rs2; + let Inst{25} = 1; + let Inst{31-26} = 0b000000; + let Inst{38-32} = 0b0000011; + let Inst{43-39} = 0b00000; + let Inst{46-44} = eew_val; + let Inst{48-47} = 0b00; + let Inst{50-49} = mt_val; + let Inst{63-51} = 0; +} + +//===----------------------------------------------------------------------===// +// Load matrix A (left operand) into TileReg - mlae*.m +// Syntax: mlae{8|16|32|64}.m tr, (rs1), rs2 +// tr: Destination TileReg +// rs1: Base address (GPR) +// rs2: Row stride in bytes (GPR) +//===----------------------------------------------------------------------===// +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 1, mayStore = 0 in { + def AME_MLAE8_M : RVInstAMELoadTile64<0b01, 0b000, "mlae8.m">; + def AME_MLAE16_M : RVInstAMELoadTile64<0b01, 0b001, "mlae16.m">; + def AME_MLAE32_M : RVInstAMELoadTile64<0b01, 0b010, "mlae32.m">; + def AME_MLAE64_M : RVInstAMELoadTile64<0b01, 0b011, "mlae64.m">; +} + +//===----------------------------------------------------------------------===// +// Load matrix B (right operand) into TileReg - mlbe*.m +// Syntax: mlbe{8|16|32|64}.m tr, (rs1), rs2 +//===----------------------------------------------------------------------===// +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 1, mayStore = 0 in { + def AME_MLBE8_M : RVInstAMELoadTile64<0b10, 0b000, "mlbe8.m">; + def AME_MLBE16_M : RVInstAMELoadTile64<0b10, 0b001, "mlbe16.m">; + def AME_MLBE32_M : RVInstAMELoadTile64<0b10, 0b010, "mlbe32.m">; + def AME_MLBE64_M : RVInstAMELoadTile64<0b10, 0b011, "mlbe64.m">; +} + +//===----------------------------------------------------------------------===// +// Load matrix C (accumulator) into AccReg - mlce*.m +// Syntax: mlce{8|16|32|64}.m acc, (rs1), rs2 +// acc: Destination AccReg (MLEN×AMUL bits) +// Note: eew here refers to the output element width after widening +//===----------------------------------------------------------------------===// +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 1, mayStore = 0 in { + def AME_MLCE8_M : RVInstAMELoadAcc64<0b00, 0b000, "mlce8.m">; + def AME_MLCE16_M : RVInstAMELoadAcc64<0b00, 0b001, "mlce16.m">; + def AME_MLCE32_M : RVInstAMELoadAcc64<0b00, 0b010, "mlce32.m">; + def AME_MLCE64_M : RVInstAMELoadAcc64<0b00, 0b011, "mlce64.m">; +} + +//===----------------------------------------------------------------------===// +// Store matrix A from TileReg - msae*.m +// Syntax: msae{8|16|32|64}.m tr, (rs1), rs2 +//===----------------------------------------------------------------------===// +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 1 in { + def AME_MSAE8_M : RVInstAMEStoreTile64<0b01, 0b000, "msae8.m">; + def AME_MSAE16_M : RVInstAMEStoreTile64<0b01, 0b001, "msae16.m">; + def AME_MSAE32_M : RVInstAMEStoreTile64<0b01, 0b010, "msae32.m">; + def AME_MSAE64_M : RVInstAMEStoreTile64<0b01, 0b011, "msae64.m">; +} + +//===----------------------------------------------------------------------===// +// Store matrix B from TileReg - msbe*.m +// Syntax: msbe{8|16|32|64}.m tr, (rs1), rs2 +//===----------------------------------------------------------------------===// +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 1 in { + def AME_MSBE8_M : RVInstAMEStoreTile64<0b10, 0b000, "msbe8.m">; + def AME_MSBE16_M : RVInstAMEStoreTile64<0b10, 0b001, "msbe16.m">; + def AME_MSBE32_M : RVInstAMEStoreTile64<0b10, 0b010, "msbe32.m">; + def AME_MSBE64_M : RVInstAMEStoreTile64<0b10, 0b011, "msbe64.m">; +} + +//===----------------------------------------------------------------------===// +// Store matrix C (accumulator) from AccReg - msce*.m +// Syntax: msce{8|16|32|64}.m acc, (rs1), rs2 +// acc: Source AccReg (MLEN×AMUL bits) +// This is the primary store for computation results +//===----------------------------------------------------------------------===// +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 1 in { + def AME_MSCE8_M : RVInstAMEStoreAcc64<0b00, 0b000, "msce8.m">; + def AME_MSCE16_M : RVInstAMEStoreAcc64<0b00, 0b001, "msce16.m">; + def AME_MSCE32_M : RVInstAMEStoreAcc64<0b00, 0b010, "msce32.m">; + def AME_MSCE64_M : RVInstAMEStoreAcc64<0b00, 0b011, "msce64.m">; +} + +//===----------------------------------------------------------------------===// +// AME Extension Pattern Matching +//===----------------------------------------------------------------------===// +// Connect LLVM intrinsics to AME machine instructions +// +// Register Model: +// - TileReg (tr0-tr7): For input matrices A and B +// - AccReg (acc0-acc7): For output/accumulator C +// +// Typical Data Flow for Matrix Multiplication (e.g., int8→int32): +// 1. msettilem/n/k: Configure tile dimensions +// 2. mlae8.m tr0, (a0), stride_a: Load matrix A into TileReg +// 3. mlbe8.m tr1, (a1), stride_b: Load matrix B into TileReg +// 4. mqma.b.mm acc0, tr0, tr1: Compute acc0 = acc0 + tr0 × tr1 +// 5. msce32.m acc0, (a2), stride_c: Store AccReg to memory +// +// Note: For quad-widen (int8→int32), input uses 8-bit load (mlae8/mlbe8) +// but output uses 32-bit store (msce32) because AMUL=4 widens the output. +//===----------------------------------------------------------------------===// + +// Configuration instruction patterns +// These use GPR operands and return the actual configured value +let Predicates = [HasBuddyExt] in { + // msettilem - set tile M dimension from GPR, returns actual value + def : Pat<(i64 (int_riscv_buddy_msettilem GPR:$rs1)), + (AME_MSETTILEM GPR:$rs1)>; + + // msettilen - set tile N dimension from GPR, returns actual value + def : Pat<(i64 (int_riscv_buddy_msettilen GPR:$rs1)), + (AME_MSETTILEN GPR:$rs1)>; + + // msettilek - set tile K dimension from GPR, returns actual value + def : Pat<(i64 (int_riscv_buddy_msettilek GPR:$rs1)), + (AME_MSETTILEK GPR:$rs1)>; +} + +//===----------------------------------------------------------------------===// +// AME Additional Integer Matrix Operations +//===----------------------------------------------------------------------===// +// Unsigned and mixed-sign matrix multiplication variants +// +// Naming convention: +// - mma: signed × signed +// - mmau: unsigned × unsigned +// - mmasu: signed × unsigned +// - mmaus: unsigned × signed +//===----------------------------------------------------------------------===// + +// Unsigned no-widen multiply-accumulate +class AME_MMAU_MM typd_val, bits<3> typ1_val, bits<3> typ2_val, + string opcodestr> + : RVInstAME64<(outs AccReg:$md), + (ins AccReg:$md_in, TileReg:$ms1, TileReg:$ms2), + opcodestr, "$md, $ms1, $ms2"> { + let Constraints = "$md = $md_in"; + let funct6 = 0b000000; + let fp = 0; + let funct3 = 0b100; + let opcode_hi = 0b0000011; + let funct5 = 0b01001; // mmau (unsigned) + let frm = 0b000; + let bma = 0b00; + let typd = typd_val; + let typ1 = typ1_val; + let typ2 = typ2_val; + let sp = 0; + let sps = 0b00000; +} + +// mmau.b.mm - uint8 × uint8 → uint32 (for unsigned int8) +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MMAU_B_MM : AME_MMAU_MM<0b010, 0b000, 0b000, "mmau.b.mm">; + +// Unsigned quad-widen multiply-accumulate +class AME_MQMAU_MM typd_val, bits<3> typ1_val, bits<3> typ2_val, + string opcodestr> + : RVInstAME64<(outs AccReg:$md), + (ins AccReg:$md_in, TileReg:$ms1, TileReg:$ms2), + opcodestr, "$md, $ms1, $ms2"> { + let Constraints = "$md = $md_in"; + let funct6 = 0b000000; + let fp = 0; + let funct3 = 0b100; + let opcode_hi = 0b0000011; + let funct5 = 0b01100; // mqmau (unsigned quad-widen) + let frm = 0b000; + let bma = 0b00; + let typd = typd_val; + let typ1 = typ1_val; + let typ2 = typ2_val; + let sp = 0; + let sps = 0b00000; +} + +// mqmau.b.mm - uint8 × uint8 → uint32 (quad-widen unsigned) +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MQMAU_B_MM : AME_MQMAU_MM<0b010, 0b000, 0b000, "mqmau.b.mm">; + +// Mixed-sign: signed × unsigned +class AME_MQMASU_MM typd_val, bits<3> typ1_val, bits<3> typ2_val, + string opcodestr> + : RVInstAME64<(outs AccReg:$md), + (ins AccReg:$md_in, TileReg:$ms1, TileReg:$ms2), + opcodestr, "$md, $ms1, $ms2"> { + let Constraints = "$md = $md_in"; + let funct6 = 0b000000; + let fp = 0; + let funct3 = 0b100; + let opcode_hi = 0b0000011; + let funct5 = 0b00101; // mqmasu (signed × unsigned quad-widen) + let frm = 0b000; + let bma = 0b00; + let typd = typd_val; + let typ1 = typ1_val; + let typ2 = typ2_val; + let sp = 0; + let sps = 0b00000; +} + +// mqmasu.b.mm - int8 × uint8 → int32 +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MQMASU_B_MM : AME_MQMASU_MM<0b010, 0b000, 0b000, "mqmasu.b.mm">; + +// Mixed-sign: unsigned × signed +class AME_MQMAUS_MM typd_val, bits<3> typ1_val, bits<3> typ2_val, + string opcodestr> + : RVInstAME64<(outs AccReg:$md), + (ins AccReg:$md_in, TileReg:$ms1, TileReg:$ms2), + opcodestr, "$md, $ms1, $ms2"> { + let Constraints = "$md = $md_in"; + let funct6 = 0b000000; + let fp = 0; + let funct3 = 0b100; + let opcode_hi = 0b0000011; + let funct5 = 0b00110; // mqmaus (unsigned × signed quad-widen) + let frm = 0b000; + let bma = 0b00; + let typd = typd_val; + let typ1 = typ1_val; + let typ2 = typ2_val; + let sp = 0; + let sps = 0b00000; +} + +// mqmaus.b.mm - uint8 × int8 → int32 +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MQMAUS_B_MM : AME_MQMAUS_MM<0b010, 0b000, 0b000, "mqmaus.b.mm">; + +//===----------------------------------------------------------------------===// +// AME Zero/Initialize Instructions +//===----------------------------------------------------------------------===// +// mzero - Zero out an accumulation register +// Useful for initializing before accumulation loop + +class AME_MZERO + : RVInst64<(outs AccReg:$md), (ins), + opcodestr, "$md", [], InstFormatOther> { + bits<5> md; + + let Inst{6-0} = 0b0111111; + let Inst{11-7} = md; + let Inst{14-12} = 0b101; // funct3 for arithmetic + let Inst{19-15} = 0b00000; + let Inst{24-20} = 0b00000; + let Inst{25} = 0; + let Inst{31-26} = 0b000000; + let Inst{38-32} = 0b0000011; + let Inst{43-39} = 0b10000; // funct5 for zero + let Inst{63-44} = 0; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MZERO_M : AME_MZERO<"mzero.m">; + +//===----------------------------------------------------------------------===// +// AME Intrinsic Pattern Matching +//===----------------------------------------------------------------------===// +// These patterns map LLVM intrinsics to AME machine instructions. +// +// Note: AME intrinsics use i64 indices for tile registers instead of +// actual register operands. This is because the MLIR lowering generates +// calls with constant indices that get mapped to physical registers +// at the final code generation stage. +// +// For tile-based operations, the tile register index (0-7) is encoded +// directly into the instruction's register field. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Configuration Instruction Patterns +//===----------------------------------------------------------------------===// + +// msettilem - set tile M dimension from GPR +let Predicates = [HasBuddyExt] in { + def : Pat<(i64 (int_riscv_buddy_msettilem i64:$rs1)), + (AME_MSETTILEM GPR:$rs1)>; +} + +// msettilen - set tile N dimension from GPR +let Predicates = [HasBuddyExt] in { + def : Pat<(i64 (int_riscv_buddy_msettilen i64:$rs1)), + (AME_MSETTILEN GPR:$rs1)>; +} + +// msettilek - set tile K dimension from GPR +let Predicates = [HasBuddyExt] in { + def : Pat<(i64 (int_riscv_buddy_msettilek i64:$rs1)), + (AME_MSETTILEK GPR:$rs1)>; +} + +//===----------------------------------------------------------------------===// +// Pseudo Instructions for Index-Based Operations +//===----------------------------------------------------------------------===// +// These pseudo instructions accept i64 indices and have AsmString for direct +// assembly output. This allows the pseudo instructions to be printed directly +// without needing complex expansion logic. +// +// For load/store/mma instructions, the AsmString hardcodes "acc" and "tr" +// prefixes so that indices 0-7 are printed as acc0-acc7 and tr0-tr7. +//===----------------------------------------------------------------------===// + +// Pseudo instruction for msettilemi with i64 immediate +// Output: msettilemi x0, +let Predicates = [HasBuddyExt], hasSideEffects = 1, mayLoad = 0, mayStore = 0, + isCodeGenOnly = 1 in +def AME_MSETTILEMI_PSEUDO : Pseudo<(outs), (ins i64imm:$imm), []> { + let AsmString = "msettilemi\tx0, $imm"; +} + +// Pseudo instruction for msettileni with i64 immediate +let Predicates = [HasBuddyExt], hasSideEffects = 1, mayLoad = 0, mayStore = 0, + isCodeGenOnly = 1 in +def AME_MSETTILENI_PSEUDO : Pseudo<(outs), (ins i64imm:$imm), []> { + let AsmString = "msettileni\tx0, $imm"; +} + +// Pseudo instruction for msettileki with i64 immediate +let Predicates = [HasBuddyExt], hasSideEffects = 1, mayLoad = 0, mayStore = 0, + isCodeGenOnly = 1 in +def AME_MSETTILEKI_PSEUDO : Pseudo<(outs), (ins i64imm:$imm), []> { + let AsmString = "msettileki\tx0, $imm"; +} + +// Pseudo instruction for mzero with accumulator index +let Predicates = [HasBuddyExt], hasSideEffects = 1, mayLoad = 0, mayStore = 0, + isCodeGenOnly = 1 in +def AME_MZERO_PSEUDO : Pseudo<(outs), (ins AMEAccIndex:$md), []> { + let AsmString = "mzero.m\tacc$md"; +} + +// Pseudo instruction for mlae32.m with tile index +let Predicates = [HasBuddyExt], hasSideEffects = 1, mayLoad = 1, mayStore = 0, + isCodeGenOnly = 1 in +def AME_MLAE32_M_PSEUDO : Pseudo<(outs), (ins AMETileIndex:$md, GPR:$rs1, GPR:$rs2), []> { + let AsmString = "mlae32.m\ttr$md, ($rs1), $rs2"; +} + +// Pseudo instruction for mlbe32.m with tile index +let Predicates = [HasBuddyExt], hasSideEffects = 1, mayLoad = 1, mayStore = 0, + isCodeGenOnly = 1 in +def AME_MLBE32_M_PSEUDO : Pseudo<(outs), (ins AMETileIndex:$md, GPR:$rs1, GPR:$rs2), []> { + let AsmString = "mlbe32.m\ttr$md, ($rs1), $rs2"; +} + +// Pseudo instruction for msce32.m with accumulator index +let Predicates = [HasBuddyExt], hasSideEffects = 1, mayLoad = 0, mayStore = 1, + isCodeGenOnly = 1 in +def AME_MSCE32_M_PSEUDO : Pseudo<(outs), (ins AMEAccIndex:$ms3, GPR:$rs1, GPR:$rs2), []> { + let AsmString = "msce32.m\tacc$ms3, ($rs1), $rs2"; +} + +// Pseudo instruction for mma.w.mm.tile with indices +// md: AccReg index (0-7), ms1/ms2: TileReg indices (0-7) +let Predicates = [HasBuddyExt], hasSideEffects = 1, mayLoad = 1, mayStore = 1, + isCodeGenOnly = 1 in +def AME_MMA_W_MM_TILE_PSEUDO : Pseudo<(outs), + (ins AMEAccIndex:$md, AMETileIndex:$ms1, AMETileIndex:$ms2), []> { + let AsmString = "mma.w.mm\tacc$md, tr$ms1, tr$ms2"; +} + +//===----------------------------------------------------------------------===// +// Pattern Matching for Immediate Configuration Instructions +//===----------------------------------------------------------------------===// + +// msettilemi - set tile M dimension with immediate +let Predicates = [HasBuddyExt] in { + def : Pat<(int_riscv_buddy_msettilemi timm:$imm), + (AME_MSETTILEMI_PSEUDO timm:$imm)>; +} + +// msettileni - set tile N dimension with immediate +let Predicates = [HasBuddyExt] in { + def : Pat<(int_riscv_buddy_msettileni timm:$imm), + (AME_MSETTILENI_PSEUDO timm:$imm)>; +} + +// msettileki - set tile K dimension with immediate +let Predicates = [HasBuddyExt] in { + def : Pat<(int_riscv_buddy_msettileki timm:$imm), + (AME_MSETTILEKI_PSEUDO timm:$imm)>; +} + +//===----------------------------------------------------------------------===// +// Pattern Matching for Zero Instruction +//===----------------------------------------------------------------------===// + +let Predicates = [HasBuddyExt] in { + def : Pat<(int_riscv_buddy_mzero timm:$md), + (AME_MZERO_PSEUDO timm:$md)>; +} + +//===----------------------------------------------------------------------===// +// Pattern Matching for Load Instructions +//===----------------------------------------------------------------------===// + +// mlae32.m - Load 32-bit left matrix A to tile register +let Predicates = [HasBuddyExt] in { + def : Pat<(int_riscv_buddy_mlae32_m timm:$md, iPTR:$rs1, i64:$rs2), + (AME_MLAE32_M_PSEUDO timm:$md, GPR:$rs1, GPR:$rs2)>; +} + +// mlbe32.m - Load 32-bit right matrix B to tile register +let Predicates = [HasBuddyExt] in { + def : Pat<(int_riscv_buddy_mlbe32_m timm:$md, iPTR:$rs1, i64:$rs2), + (AME_MLBE32_M_PSEUDO timm:$md, GPR:$rs1, GPR:$rs2)>; +} + +//===----------------------------------------------------------------------===// +// Pattern Matching for Store Instructions +//===----------------------------------------------------------------------===// + +// msce32.m - Store 32-bit output matrix C from accumulator +let Predicates = [HasBuddyExt] in { + def : Pat<(int_riscv_buddy_msce32_m timm:$ms3, iPTR:$rs1, i64:$rs2), + (AME_MSCE32_M_PSEUDO timm:$ms3, GPR:$rs1, GPR:$rs2)>; +} + +//===----------------------------------------------------------------------===// +// Pattern Matching for Tile-Based Matrix Multiply +//===----------------------------------------------------------------------===// + +// mma.w.mm.tile - int32 tile matrix multiply: md = md + ms1 x ms2 +let Predicates = [HasBuddyExt] in { + def : Pat<(int_riscv_buddy_mma_w_mm_tile timm:$md, timm:$ms1, timm:$ms2), + (AME_MMA_W_MM_TILE_PSEUDO timm:$md, timm:$ms1, timm:$ms2)>; +} + +//===----------------------------------------------------------------------===// +// AME Summary +//===----------------------------------------------------------------------===// +// Complete instruction set for basic matrix operations: +// +// Configuration: +// - msettilem, msettilemi: Set M dimension +// - msettilen, msettileni: Set N dimension +// - msettilek, msettileki: Set K dimension +// +// Load (Memory → Register): +// - mlae{8|16|32|64}.m: Load A into TileReg +// - mlbe{8|16|32|64}.m: Load B into TileReg +// - mlce{8|16|32|64}.m: Load C into AccReg +// +// Compute (TileReg × TileReg → AccReg): +// No-widen: +// - mma.{h|w|dw}.mm: Signed int16/32/64 +// Double-widen (AMUL ≥ 2): +// - mwma.{h|w}.mm: int16→int32, int32→int64 +// Quad-widen (AMUL ≥ 4, mmi8i32 MANDATORY): +// - mqma.b.mm: int8 → int32 (signed) +// - mqmau.b.mm: uint8 → uint32 (unsigned) +// - mqmasu.b.mm: int8 × uint8 → int32 +// - mqmaus.b.mm: uint8 × int8 → int32 +// +// Store (Register → Memory): +// - msae{8|16|32|64}.m: Store TileReg A +// - msbe{8|16|32|64}.m: Store TileReg B +// - msce{8|16|32|64}.m: Store AccReg C +// +// Utility: +// - mzero.m: Zero out AccReg +//===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h index 907f470c0d248..d0c123c3bc600 100644 --- a/mlir/include/mlir/IR/Visitors.h +++ b/mlir/include/mlir/IR/Visitors.h @@ -120,9 +120,10 @@ void walk(Operation *op, function_ref callback, WalkOrder order) { for (auto ®ion : Iterator::makeIterable(*op)) { // Early increment here in the case where the block is erased. - // PostOrderTraversal keeps state outside of iterators, so store it here. - auto &&It = Iterator::makeIterable(region); - for (auto &block : llvm::make_early_inc_range(It)) { + // Store the block range to ensure the iteratable (e.g., + // PostOrderTraversal) outlives the iterators of make_early_inc_range. + auto &&blockRange = Iterator::makeIterable(region); + for (auto &block : llvm::make_early_inc_range(blockRange)) { if (order == WalkOrder::PreOrder) callback(&block); for (auto &nestedOp : Iterator::makeIterable(block)) @@ -196,9 +197,10 @@ WalkResult walk(Operation *op, function_ref callback, WalkOrder order) { for (auto ®ion : Iterator::makeIterable(*op)) { // Early increment here in the case where the block is erased. - // PostOrderTraversal keeps state outside of iterators, so store it here. - auto &&It = Iterator::makeIterable(region); - for (auto &block : llvm::make_early_inc_range(It)) { + // Store the block range to ensure the iteratable (e.g., + // PostOrderTraversal) outlives the iterators of make_early_inc_range. + auto &&blockRange = Iterator::makeIterable(region); + for (auto &block : llvm::make_early_inc_range(blockRange)) { if (order == WalkOrder::PreOrder) { WalkResult result = callback(&block); if (result.wasSkipped()) diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 43a6816c6d863..ae0982830e138 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -48,6 +48,9 @@ bool shape::isExtentTensorType(Type type) { LogicalResult shape::getShapeVec(Value input, SmallVectorImpl &shapeValues) { + // Look through tensor.cast operations to find the underlying shape. + if (auto castOp = input.getDefiningOp()) + return getShapeVec(castOp.getSource(), shapeValues); if (auto inputOp = input.getDefiningOp()) { auto type = llvm::cast(inputOp.getArg().getType()); if (!type.hasRank()) @@ -799,27 +802,27 @@ struct CanonicalizeCastExtentTensorOperandsPattern LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { - // Canonicalize operands. + // Canonicalize operands by stripping information-losing tensor.cast ops. + SmallVector newOperands; bool anyChange = false; - auto canonicalizeOperand = [&](Value operand) -> Value { + for (Value operand : op.getShapes()) { if (auto castOp = operand.getDefiningOp()) { // Only eliminate the cast if it holds no shape information. - bool isInformationLoosingCast = - llvm::cast(castOp.getType()).isDynamicDim(0); - if (isInformationLoosingCast) { + if (llvm::cast(castOp.getType()).isDynamicDim(0)) { anyChange = true; - return castOp.getSource(); + newOperands.push_back(castOp.getSource()); + continue; } } - return operand; - }; - auto newOperands = - llvm::map_to_vector<8>(op.getOperands(), canonicalizeOperand); + newOperands.push_back(operand); + } // Rewrite op if any change required. if (!anyChange) return failure(); - rewriter.replaceOpWithNewOp(op, op->getResultTypes(), newOperands); + rewriter.modifyOpInPlace(op, [&]() { + op.getShapesMutable().assign(newOperands); + }); return success(); } }; @@ -1017,6 +1020,22 @@ OpFoldResult CstrBroadcastableOp::fold(FoldAdaptor adaptor) { }()) return BoolAttr::get(getContext(), true); + // No broadcasting is needed if all operands but one are scalar, using + // getShapeVec to look through tensor.cast and shape_of ops. + if ([&] { + bool nonScalarSeen = false; + for (auto shapeValue : getShapes()) { + SmallVector extents; + if (failed(getShapeVec(shapeValue, extents)) || !extents.empty()) { + if (nonScalarSeen) + return false; + nonScalarSeen = true; + } + } + return true; + }()) + return BoolAttr::get(getContext(), true); + // Because a failing witness result here represents an eventual assertion // failure, we do not replace it with a constant witness. return nullptr; diff --git a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp index 9b16e09124aa3..f9fa904a30ee8 100644 --- a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp @@ -10,19 +10,21 @@ // JIT engine. // //===----------------------------------------------------------------------===// -#include "mlir/ExecutionEngine/ExecutionEngine.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/ExecutionEngine/ExecutionEngine.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Support/FileUtilities.h" #include "mlir/Target/LLVMIR/Export.h" #include "llvm/ExecutionEngine/JITEventListener.h" +#include "llvm/ExecutionEngine/JITLink/JITLinkMemoryManager.h" #include "llvm/ExecutionEngine/ObjectCache.h" #include "llvm/ExecutionEngine/Orc/CompileUtils.h" #include "llvm/ExecutionEngine/Orc/ExecutionUtils.h" #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" #include "llvm/ExecutionEngine/Orc/IRTransformLayer.h" #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" +#include "llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h" #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" #include "llvm/IR/IRBuilder.h" #include "llvm/MC/TargetRegistry.h" @@ -314,27 +316,57 @@ ExecutionEngine::create(Operation *m, const ExecutionEngineOptions &options, auto objectLinkingLayerCreator = [&](ExecutionSession &session) { // Needed to respect AArch64 ABI requirements on the distance between // TEXT and GOT sections. - bool reserveAlloc = llvmModule->getTargetTriple().isAArch64(); - auto objectLayer = std::make_unique( - session, [sectionMemoryMapper = options.sectionMemoryMapper, - reserveAlloc](const MemoryBuffer &) { - return std::make_unique(sectionMemoryMapper, - reserveAlloc); - }); - - // Register JIT event listeners if they are enabled. - if (engine->gdbListener) - objectLayer->registerJITEventListener(*engine->gdbListener); - if (engine->perfListener) - objectLayer->registerJITEventListener(*engine->perfListener); + + // Check if we should use ObjectLinkingLayer (JITLink) + // JITLink supports modern architectures like RISC-V, AArch64 + // RuntimeDyld is older and provides better compatibility with legacy + // platforms + + // Decide which layer to use + bool useJITLink = llvmModule->getTargetTriple().isAArch64() || + llvmModule->getTargetTriple().isRISCV(); + + std::unique_ptr objectLayer; + + if (useJITLink) { + // JITLink path + objectLayer = std::make_unique(session); + + LLVM_DEBUG(llvm::dbgs() << "Using ObjectLinkingLayer (JITLink)\n"); + + } else { + // RuntimeDyld path + auto rtDyldLayer = std::make_unique( + session, + [sectionMemoryMapper = + options.sectionMemoryMapper](const llvm::MemoryBuffer &) + -> std::unique_ptr { + return std::make_unique(sectionMemoryMapper); + }); + + // Only RTDyld supports listener + if (engine->gdbListener) + rtDyldLayer->registerJITEventListener(*engine->gdbListener); + + if (engine->perfListener) + rtDyldLayer->registerJITEventListener(*engine->perfListener); + + LLVM_DEBUG(llvm::dbgs() << "Using RTDyldObjectLinkingLayer\n"); + + // Upcast + objectLayer = std::move(rtDyldLayer); + } // COFF format binaries (Windows) need special handling to deal with // exported symbol visibility. // cf llvm/lib/ExecutionEngine/Orc/LLJIT.cpp LLJIT::createObjectLinkingLayer const llvm::Triple &targetTriple = llvmModule->getTargetTriple(); - if (targetTriple.isOSBinFormatCOFF()) { - objectLayer->setOverrideObjectFlagsWithResponsibilityFlags(true); - objectLayer->setAutoClaimResponsibilityForObjectSymbols(true); + if (!useJITLink && targetTriple.isOSBinFormatCOFF()) { + if (auto *rtDyldLayer = dyn_cast( + objectLayer.get())) { + rtDyldLayer->setOverrideObjectFlagsWithResponsibilityFlags(true); + rtDyldLayer->setAutoClaimResponsibilityForObjectSymbols(true); + } } // Resolve symbols from shared libraries. diff --git a/mlir/lib/ExecutionEngine/JitRunner.cpp b/mlir/lib/ExecutionEngine/JitRunner.cpp index db0516533afcb..4b52669df9a6e 100644 --- a/mlir/lib/ExecutionEngine/JitRunner.cpp +++ b/mlir/lib/ExecutionEngine/JitRunner.cpp @@ -31,6 +31,7 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/LegacyPassNameParser.h" +#include "llvm/Support/CodeGen.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FileUtilities.h" @@ -361,6 +362,11 @@ int mlir::JitRunnerMain(int argc, char **argv, const DialectRegistry ®istry, tmBuilderOrError->getTargetTriple().setArchName(options.mArch); } + if (tmBuilderOrError->getTargetTriple().isRISCV()){ + tmBuilderOrError->setRelocationModel(llvm::Reloc::PIC_); + tmBuilderOrError->setCodeModel(llvm::CodeModel::Medium); + } + // Build TargetMachine auto tmOrError = tmBuilderOrError->createTargetMachine(); diff --git a/mlir/test/mlir-runner/test-expand-math-approx.mlir b/mlir/test/mlir-runner/test-expand-math-approx.mlir index 06b3171a2349e..5ba0f7106e101 100644 --- a/mlir/test/mlir-runner/test-expand-math-approx.mlir +++ b/mlir/test/mlir-runner/test-expand-math-approx.mlir @@ -233,12 +233,12 @@ func.func @powf() { %g_p = arith.constant 23598.0 : f64 call @func_powff64(%g, %g_p) : (f64, f64) -> () - // CHECK-NEXT: -nan + // CHECK-NEXT: {{-?}}nan %h = arith.constant 1.0 : f64 %h_p = arith.constant 0xfff0000001000000 : f64 call @func_powff64(%h, %h_p) : (f64, f64) -> () - // CHECK-NEXT: -nan + // CHECK-NEXT: {{-?}}nan %i = arith.constant 1.0 : f32 %i_p = arith.constant 0xffffffff : f32 call @func_powff32(%i, %i_p) : (f32, f32) -> () diff --git a/mlir/test/python/integration/dialects/linalg/opsrun.py b/mlir/test/python/integration/dialects/linalg/opsrun.py index 8eff573f98ad3..665879acb5cc9 100644 --- a/mlir/test/python/integration/dialects/linalg/opsrun.py +++ b/mlir/test/python/integration/dialects/linalg/opsrun.py @@ -19,6 +19,11 @@ def log(*args): sys.stderr.flush() +def run(f): + f() + return f + + fill_boiler = """ func.func @main() -> i32 attributes {llvm.emit_c_interface} { %O0 = memref.alloc() : memref @@ -177,7 +182,7 @@ def fill_2d_on_buffers(value, out): # CHECK: RESULT: 6 -test_fill_builtin() +run(test_fill_builtin) def test_fill_generic(): @@ -211,7 +216,7 @@ def fill_2d_on_buffers(value, out): # CHECK: RESULT: 6 -test_fill_generic() +run(test_fill_generic) def test_fill_rng_builtin(): @@ -238,7 +243,7 @@ def fill_rng_on_buffers(min, max, seed, out): # CHECK: RESULT: -480 -test_fill_rng_builtin() +run(test_fill_rng_builtin) def test_fill_rng_generic(): @@ -265,7 +270,7 @@ def fill_rng_on_buffers(min, max, seed, out): # CHECK: RESULT: -480 -test_fill_rng_generic() +run(test_fill_rng_generic) def test_max_pooling_builtin(): @@ -299,7 +304,7 @@ def pooling_on_buffers(input, shape, output): # CHECK: RESULT: 42 -test_max_pooling_builtin() +run(test_max_pooling_builtin) def test_max_pooling_generic(): @@ -338,7 +343,7 @@ def pooling_on_buffers(input, shape, output): # CHECK: RESULT: 42 -test_max_pooling_generic() +run(test_max_pooling_generic) def test_min_pooling_builtin(): @@ -370,7 +375,7 @@ def pooling_on_buffers(input, shape, output): # CHECK: RESULT: -13 -test_min_pooling_builtin() +run(test_min_pooling_builtin) def test_min_pooling_generic(): @@ -404,4 +409,4 @@ def pooling_on_buffers(input, shape, output): # CHECK: RESULT: -13 -test_min_pooling_generic() +run(test_min_pooling_generic)