diff --git a/build_tools/llvm_version.txt b/build_tools/llvm_version.txt index c93d17e96..e1bb882e1 100644 --- a/build_tools/llvm_version.txt +++ b/build_tools/llvm_version.txt @@ -1 +1 @@ -19b28074618c92fa4c4281eeee67c715abebbfd7 +50351218b3a4687079fe79f932aac0e00d5d990f diff --git a/build_tools/patches/0001-Add-support-for-VectorAnyINTEL-capability.patch b/build_tools/patches/0001-Add-support-for-VectorAnyINTEL-capability.patch index 1102c888d..3b5d4eb98 100644 --- a/build_tools/patches/0001-Add-support-for-VectorAnyINTEL-capability.patch +++ b/build_tools/patches/0001-Add-support-for-VectorAnyINTEL-capability.patch @@ -1,7 +1,7 @@ -From 887ae8599b205921bc4fd34da3c1de767f5568ae Mon Sep 17 00:00:00 2001 +From 1dd1130e8647473448ca9ce808e7bc6efa424675 Mon Sep 17 00:00:00 2001 From: Garra1980 -Date: Tue, 23 Sep 2025 21:22:18 +0200 -Subject: [PATCH] Add support for VectorAnyINTEL capability +Date: Mon, 5 Jan 2026 17:46:15 +0100 +Subject: [PATCH] Add-support-for-VectorAnyINTEL-capability --- .../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 11 +- @@ -24,7 +24,7 @@ Subject: [PATCH] Add support for VectorAnyINTEL capability 17 files changed, 324 insertions(+), 68 deletions(-) diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td -index 0e42d08cdb1f..f821b0d2e59b 100644 +index ecbbf39a534e..d6a72472bd1b 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -4240,7 +4240,14 @@ def SPIRV_BFloat16KHR : TypeAlias; @@ -53,10 +53,10 @@ index 0e42d08cdb1f..f821b0d2e59b 100644 class SPIRV_ScalarOrVectorOf : AnyTypeOf<[type, SPIRV_VectorOf]>; diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td -index 6b4e3dd60319..987b33c055e9 100644 +index 0fb4837e528b..a33b18e8c868 100644 --- a/mlir/include/mlir/IR/CommonTypeConstraints.td +++ b/mlir/include/mlir/IR/CommonTypeConstraints.td -@@ -654,6 +654,92 @@ class ScalableVectorOfRankAndLengthAndType allowedRanks, +@@ -696,6 +696,92 @@ class ScalableVectorOfRankAndLengthAndType allowedRanks, ScalableVectorOfLength.summary, "::mlir::VectorType">; @@ -150,7 +150,7 @@ index 6b4e3dd60319..987b33c055e9 100644 // Negative values for `n` index in reverse. class ShapedTypeWithNthDimOfSize allowedSizes> : Type< diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp -index c8efdf009422..5236dc299f81 100644 +index 22b57d6c0821..8f02dd856d4e 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp @@ -186,9 +186,12 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect, @@ -169,7 +169,7 @@ index c8efdf009422..5236dc299f81 100644 return Type(); } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp -index 7e9a80e7d73a..1db6233cf73f 100644 +index 53a48abe5ad0..4c39a7c83281 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -186,9 +186,10 @@ bool CompositeType::classof(Type type) { @@ -178,15 +178,15 @@ index 7e9a80e7d73a..1db6233cf73f 100644 bool CompositeType::isValid(VectorType type) { - return type.getRank() == 1 && - llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) && -- llvm::isa(type.getElementType()); +- isa(type.getElementType()); + // Number of elements should be between [2 to 2^32 - 1]. -+ return type.getRank() == 1 && mlir::isa(type.getElementType()) && ++ return type.getRank() == 1 && isa(type.getElementType()) && + type.getNumElements() >= 2 && + type.getNumElements() <= std::numeric_limits::max(); } Type CompositeType::getElementType(unsigned index) const { -@@ -221,8 +222,23 @@ void TypeCapabilityVisitor::addConcrete(VectorType type) { +@@ -218,8 +219,23 @@ void TypeCapabilityVisitor::addConcrete(VectorType type) { int64_t vecSize = type.getNumElements(); if (vecSize == 8 || vecSize == 16) { @@ -213,7 +213,7 @@ index 7e9a80e7d73a..1db6233cf73f 100644 } diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp -index 122f61e0a66a..c6f37e9345ed 100644 +index 816226749463..31c590efcda3 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -84,9 +84,13 @@ static std::optional> getTargetShape(VectorType vecType) { @@ -385,7 +385,7 @@ index 122f61e0a66a..c6f37e9345ed 100644 } static Type -@@ -1693,16 +1780,18 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) { +@@ -1694,16 +1781,18 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) { SmallVector, 4> typeExtensions; SmallVector, 8> typeCapabilities; for (Type valueType : valueTypes) { diff --git a/build_tools/patches/0006-Add-SPIRV_ExecutionModeAttributesAttr.patch b/build_tools/patches/0006-Add-SPIRV_ExecutionModeAttributesAttr.patch index 2625040d9..025296757 100644 --- a/build_tools/patches/0006-Add-SPIRV_ExecutionModeAttributesAttr.patch +++ b/build_tools/patches/0006-Add-SPIRV_ExecutionModeAttributesAttr.patch @@ -1,9 +1,8 @@ -From a18280b49a33d421477db322d642aff187b029f8 Mon Sep 17 00:00:00 2001 -From: Dimple Prajapati -Date: Tue, 7 May 2024 23:26:34 +0000 -Subject: [PATCH] Add SPIRV_ExecutionModeAttributesAttr +From ab51a97f5358afe72bfce6483a3416a50dc40018 Mon Sep 17 00:00:00 2001 +From: Garra1980 +Date: Mon, 5 Jan 2026 17:51:39 +0100 +Subject: [PATCH] Add-SPIRV_ExecutionModeAttributesAttr -add spirv.ExecutionMode Op during GPUToSPIRV Pass lowering --- .../mlir/Dialect/SPIRV/IR/SPIRVAttributes.td | 11 +++++++++++ .../mlir/Dialect/SPIRV/IR/TargetAndABI.h | 8 ++++++++ @@ -13,7 +12,7 @@ add spirv.ExecutionMode Op during GPUToSPIRV Pass lowering 5 files changed, 55 insertions(+) diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td -index 3a11284da051..1267ecd251ae 100644 +index 1bc3c63646fd..cf5b5ffa451d 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td @@ -56,6 +56,17 @@ def SPIRV_LinkageAttributesAttr : SPIRV_Attr<"LinkageAttributes", "linkage_attri @@ -54,10 +53,10 @@ index 24574bfaf619..f64f99294038 100644 } // namespace mlir diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp -index d7885e035959..5195035f088f 100644 +index c33a903d0339..a6578465abac 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp -@@ -343,6 +343,20 @@ LogicalResult GPUFuncOpConversion::matchAndRewrite( +@@ -346,6 +346,20 @@ LogicalResult GPUFuncOpConversion::matchAndRewrite( return failure(); newFuncOp->removeAttr( rewriter.getStringAttr(gpu::GPUDialect::getKernelFuncAttrName())); @@ -79,25 +78,25 @@ index d7885e035959..5195035f088f 100644 } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp -index 65aaafa55386..b9f906ada3ee 100644 +index 8f02dd856d4e..dfb13c173596 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp -@@ -948,6 +948,10 @@ LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op, +@@ -1026,6 +1026,10 @@ LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op, } else if (symbol == spirv::getTargetEnvAttrName()) { - if (!llvm::isa(attr)) + if (!isa(attr)) return op->emitError("'") << symbol << "' must be a spirv::TargetEnvAttr"; + } else if (symbol == spirv::getExecutionModeFuncAttrName()) { -+ if (!llvm::isa(attr)) ++ if (!isa(attr)) + return op->emitError("'") + << symbol << "' must be a spirv::ExecutionModeFuncAttributeAttr"; } else { return op->emitError("found unsupported '") << symbol << "' attribute on operation"; diff --git a/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp b/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp -index bbc318e17300..56251a8f3990 100644 +index 8c52ba8b8583..864fc75c53b1 100644 --- a/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp -@@ -242,3 +242,21 @@ spirv::getMemoryModel(spirv::TargetEnvAttr targetAttr) { +@@ -238,3 +238,21 @@ spirv::getMemoryModel(spirv::TargetEnvAttr targetAttr) { } return failure(); } diff --git a/build_tools/patches/0014-Add-32-64-and-128-length-vector-as-supported-vectors.patch b/build_tools/patches/0014-Add-32-64-and-128-length-vector-as-supported-vectors.patch index 19316b6a7..6f3bfa59d 100644 --- a/build_tools/patches/0014-Add-32-64-and-128-length-vector-as-supported-vectors.patch +++ b/build_tools/patches/0014-Add-32-64-and-128-length-vector-as-supported-vectors.patch @@ -1,18 +1,14 @@ -From 27bd5d19a0f122d9acded716fe936d07416e1308 Mon Sep 17 00:00:00 2001 -From: "Shahneous Bari, Md Abdullah" -Date: Wed, 17 Dec 2025 14:21:48 +0000 -Subject: [PATCH] Add 32, 64 and 128 length vector as supported vectors. +From 0f49dfa51b379d1197b290e810ce2d3d178b5dde Mon Sep 17 00:00:00 2001 +From: Garra1980 +Date: Mon, 5 Jan 2026 15:48:16 +0100 +Subject: [PATCH] Add-32-64-and-128-length-vector-as-supported-vectors -This is needed to support large loads/stores using OpenCL intrinsics in -MLIR workflow. -THIS IS A HACK and temporary solution, -need to re-visit this with better solution later. --- - llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp | 81 ++++++++++++++------ - 1 file changed, 59 insertions(+), 22 deletions(-) + llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp | 82 +++++++++++++++----- + 1 file changed, 61 insertions(+), 21 deletions(-) diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp -index 30703ee40be0..2f3baa1b6c7e 100644 +index 590182731b00..0d9f4df374f9 100644 --- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp @@ -50,6 +50,28 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { @@ -44,20 +40,15 @@ index 30703ee40be0..2f3baa1b6c7e 100644 const LLT v16s64 = LLT::fixed_vector(16, 64); const LLT v16s32 = LLT::fixed_vector(16, 32); const LLT v16s16 = LLT::fixed_vector(16, 16); -@@ -99,41 +121,53 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { +@@ -100,16 +122,22 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { // TODO: remove copy-pasting here by using concatenation in some way. auto allPtrsScalarsAndVectors = { -- p0, p1, p2, p3, p4, p5, p6, p7, p8, -- p9, p10, p11, p12, s1, s8, s16, s32, s64, -- v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, v3s16, v3s32, -- v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, v8s8, v8s16, -- v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64}; -- -- auto allVectors = {v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, -- v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, -- v4s64, v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, -- v16s8, v16s16, v16s32, v16s64}; +- p0, p1, p2, p3, p4, p5, p6, p7, p8, +- p9, p10, p11, p12, p13, s1, s8, s16, s32, +- s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, v3s16, +- v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, v8s8, +- v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64}; + p0, p1, p2, p3, p4, p5, p6, p7, p8, + p9, p10, p11, p12, s1, s8, s16, s32, s64, + s128, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, v3s16, @@ -74,9 +65,15 @@ index 30703ee40be0..2f3baa1b6c7e 100644 + v32s32, v32s64, v64s1, v64s8, v64s16, v64s32, v64s64, + v128s1, v128s8, v128s16, v128s32, v128s64}; +- auto allVectors = {v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, +- v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, +- v4s64, v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, +- v16s8, v16s16, v16s32, v16s64}; + auto allShaderVectors = {v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, v3s16, v3s32, v3s64, - v4s1, v4s8, v4s16, v4s32, v4s64}; +@@ -118,25 +146,32 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { + auto allScalars = {s1, s8, s16, s32, s64}; auto allScalarsAndVectors = { - s1, s8, s16, s32, s64, s128, v2s1, v2s8, @@ -116,9 +113,9 @@ index 30703ee40be0..2f3baa1b6c7e 100644 + v32s1, v32s8, v32s16, v32s32, v32s64, v64s1, v64s8, v64s16, v64s32, + v64s64, v128s1, v128s8, v128s16, v128s32, v128s64}; - auto allFloatAndIntScalarsAndPtrs = {s8, s16, s32, s64, p0, p1, p2, p3, p4, - p5, p6, p7, p8, p9, p10, p11, p12}; -@@ -170,7 +204,9 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { + auto allFloatAndIntScalarsAndPtrs = {s8, s16, s32, s64, p0, p1, + p2, p3, p4, p5, p6, p7, +@@ -174,7 +209,9 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { // shader execution models, vector sizes are strictly limited to 4. In // non-shader contexts, vector sizes of 8 and 16 are also permitted, but // arbitrary sizes (e.g., 6 or 11) are not. @@ -126,18 +123,20 @@ index 30703ee40be0..2f3baa1b6c7e 100644 + + // @IMEX, make the max vector size to be 128 for now. + uint32_t MaxVectorSize = ST.isShader() ? 4 : 128; + LLVM_DEBUG(dbgs() << "MaxVectorSize: " << MaxVectorSize << "\n"); for (auto Opc : getTypeFoldingSupportedOpcodes()) { - if (Opc != G_EXTRACT_VECTOR_ELT) -@@ -531,7 +567,8 @@ bool SPIRVLegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper, - LLT DstTy = MRI.getType(DstReg); - LLT SrcTy = MRI.getType(SrcReg); - -- int32_t MaxVectorSize = ST.isShader() ? 4 : 16; -+ // @IMEX make the max vector size to be 128 -+ int32_t MaxVectorSize = ST.isShader() ? 4 : 128; - - bool DstNeedsLegalization = false; - bool SrcNeedsLegalization = false; +@@ -579,7 +616,10 @@ static bool needsVectorLegalization(const LLT &Ty, const SPIRVSubtarget &ST) { + if (!Ty.isVector()) + return false; + unsigned NumElements = Ty.getNumElements(); +- unsigned MaxVectorSize = ST.isShader() ? 4 : 16; ++ ++ // @IMEX make the max vector size to be 128 ++ int32_t MaxVectorSize = ST.isShader() ? 4 : 128; ++ + return (NumElements > 4 && !isPowerOf2_32(NumElements)) || + NumElements > MaxVectorSize; + } -- -2.43.0 +2.34.1 diff --git a/build_tools/patches/relaxing_xegpu-propagation.patch b/build_tools/patches/relaxing_xegpu-propagation.patch deleted file mode 100644 index 16fac6c32..000000000 --- a/build_tools/patches/relaxing_xegpu-propagation.patch +++ /dev/null @@ -1,33 +0,0 @@ -From 8f85a84fdad1c1ee6c39a3b04e74e0d8513a6abf Mon Sep 17 00:00:00 2001 -From: Garra1980 -Date: Tue, 25 Nov 2025 22:45:25 +0100 -Subject: [PATCH] relaxing_xegpu propagation - ---- - mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp | 4 ---- - 1 file changed, 4 deletions(-) - -diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp -index 6b3ba5a5981c..ecd6a3358f4a 100644 ---- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp -+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp -@@ -127,8 +127,6 @@ public: - SmallVector LayoutInfo::getLaneLayout() const { - if (!isAssigned()) - return {}; -- assert(storage.getEffectiveLaneLayoutAsInt().size() && -- "Expected lane layout to be assigned"); - return llvm::map_to_vector(storage.getEffectiveLaneLayoutAsInt(), - [](int64_t val) { return static_cast(val); }); - } -@@ -136,8 +134,6 @@ SmallVector LayoutInfo::getLaneLayout() const { - SmallVector LayoutInfo::getLaneData() const { - if (!isAssigned()) - return {}; -- assert(storage.getEffectiveLaneDataAsInt().size() && -- "Expected lane data to be assigned"); - return llvm::map_to_vector(storage.getEffectiveLaneDataAsInt(), - [](int64_t val) { return static_cast(val); }); - } --- -2.34.1 diff --git a/build_tools/patches/wg_fa_support.patch b/build_tools/patches/wg_fa_support.patch index c703fc0e7..39508836a 100644 --- a/build_tools/patches/wg_fa_support.patch +++ b/build_tools/patches/wg_fa_support.patch @@ -1,5 +1,16 @@ +From 4342fab14630d1ee774ec78dea66fd78f200a3ad Mon Sep 17 00:00:00 2001 +From: Garra1980 +Date: Mon, 22 Dec 2025 16:40:20 +0100 +Subject: [PATCH] wg_fa_support + +--- + .../XeGPU/Transforms/XeGPUBlocking.cpp | 6 +- + .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 59 +++++++++++++++++++ + .../Transforms/XeGPUWgToSgDistribute.cpp | 12 ++-- + 3 files changed, 68 insertions(+), 9 deletions(-) + diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp -index ec5feb8bc8c4..c8b9f0eb6a06 100644 +index ba2753f517ce..dc4f05c0d914 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp @@ -228,7 +228,7 @@ XeGPUBlockingPass::getTileShape(Operation *op) const { @@ -11,10 +22,10 @@ index ec5feb8bc8c4..c8b9f0eb6a06 100644 return getTileShape(op->getOpResult(0)); return std::nullopt; -@@ -413,14 +413,14 @@ void XeGPUBlockingPass::runOnOperation() { +@@ -415,14 +415,14 @@ void XeGPUBlockingPass::runOnOperation() { // Remove the layout attributes cached per operands. for (OpOperand &opr : op->getOpOperands()) { - std::string name = xegpu::getLayoutName(opr); + std::string name = xegpu::getTemporaryLayoutName(opr); - if (op->hasAttrOfType(name)) + if (op->hasAttrOfType(name)) op->removeAttr(name); @@ -22,17 +33,17 @@ index ec5feb8bc8c4..c8b9f0eb6a06 100644 // Update the layout attributes per result. for (OpResult result : op->getOpResults()) { - std::string name = xegpu::getLayoutName(result); + std::string name = xegpu::getTemporaryLayoutName(result); - if (auto layout = op->getAttrOfType(name)) { + if (auto layout = op->getAttrOfType(name)) { op->removeAttr(name); if (!isa(op)) xegpu::setDistributeLayoutAttr(result, layout.dropInstData()); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp -index 574fe29f4bab..64b4436d7016 100644 +index 7fc75e7294ea..94f31e653511 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp -@@ -1273,6 +1273,60 @@ static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder, +@@ -1336,6 +1336,60 @@ static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder, return success(); } @@ -41,14 +52,14 @@ index 574fe29f4bab..64b4436d7016 100644 + // Load op has a conflict if tensor desc layout is different from the its + // result layout. + auto getResultLayout = [](OpResult result) { -+ auto resultLayoutName = xegpu::getLayoutName(result); ++ auto resultLayoutName = xegpu::getTemporaryLayoutName(result); + return result.getOwner()->getAttrOfType( + resultLayoutName); + }; + auto hasConflict = [&getResultLayout](xegpu::LoadNdOp loadNdOp) -> bool { + auto tdescType = loadNdOp.getTensorDescType(); + auto tdescLayout = tdescType.getLayout(); -+ // auto resultLayoutName = xegpu::getLayoutName(loadNdOp->getOpResult(0)); ++ // auto resultLayoutName = xegpu::getTemporaryLayoutName(loadNdOp->getOpResult(0)); + auto resultLayout = getResultLayout(loadNdOp->getOpResult(0)); + return tdescLayout && resultLayout && tdescLayout != resultLayout; + }; @@ -93,7 +104,7 @@ index 574fe29f4bab..64b4436d7016 100644 namespace { struct XeGPUPropagateLayoutPass final : public xegpu::impl::XeGPUPropagateLayoutBase { -@@ -1346,4 +1400,9 @@ void XeGPUPropagateLayoutPass::runOnOperation() { +@@ -1411,4 +1465,9 @@ void XeGPUPropagateLayoutPass::runOnOperation() { signalPassFailure(); return; } @@ -104,10 +115,10 @@ index 574fe29f4bab..64b4436d7016 100644 + } } diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp -index 48bd0662b03f..76ed73fdfaef 100644 +index 07572a495076..448d78f4dc4f 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp -@@ -1238,13 +1238,13 @@ struct WgToSgVectorTransposeOp +@@ -1270,13 +1270,13 @@ struct WgToSgVectorTransposeOp SmallVector sourceSgLayout = sourceLayout.getEffectiveSgLayoutAsInt(); SmallVector resultSgLayout = layout.getEffectiveSgLayoutAsInt(); @@ -127,3 +138,5 @@ index 48bd0662b03f..76ed73fdfaef 100644 ArrayRef permutation = op.getPermutation(); size_t permutationSize = permutation.size(); +-- +2.34.1 diff --git a/lib/Dialect/XeTile/Transforms/BlockingAnalysis.cpp b/lib/Dialect/XeTile/Transforms/BlockingAnalysis.cpp index 9b48a54a2..90b12ce3c 100644 --- a/lib/Dialect/XeTile/Transforms/BlockingAnalysis.cpp +++ b/lib/Dialect/XeTile/Transforms/BlockingAnalysis.cpp @@ -282,6 +282,10 @@ class BlockingAnalysisImpl void visitBranchOperand(mlir::OpOperand &operand) override {} + void visitNonControlFlowArguments( + mlir::RegionSuccessor &successor, + mlir::ArrayRef arguments) override{}; + void visitCallOperand(mlir::OpOperand &operand) override {} void setToExitState(BlockingLattice *lattice) override {} diff --git a/lib/Transforms/VnniTransformation.cpp b/lib/Transforms/VnniTransformation.cpp index 0b17ade66..b87e9b6c7 100644 --- a/lib/Transforms/VnniTransformation.cpp +++ b/lib/Transforms/VnniTransformation.cpp @@ -231,6 +231,10 @@ class LayoutAnalysisImpl void visitCallOperand(mlir::OpOperand &operand) override {} + void visitNonControlFlowArguments( + mlir::RegionSuccessor &successor, + mlir::ArrayRef arguments) override{}; + void setToExitState(LayoutLattice *lattice) override { (void)lattice->meet(false); } diff --git a/test/Integration/Dialect/XeGPU/WG/arith_maximumf.mlir b/test/Integration/Dialect/XeGPU/WG/arith_maximumf.mlir index 48624dc0c..a748d8e6b 100644 --- a/test/Integration/Dialect/XeGPU/WG/arith_maximumf.mlir +++ b/test/Integration/Dialect/XeGPU/WG/arith_maximumf.mlir @@ -45,12 +45,12 @@ module @gemm attributes {gpu.container_module} { %m = arith.muli %block_id_x, %c256 : index %n = arith.muli %block_id_y, %c256 : index %input_tdesc_1 = xegpu.create_nd_tdesc %input1_gpu : memref<256x256xf32> -> !xegpu.tensor_desc<256x256xf32, #map> - %input_val_1 = xegpu.load_nd %input_tdesc_1[%m, %n] : !xegpu.tensor_desc<256x256xf32, #map> -> vector<256x256xf32> + %input_val_1 = xegpu.load_nd %input_tdesc_1[%m, %n] {layout = #map}: !xegpu.tensor_desc<256x256xf32, #map> -> vector<256x256xf32> %input_tdesc_2 = xegpu.create_nd_tdesc %input2_gpu : memref<256x256xf32> -> !xegpu.tensor_desc<256x256xf32, #map> - %input_val_2 = xegpu.load_nd %input_tdesc_2[%m, %n] : !xegpu.tensor_desc<256x256xf32, #map> -> vector<256x256xf32> + %input_val_2 = xegpu.load_nd %input_tdesc_2[%m, %n] {layout = #map}: !xegpu.tensor_desc<256x256xf32, #map> -> vector<256x256xf32> %result_val = arith.maximumf %input_val_1, %input_val_2 {layout_result_0 = #map} : vector<256x256xf32> %result_tdesc = xegpu.create_nd_tdesc %result_gpu : memref<256x256xf32> -> !xegpu.tensor_desc<256x256xf32, #map> - xegpu.store_nd %result_val, %result_tdesc[%m, %n] : vector<256x256xf32>, !xegpu.tensor_desc<256x256xf32, #map> + xegpu.store_nd %result_val, %result_tdesc[%m, %n] {layout = #map}: vector<256x256xf32>, !xegpu.tensor_desc<256x256xf32, #map> gpu.return } } @@ -64,12 +64,12 @@ module @gemm attributes {gpu.container_module} { %m = arith.muli %block_id_x, %c256 : index %n = arith.muli %block_id_y, %c256 : index %input_tdesc_1 = xegpu.create_nd_tdesc %input1_gpu : memref<256x256xf32> -> !xegpu.tensor_desc<256x256xf32, #map> - %input_val_1 = xegpu.load_nd %input_tdesc_1[%m, %n] : !xegpu.tensor_desc<256x256xf32, #map> -> vector<256x256xf32> + %input_val_1 = xegpu.load_nd %input_tdesc_1[%m, %n] {layout = #map}: !xegpu.tensor_desc<256x256xf32, #map> -> vector<256x256xf32> %input_tdesc_2 = xegpu.create_nd_tdesc %input2_gpu : memref<256x256xf32> -> !xegpu.tensor_desc<256x256xf32, #map> - %input_val_2 = xegpu.load_nd %input_tdesc_2[%m, %n] : !xegpu.tensor_desc<256x256xf32, #map> -> vector<256x256xf32> + %input_val_2 = xegpu.load_nd %input_tdesc_2[%m, %n] {layout = #map}: !xegpu.tensor_desc<256x256xf32, #map> -> vector<256x256xf32> %result_val = arith.maximumf %input_val_1, %input_val_2 fastmath {layout_result_0 = #map} : vector<256x256xf32> %result_tdesc = xegpu.create_nd_tdesc %result_gpu : memref<256x256xf32> -> !xegpu.tensor_desc<256x256xf32, #map> - xegpu.store_nd %result_val, %result_tdesc[%m, %n] : vector<256x256xf32>, !xegpu.tensor_desc<256x256xf32, #map> + xegpu.store_nd %result_val, %result_tdesc[%m, %n] {layout = #map}: vector<256x256xf32>, !xegpu.tensor_desc<256x256xf32, #map> gpu.return } } diff --git a/test/Integration/Dialect/XeGPU/WG/flash_attention_fwd.mlir b/test/Integration/Dialect/XeGPU/WG/flash_attention_fwd.mlir index f8480b6a3..3886a8e7e 100644 --- a/test/Integration/Dialect/XeGPU/WG/flash_attention_fwd.mlir +++ b/test/Integration/Dialect/XeGPU/WG/flash_attention_fwd.mlir @@ -68,17 +68,17 @@ module @flash_attention attributes {gpu.container_module} { // For prefetch SG layout is 4x2. Each SG prefetch 16x32xf16 tile. // Note that prefetch x offset is same as Q x offset. This is because WGs in same batch colloborate on K and V prefetch. %k_prefetch_tile = xegpu.create_nd_tdesc %K , shape: [%size_x, %BLOCK_DMODEL], strides: [%BLOCK_DMODEL, %c1] : memref -> !xegpu.tensor_desc<64x64xf16, #k_prefetch> - xegpu.prefetch_nd %k_prefetch_tile[%wg_q_x_offset, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<64x64xf16, #k_prefetch> + xegpu.prefetch_nd %k_prefetch_tile[%wg_q_x_offset, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, layout = #k_prefetch} : !xegpu.tensor_desc<64x64xf16, #k_prefetch> %wg_q_x_offset_plus_BLOCK_N = arith.addi %wg_q_x_offset, %BLOCK_N : index - xegpu.prefetch_nd %k_prefetch_tile[%wg_q_x_offset_plus_BLOCK_N, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<64x64xf16, #k_prefetch> + xegpu.prefetch_nd %k_prefetch_tile[%wg_q_x_offset_plus_BLOCK_N, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, layout = #k_prefetch} : !xegpu.tensor_desc<64x64xf16, #k_prefetch> %wg_q_x_offset_plus_2_BLOCK_N = arith.addi %wg_q_x_offset_plus_BLOCK_N, %BLOCK_N : index - xegpu.prefetch_nd %k_prefetch_tile[%wg_q_x_offset_plus_2_BLOCK_N, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<64x64xf16, #k_prefetch> + xegpu.prefetch_nd %k_prefetch_tile[%wg_q_x_offset_plus_2_BLOCK_N, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, layout = #k_prefetch} : !xegpu.tensor_desc<64x64xf16, #k_prefetch> // V prefetch is similar to K %v_prefetch_tile = xegpu.create_nd_tdesc %V , shape: [%size_x, %BLOCK_DMODEL], strides: [%BLOCK_DMODEL, %c1] : memref -> !xegpu.tensor_desc<64x64xf16, #v_prefetch> - xegpu.prefetch_nd %v_prefetch_tile[%wg_q_x_offset, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<64x64xf16, #v_prefetch> - xegpu.prefetch_nd %v_prefetch_tile[%wg_q_x_offset_plus_BLOCK_N, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<64x64xf16, #v_prefetch> - xegpu.prefetch_nd %v_prefetch_tile[%wg_q_x_offset_plus_2_BLOCK_N, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<64x64xf16, #v_prefetch> + xegpu.prefetch_nd %v_prefetch_tile[%wg_q_x_offset, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, layout = #v_prefetch} : !xegpu.tensor_desc<64x64xf16, #v_prefetch> + xegpu.prefetch_nd %v_prefetch_tile[%wg_q_x_offset_plus_BLOCK_N, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, layout = #v_prefetch} : !xegpu.tensor_desc<64x64xf16, #v_prefetch> + xegpu.prefetch_nd %v_prefetch_tile[%wg_q_x_offset_plus_2_BLOCK_N, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, layout = #v_prefetch} : !xegpu.tensor_desc<64x64xf16, #v_prefetch> %BLOCK_N_3_t = arith.addi %BLOCK_N, %BLOCK_N : index %BLOCK_N_3 = arith.addi %BLOCK_N_3_t, %BLOCK_N : index @@ -98,7 +98,7 @@ module @flash_attention attributes {gpu.container_module} { // Load Q tile. Each WG loads 128x64xf16 tile of Q. - %q_value = xegpu.load_nd %q_tile[%wg_q_x_offset, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<128x64xf16, #q> -> vector<128x64xf16> + %q_value = xegpu.load_nd %q_tile[%wg_q_x_offset, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, layout = #q} : !xegpu.tensor_desc<128x64xf16, #q> -> vector<128x64xf16> // Inner loop. This loop iterate over K and V tiles and update the accumulator by computing softmax(q*k^T)*v // K and V tiles are accessed in 64x64xf16 blocks (BLOCK_N=64). However, we load them in 16x64xf16 slices. @@ -116,42 +116,42 @@ module @flash_attention attributes {gpu.container_module} { // K prefetch %prefetch_offset_x_running_t = arith.addi %BLOCK_N_3, %k : index %prefetch_offset_x_running = arith.addi %wg_q_x_offset, %prefetch_offset_x_running_t : index - xegpu.prefetch_nd %k_prefetch_tile[%prefetch_offset_x_running, %c0] : !xegpu.tensor_desc<64x64xf16, #k_prefetch> + xegpu.prefetch_nd %k_prefetch_tile[%prefetch_offset_x_running, %c0] {layout = #k_prefetch}: !xegpu.tensor_desc<64x64xf16, #k_prefetch> // V prefetch - xegpu.prefetch_nd %v_prefetch_tile[%prefetch_offset_x_running, %c0] : !xegpu.tensor_desc<64x64xf16, #v_prefetch> + xegpu.prefetch_nd %v_prefetch_tile[%prefetch_offset_x_running, %c0] {layout = #v_prefetch}: !xegpu.tensor_desc<64x64xf16, #v_prefetch> // Load first 16x64xf16 K slice. K is in column major layout, so we need to transpose after loading. %wg_x_offset_running = arith.addi %wg_x_offset, %k : index - %k_value_slice_0_t0 = xegpu.load_nd %k_tile_slice[%wg_x_offset_running, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x64xf16, #k> -> vector<16x64xf16> + %k_value_slice_0_t0 = xegpu.load_nd %k_tile_slice[%wg_x_offset_running, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, layout = #k} : !xegpu.tensor_desc<16x64xf16, #k> -> vector<16x64xf16> %k_value_slice_0 = vector.transpose %k_value_slice_0_t0, [1, 0] {layout_result_0 = #kt} : vector<16x64xf16> to vector<64x16xf16> // Compute first 128x16 of Q * K^T using DPAS. - %qk_out_0 = xegpu.dpas %q_value, %k_value_slice_0, %zero_dpas_128x16 {layout_result_0 = #layout_128x16} : vector<128x64xf16>, vector<64x16xf16>, vector<128x16xf32> -> vector<128x16xf32> + %qk_out_0 = xegpu.dpas %q_value, %k_value_slice_0, %zero_dpas_128x16 {layout_a = #q, layout_b = #kt, layout_cd = #layout_128x16} : vector<128x64xf16>, vector<64x16xf16>, vector<128x16xf32> -> vector<128x16xf32> // Load second 16x64xf16 K slice. %wg_x_offset_running_plus_16 = arith.addi %wg_x_offset_running, %c16 : index - %k_value_slice_1_t0 = xegpu.load_nd %k_tile_slice[%wg_x_offset_running_plus_16, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x64xf16, #k> -> vector<16x64xf16> + %k_value_slice_1_t0 = xegpu.load_nd %k_tile_slice[%wg_x_offset_running_plus_16, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, layout = #k} : !xegpu.tensor_desc<16x64xf16, #k> -> vector<16x64xf16> %k_value_slice_1 = vector.transpose %k_value_slice_1_t0, [1, 0] {layout_result_0 = #kt} : vector<16x64xf16> to vector<64x16xf16> // Compute second 128x16 of Q * K^T using DPAS - %qk_out_1 = xegpu.dpas %q_value, %k_value_slice_1, %zero_dpas_128x16 {layout_result_0 = #layout_128x16} : vector<128x64xf16>, vector<64x16xf16>, vector<128x16xf32> -> vector<128x16xf32> + %qk_out_1 = xegpu.dpas %q_value, %k_value_slice_1, %zero_dpas_128x16 {layout_a = #q, layout_b = #kt, layout_cd = #layout_128x16} : vector<128x64xf16>, vector<64x16xf16>, vector<128x16xf32> -> vector<128x16xf32> // Load third 16x64xf16 K slice %wg_x_offset_running_plus_32 = arith.addi %wg_x_offset_running_plus_16, %c16 : index - %k_value_slice_2_t0 = xegpu.load_nd %k_tile_slice[%wg_x_offset_running_plus_32, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x64xf16, #k> -> vector<16x64xf16> + %k_value_slice_2_t0 = xegpu.load_nd %k_tile_slice[%wg_x_offset_running_plus_32, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, layout = #k} : !xegpu.tensor_desc<16x64xf16, #k> -> vector<16x64xf16> %k_value_slice_2 = vector.transpose %k_value_slice_2_t0, [1, 0] {layout_result_0 = #kt} : vector<16x64xf16> to vector<64x16xf16> // Compute third 128x16 of Q * K^T using DPAS - %qk_out_2 = xegpu.dpas %q_value, %k_value_slice_2, %zero_dpas_128x16 {layout_result_0 = #layout_128x16} : vector<128x64xf16>, vector<64x16xf16>, vector<128x16xf32> -> vector<128x16xf32> + %qk_out_2 = xegpu.dpas %q_value, %k_value_slice_2, %zero_dpas_128x16 {layout_a = #q, layout_b = #kt, layout_cd = #layout_128x16} : vector<128x64xf16>, vector<64x16xf16>, vector<128x16xf32> -> vector<128x16xf32> // Load forth 16x64 K slice %wg_x_offset_running_plus_48 = arith.addi %wg_x_offset_running_plus_32, %c16 : index - %k_value_slice_3_t0 = xegpu.load_nd %k_tile_slice[%wg_x_offset_running_plus_48, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x64xf16, #k> -> vector<16x64xf16> + %k_value_slice_3_t0 = xegpu.load_nd %k_tile_slice[%wg_x_offset_running_plus_48, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, layout = #k} : !xegpu.tensor_desc<16x64xf16, #k> -> vector<16x64xf16> %k_value_slice_3 = vector.transpose %k_value_slice_3_t0, [1, 0] {layout_result_0 = #kt} : vector<16x64xf16> to vector<64x16xf16> // Compute forth 128x16 of Q * K^T using DPAS - %qk_out_3 = xegpu.dpas %q_value, %k_value_slice_3, %zero_dpas_128x16 {layout_result_0 = #layout_128x16} : vector<128x64xf16>, vector<64x16xf16>, vector<128x16xf32> -> vector<128x16xf32> + %qk_out_3 = xegpu.dpas %q_value, %k_value_slice_3, %zero_dpas_128x16 {layout_a = #q, layout_b = #kt, layout_cd = #layout_128x16} : vector<128x64xf16>, vector<64x16xf16>, vector<128x16xf32> -> vector<128x16xf32> // Softmax computation on QK_out tile // Do max reduction on qk_out @@ -209,24 +209,24 @@ module @flash_attention attributes {gpu.container_module} { %qk_out_3_f16 = arith.truncf %qk_out_3_exp {layout_result_0 = #layout_128x16} : vector<128x16xf32> to vector<128x16xf16> // Load first 16x64 V slice. - %v_val_slice_0 = xegpu.load_nd %v_tile_slice[%wg_x_offset_running, %c0] { l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x64xf16, #v> -> vector<16x64xf16> + %v_val_slice_0 = xegpu.load_nd %v_tile_slice[%wg_x_offset_running, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, layout = #v} : !xegpu.tensor_desc<16x64xf16, #v> -> vector<16x64xf16> // Compute first iteration update of 128x64 of P * V - %pv_out_iter0 = xegpu.dpas %qk_out_0_f16, %v_val_slice_0, %acc_in_updated {layout_result_0 = #out} : vector<128x16xf16>, vector<16x64xf16>, vector<128x64xf32> -> vector<128x64xf32> + %pv_out_iter0 = xegpu.dpas %qk_out_0_f16, %v_val_slice_0, %acc_in_updated {layout_a = #q, layout_b = #v, layout_cd = #out} : vector<128x16xf16>, vector<16x64xf16>, vector<128x64xf32> -> vector<128x64xf32> // Load second 16x64 V slice. - %v_val_slice_1 = xegpu.load_nd %v_tile_slice[%wg_x_offset_running_plus_16, %c0] { l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x64xf16, #v> -> vector<16x64xf16> + %v_val_slice_1 = xegpu.load_nd %v_tile_slice[%wg_x_offset_running_plus_16, %c0] { l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, layout = #v} : !xegpu.tensor_desc<16x64xf16, #v> -> vector<16x64xf16> // Compute second iteration update of 128x64 of P * V - %pv_out_iter1 = xegpu.dpas %qk_out_1_f16, %v_val_slice_1, %pv_out_iter0 {layout_result_0 = #out} : vector<128x16xf16>, vector<16x64xf16>, vector<128x64xf32> -> vector<128x64xf32> + %pv_out_iter1 = xegpu.dpas %qk_out_1_f16, %v_val_slice_1, %pv_out_iter0 {layout_a = #q, layout_b = #v, layout_cd = #out} : vector<128x16xf16>, vector<16x64xf16>, vector<128x64xf32> -> vector<128x64xf32> // Load third 16x64 V slice. - %v_val_slice_2 = xegpu.load_nd %v_tile_slice[%wg_x_offset_running_plus_32, %c0] { l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x64xf16, #v> -> vector<16x64xf16> + %v_val_slice_2 = xegpu.load_nd %v_tile_slice[%wg_x_offset_running_plus_32, %c0] { l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, layout = #v} : !xegpu.tensor_desc<16x64xf16, #v> -> vector<16x64xf16> // Compute third iteration update of 128x64 of P * V - %pv_out_iter2 = xegpu.dpas %qk_out_2_f16, %v_val_slice_2, %pv_out_iter1 {layout_result_0 = #out} : vector<128x16xf16>, vector<16x64xf16>, vector<128x64xf32> -> vector<128x64xf32> + %pv_out_iter2 = xegpu.dpas %qk_out_2_f16, %v_val_slice_2, %pv_out_iter1 {layout_a = #q, layout_b = #v, layout_cd = #out} : vector<128x16xf16>, vector<16x64xf16>, vector<128x64xf32> -> vector<128x64xf32> // Load forth 16x64 V slice. - %v_val_slice_3 = xegpu.load_nd %v_tile_slice[%wg_x_offset_running_plus_48, %c0] { l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x64xf16, #v> -> vector<16x64xf16> + %v_val_slice_3 = xegpu.load_nd %v_tile_slice[%wg_x_offset_running_plus_48, %c0] { l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, layout = #v} : !xegpu.tensor_desc<16x64xf16, #v> -> vector<16x64xf16> // Compute forth iteration update of 128x64 of P * V - %pv_out_iter3 = xegpu.dpas %qk_out_3_f16, %v_val_slice_3, %pv_out_iter2 {layout_result_0 = #out} : vector<128x16xf16>, vector<16x64xf16>, vector<128x64xf32> -> vector<128x64xf32> + %pv_out_iter3 = xegpu.dpas %qk_out_3_f16, %v_val_slice_3, %pv_out_iter2 {layout_a = #q, layout_b = #v, layout_cd = #out} : vector<128x16xf16>, vector<16x64xf16>, vector<128x64xf32> -> vector<128x64xf32> scf.yield %pv_out_iter3, %m_ij_row, %l_i_row_new : vector<128x64xf32>, vector<128x1xf32>, vector<128x1xf32> } {layout_result_0 = #out, layout_result_1 = #layout_128x1, layout_result_2 = #layout_128x1}// end of inner loop diff --git a/test/Integration/Dialect/XeGPU/WG/gemm_4kx4kx4k_f16_f16_f32.mlir b/test/Integration/Dialect/XeGPU/WG/gemm_4kx4kx4k_f16_f16_f32.mlir index 4c25250d3..6bda55571 100644 --- a/test/Integration/Dialect/XeGPU/WG/gemm_4kx4kx4k_f16_f16_f32.mlir +++ b/test/Integration/Dialect/XeGPU/WG/gemm_4kx4kx4k_f16_f16_f32.mlir @@ -48,34 +48,34 @@ module @gemm attributes {gpu.container_module} { %m = arith.muli %block_id_x, %c256 : index %n = arith.muli %block_id_y, %c256 : index %c_tdesc = xegpu.create_nd_tdesc %C : memref<4096x4096xf32> -> !xegpu.tensor_desc<256x256xf32, #c> - %c_init_value = xegpu.load_nd %c_tdesc[%m, %n] : !xegpu.tensor_desc<256x256xf32, #c> -> vector<256x256xf32> + %c_init_value = xegpu.load_nd %c_tdesc[%m, %n] {layout = #c}: !xegpu.tensor_desc<256x256xf32, #c> -> vector<256x256xf32> %a_tdesc = xegpu.create_nd_tdesc %A : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #a> %b_tdesc = xegpu.create_nd_tdesc %B : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16, #b> // Prefetch A 3 times. %a_prefetch_tdesc = xegpu.create_nd_tdesc %A : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #a_prefetch> - xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c0] : !xegpu.tensor_desc<256x32xf16, #a_prefetch> - xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c32] : !xegpu.tensor_desc<256x32xf16, #a_prefetch> - xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c64] : !xegpu.tensor_desc<256x32xf16, #a_prefetch> + xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c0] {layout = #a_prefetch}: !xegpu.tensor_desc<256x32xf16, #a_prefetch> + xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c32] {layout = #a_prefetch}: !xegpu.tensor_desc<256x32xf16, #a_prefetch> + xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c64] {layout = #a_prefetch}: !xegpu.tensor_desc<256x32xf16, #a_prefetch> // Prefetch B 3 times. %b_prefetch_tdesc = xegpu.create_nd_tdesc %B : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16, #b_prefetch> - xegpu.prefetch_nd %b_prefetch_tdesc[%c0, %n] : !xegpu.tensor_desc<32x256xf16, #b_prefetch> - xegpu.prefetch_nd %b_prefetch_tdesc[%c32, %n] : !xegpu.tensor_desc<32x256xf16, #b_prefetch> - xegpu.prefetch_nd %b_prefetch_tdesc[%c64, %n] : !xegpu.tensor_desc<32x256xf16, #b_prefetch> + xegpu.prefetch_nd %b_prefetch_tdesc[%c0, %n] {layout = #b_prefetch}: !xegpu.tensor_desc<32x256xf16, #b_prefetch> + xegpu.prefetch_nd %b_prefetch_tdesc[%c32, %n] {layout = #b_prefetch}: !xegpu.tensor_desc<32x256xf16, #b_prefetch> + xegpu.prefetch_nd %b_prefetch_tdesc[%c64, %n] {layout = #b_prefetch}: !xegpu.tensor_desc<32x256xf16, #b_prefetch> %out = scf.for %k = %c0 to %c4096 step %c32 iter_args(%c_value = %c_init_value) -> (vector<256x256xf32>) { - %a_value = xegpu.load_nd %a_tdesc[%m, %k] : !xegpu.tensor_desc<256x32xf16, #a> -> vector<256x32xf16> - %b_value = xegpu.load_nd %b_tdesc[%k, %n] : !xegpu.tensor_desc<32x256xf16, #b> -> vector<32x256xf16> + %a_value = xegpu.load_nd %a_tdesc[%m, %k] {layout = #a}: !xegpu.tensor_desc<256x32xf16, #a> -> vector<256x32xf16> + %b_value = xegpu.load_nd %b_tdesc[%k, %n] {layout = #b}: !xegpu.tensor_desc<32x256xf16, #b> -> vector<32x256xf16> // Prefetch next tiles. %prefetch_offset = arith.addi %k, %c96 : index - xegpu.prefetch_nd %a_prefetch_tdesc[%m, %prefetch_offset] : !xegpu.tensor_desc<256x32xf16, #a_prefetch> - xegpu.prefetch_nd %b_prefetch_tdesc[%prefetch_offset, %n] : !xegpu.tensor_desc<32x256xf16, #b_prefetch> - %c_new_value = xegpu.dpas %a_value, %b_value, %c_value {layout_result_0 = #c} + xegpu.prefetch_nd %a_prefetch_tdesc[%m, %prefetch_offset] {layout = #a_prefetch}: !xegpu.tensor_desc<256x32xf16, #a_prefetch> + xegpu.prefetch_nd %b_prefetch_tdesc[%prefetch_offset, %n] {layout = #b_prefetch}: !xegpu.tensor_desc<32x256xf16, #b_prefetch> + %c_new_value = xegpu.dpas %a_value, %b_value, %c_value {layout_a = #a, layout_b = #b, layout_cd = #c} : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf32> -> vector<256x256xf32> scf.yield %c_new_value : vector<256x256xf32> } - xegpu.store_nd %out, %c_tdesc[%m, %n] : vector<256x256xf32>, !xegpu.tensor_desc<256x256xf32, #c> + xegpu.store_nd %out, %c_tdesc[%m, %n] {layout = #c}: vector<256x256xf32>, !xegpu.tensor_desc<256x256xf32, #c> gpu.return } } diff --git a/test/Integration/Dialect/XeGPU/WG/gemm_4kx4kx4k_transpose_b.mlir b/test/Integration/Dialect/XeGPU/WG/gemm_4kx4kx4k_transpose_b.mlir index f5cc4ae90..805fefb40 100644 --- a/test/Integration/Dialect/XeGPU/WG/gemm_4kx4kx4k_transpose_b.mlir +++ b/test/Integration/Dialect/XeGPU/WG/gemm_4kx4kx4k_transpose_b.mlir @@ -3,8 +3,8 @@ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck #a = #xegpu.layout -#b = #xegpu.layout -#bt = #xegpu.layout +#b = #xegpu.layout +#bt = #xegpu.layout #c = #xegpu.layout #a_prefetch = #xegpu.layout #b_prefetch = #xegpu.layout @@ -49,35 +49,35 @@ module @gemm attributes {gpu.container_module} { %m = arith.muli %block_id_x, %c256 : index %n = arith.muli %block_id_y, %c256 : index %c_tdesc = xegpu.create_nd_tdesc %C : memref<4096x4096xf32> -> !xegpu.tensor_desc<256x256xf32, #c> - %c_init_value = xegpu.load_nd %c_tdesc[%m, %n] : !xegpu.tensor_desc<256x256xf32, #c> -> vector<256x256xf32> + %c_init_value = xegpu.load_nd %c_tdesc[%m, %n] {layout = #c}: !xegpu.tensor_desc<256x256xf32, #c> -> vector<256x256xf32> %a_tdesc = xegpu.create_nd_tdesc %A : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #a> %b_tdesc = xegpu.create_nd_tdesc %B : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #bt> // Prefetch A 3 times. %a_prefetch_tdesc = xegpu.create_nd_tdesc %A : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #a_prefetch> - xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c0] : !xegpu.tensor_desc<256x32xf16, #a_prefetch> - xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c32] : !xegpu.tensor_desc<256x32xf16, #a_prefetch> - xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c64] : !xegpu.tensor_desc<256x32xf16, #a_prefetch> + xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c0] {layout = #a_prefetch}: !xegpu.tensor_desc<256x32xf16, #a_prefetch> + xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c32] {layout = #a_prefetch}: !xegpu.tensor_desc<256x32xf16, #a_prefetch> + xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c64] {layout = #a_prefetch}: !xegpu.tensor_desc<256x32xf16, #a_prefetch> // Prefetch B 3 times. %b_prefetch_tdesc = xegpu.create_nd_tdesc %B : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #b_prefetch> - xegpu.prefetch_nd %b_prefetch_tdesc[%n, %c0] : !xegpu.tensor_desc<256x32xf16, #b_prefetch> - xegpu.prefetch_nd %b_prefetch_tdesc[%n, %c32] : !xegpu.tensor_desc<256x32xf16, #b_prefetch> - xegpu.prefetch_nd %b_prefetch_tdesc[%n, %c64] : !xegpu.tensor_desc<256x32xf16, #b_prefetch> + xegpu.prefetch_nd %b_prefetch_tdesc[%n, %c0] {layout = #b_prefetch}: !xegpu.tensor_desc<256x32xf16, #b_prefetch> + xegpu.prefetch_nd %b_prefetch_tdesc[%n, %c32] {layout = #b_prefetch}: !xegpu.tensor_desc<256x32xf16, #b_prefetch> + xegpu.prefetch_nd %b_prefetch_tdesc[%n, %c64] {layout = #b_prefetch}: !xegpu.tensor_desc<256x32xf16, #b_prefetch> %out = scf.for %k = %c0 to %c4096 step %c32 iter_args(%c_value = %c_init_value) -> (vector<256x256xf32>) { - %a_value = xegpu.load_nd %a_tdesc[%m, %k] : !xegpu.tensor_desc<256x32xf16, #a> -> vector<256x32xf16> - %b_value = xegpu.load_nd %b_tdesc[%n, %k] : !xegpu.tensor_desc<256x32xf16, #bt> -> vector<256x32xf16> + %a_value = xegpu.load_nd %a_tdesc[%m, %k] {layout = #a}: !xegpu.tensor_desc<256x32xf16, #a> -> vector<256x32xf16> + %b_value = xegpu.load_nd %b_tdesc[%n, %k] {layout = #bt}: !xegpu.tensor_desc<256x32xf16, #bt> -> vector<256x32xf16> // Prefetch next tiles. %prefetch_offset = arith.addi %k, %c96 : index - xegpu.prefetch_nd %a_prefetch_tdesc[%m, %prefetch_offset] : !xegpu.tensor_desc<256x32xf16, #a_prefetch> - xegpu.prefetch_nd %b_prefetch_tdesc[%n, %prefetch_offset] : !xegpu.tensor_desc<256x32xf16, #b_prefetch> + xegpu.prefetch_nd %a_prefetch_tdesc[%m, %prefetch_offset] {layout = #a_prefetch}: !xegpu.tensor_desc<256x32xf16, #a_prefetch> + xegpu.prefetch_nd %b_prefetch_tdesc[%n, %prefetch_offset] {layout = #b_prefetch}: !xegpu.tensor_desc<256x32xf16, #b_prefetch> %b_transposed = vector.transpose %b_value, [1, 0] {layout_result_0 = #b} : vector<256x32xf16> to vector<32x256xf16> - %c_new_value = xegpu.dpas %a_value, %b_transposed, %c_value {layout_result_0 = #c} + %c_new_value = xegpu.dpas %a_value, %b_transposed, %c_value {layout_a = #a, layout_b = #b, layout_cd = #c} : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf32> -> vector<256x256xf32> scf.yield %c_new_value : vector<256x256xf32> } - xegpu.store_nd %out, %c_tdesc[%m, %n] : vector<256x256xf32>, !xegpu.tensor_desc<256x256xf32, #c> + xegpu.store_nd %out, %c_tdesc[%m, %n] {layout = #c}: vector<256x256xf32>, !xegpu.tensor_desc<256x256xf32, #c> gpu.return } } diff --git a/test/Integration/Dialect/XeGPU/WG/math_exp.mlir b/test/Integration/Dialect/XeGPU/WG/math_exp.mlir index dd79ab505..592d18195 100644 --- a/test/Integration/Dialect/XeGPU/WG/math_exp.mlir +++ b/test/Integration/Dialect/XeGPU/WG/math_exp.mlir @@ -44,10 +44,10 @@ module @gemm attributes {gpu.container_module} { %m = arith.muli %block_id_x, %c256 : index %n = arith.muli %block_id_y, %c256 : index %input_tdesc = xegpu.create_nd_tdesc %input_gpu : memref<256x256xf32> -> !xegpu.tensor_desc<256x256xf32, #map> - %input_val = xegpu.load_nd %input_tdesc[%m, %n] : !xegpu.tensor_desc<256x256xf32, #map> -> vector<256x256xf32> - %result_val = math.exp %input_val {layout_result_0 = #map} : vector<256x256xf32> + %input_val = xegpu.load_nd %input_tdesc[%m, %n] {layout = #map} : !xegpu.tensor_desc<256x256xf32, #map> -> vector<256x256xf32> + %result_val = math.exp %input_val {layout_result_0 = #map}: vector<256x256xf32> %result_tdesc = xegpu.create_nd_tdesc %result_gpu : memref<256x256xf32> -> !xegpu.tensor_desc<256x256xf32, #map> - xegpu.store_nd %result_val, %result_tdesc[%m, %n] : vector<256x256xf32>, !xegpu.tensor_desc<256x256xf32, #map> + xegpu.store_nd %result_val, %result_tdesc[%m, %n] {layout = #map} : vector<256x256xf32>, !xegpu.tensor_desc<256x256xf32, #map> gpu.return } @@ -63,10 +63,10 @@ module @gemm attributes {gpu.container_module} { %m = arith.muli %block_id_x, %c256 : index %n = arith.muli %block_id_y, %c256 : index %input_tdesc = xegpu.create_nd_tdesc %input_gpu_with_fast_math : memref<256x256xf32> -> !xegpu.tensor_desc<256x256xf32, #map> - %input_val = xegpu.load_nd %input_tdesc[%m, %n] : !xegpu.tensor_desc<256x256xf32, #map> -> vector<256x256xf32> - %result_val = math.exp %input_val fastmath {layout_result_0 = #map} : vector<256x256xf32> + %input_val = xegpu.load_nd %input_tdesc[%m, %n] {layout = #map} : !xegpu.tensor_desc<256x256xf32, #map> -> vector<256x256xf32> + %result_val = math.exp %input_val fastmath {layout_result_0 = #map}: vector<256x256xf32> %result_tdesc = xegpu.create_nd_tdesc %result_gpu_with_fastmath : memref<256x256xf32> -> !xegpu.tensor_desc<256x256xf32, #map> - xegpu.store_nd %result_val, %result_tdesc[%m, %n] : vector<256x256xf32>, !xegpu.tensor_desc<256x256xf32, #map> + xegpu.store_nd %result_val, %result_tdesc[%m, %n] {layout = #map} : vector<256x256xf32>, !xegpu.tensor_desc<256x256xf32, #map> gpu.return } } diff --git a/test/Integration/Dialect/XeGPU/WG/simple_gemm.mlir b/test/Integration/Dialect/XeGPU/WG/simple_gemm.mlir index d46d6caa2..e3b011c21 100644 --- a/test/Integration/Dialect/XeGPU/WG/simple_gemm.mlir +++ b/test/Integration/Dialect/XeGPU/WG/simple_gemm.mlir @@ -51,34 +51,34 @@ module @gemm attributes {gpu.container_module} { %m = arith.muli %block_id_x, %c256 : index %n = arith.muli %block_id_y, %c256 : index %c_tdesc = xegpu.create_nd_tdesc %C : memref<256x256xf32> -> !xegpu.tensor_desc<256x256xf32, #c> - %c_init_value = xegpu.load_nd %c_tdesc[%m, %n] : !xegpu.tensor_desc<256x256xf32, #c> -> vector<256x256xf32> + %c_init_value = xegpu.load_nd %c_tdesc[%m, %n] {layout = #c}: !xegpu.tensor_desc<256x256xf32, #c> -> vector<256x256xf32> %a_tdesc = xegpu.create_nd_tdesc %A : memref<256x256xf16> -> !xegpu.tensor_desc<256x32xf16, #a> %b_tdesc = xegpu.create_nd_tdesc %B : memref<256x256xf16> -> !xegpu.tensor_desc<32x256xf16, #b> // Prefetch A 3 times. %a_prefetch_tdesc = xegpu.create_nd_tdesc %A : memref<256x256xf16> -> !xegpu.tensor_desc<256x32xf16, #a_prefetch> - xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c0] : !xegpu.tensor_desc<256x32xf16, #a_prefetch> - xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c32] : !xegpu.tensor_desc<256x32xf16, #a_prefetch> - xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c64] : !xegpu.tensor_desc<256x32xf16, #a_prefetch> + xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c0] {layout = #b_prefetch}: !xegpu.tensor_desc<256x32xf16, #a_prefetch> + xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c32] {layout = #b_prefetch}: !xegpu.tensor_desc<256x32xf16, #a_prefetch> + xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c64] {layout = #b_prefetch}: !xegpu.tensor_desc<256x32xf16, #a_prefetch> // Prefetch B 3 times. %b_prefetch_tdesc = xegpu.create_nd_tdesc %B : memref<256x256xf16> -> !xegpu.tensor_desc<32x256xf16, #b_prefetch> - xegpu.prefetch_nd %b_prefetch_tdesc[%c0, %n] : !xegpu.tensor_desc<32x256xf16, #b_prefetch> - xegpu.prefetch_nd %b_prefetch_tdesc[%c32, %n] : !xegpu.tensor_desc<32x256xf16, #b_prefetch> - xegpu.prefetch_nd %b_prefetch_tdesc[%c64, %n] : !xegpu.tensor_desc<32x256xf16, #b_prefetch> + xegpu.prefetch_nd %b_prefetch_tdesc[%c0, %n] {layout = #b_prefetch}: !xegpu.tensor_desc<32x256xf16, #b_prefetch> + xegpu.prefetch_nd %b_prefetch_tdesc[%c32, %n] {layout = #b_prefetch}: !xegpu.tensor_desc<32x256xf16, #b_prefetch> + xegpu.prefetch_nd %b_prefetch_tdesc[%c64, %n] {layout = #b_prefetch}: !xegpu.tensor_desc<32x256xf16, #b_prefetch> %out = scf.for %k = %c0 to %c256 step %c32 iter_args(%c_value = %c_init_value) -> (vector<256x256xf32>) { - %a_value = xegpu.load_nd %a_tdesc[%m, %k] : !xegpu.tensor_desc<256x32xf16, #a> -> vector<256x32xf16> - %b_value = xegpu.load_nd %b_tdesc[%k, %n] : !xegpu.tensor_desc<32x256xf16, #b> -> vector<32x256xf16> + %a_value = xegpu.load_nd %a_tdesc[%m, %k] {layout = #a}: !xegpu.tensor_desc<256x32xf16, #a> -> vector<256x32xf16> + %b_value = xegpu.load_nd %b_tdesc[%k, %n] {layout = #b}: !xegpu.tensor_desc<32x256xf16, #b> -> vector<32x256xf16> // Prefetch next tiles. %prefetch_offset = arith.addi %k, %c96 : index - xegpu.prefetch_nd %a_prefetch_tdesc[%m, %prefetch_offset] : !xegpu.tensor_desc<256x32xf16, #a_prefetch> - xegpu.prefetch_nd %b_prefetch_tdesc[%prefetch_offset, %n] : !xegpu.tensor_desc<32x256xf16, #b_prefetch> - %c_new_value = xegpu.dpas %a_value, %b_value, %c_value {layout_result_0 = #c} + xegpu.prefetch_nd %a_prefetch_tdesc[%m, %prefetch_offset] {layout = #a_prefetch}: !xegpu.tensor_desc<256x32xf16, #a_prefetch> + xegpu.prefetch_nd %b_prefetch_tdesc[%prefetch_offset, %n] {layout = #b_prefetch}: !xegpu.tensor_desc<32x256xf16, #b_prefetch> + %c_new_value = xegpu.dpas %a_value, %b_value, %c_value {layout_a = #a, layout_b = #b, layout_cd = #c} : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf32> -> vector<256x256xf32> scf.yield %c_new_value : vector<256x256xf32> } - xegpu.store_nd %out, %c_tdesc[%m, %n] : vector<256x256xf32>, !xegpu.tensor_desc<256x256xf32, #c> + xegpu.store_nd %out, %c_tdesc[%m, %n] {layout = #c}: vector<256x256xf32>, !xegpu.tensor_desc<256x256xf32, #c> gpu.return } }