From 382ed95ce88345cb154c8b6d69e49c89c39405b6 Mon Sep 17 00:00:00 2001 From: Garra1980 Date: Wed, 14 Jan 2026 21:50:50 +0100 Subject: [PATCH 1/2] [LLVM Pulldown] Bump to rev 50351218b3a4687079fe79f932aac0e00d5d990f --- build_tools/llvm_version.txt | 2 +- ...upport-for-VectorAnyINTEL-capability.patch | 119 +++++++++--------- ...dd-SPIRV_ExecutionModeAttributesAttr.patch | 38 +++--- ...8-length-vector-as-supported-vectors.patch | 98 +++++++-------- .../patches/relaxing_xegpu-propagation.patch | 33 ----- build_tools/patches/wg_fa_support.patch | 48 ++++--- .../XeTile/Transforms/BlockingAnalysis.cpp | 4 + lib/Transforms/VnniTransformation.cpp | 4 + .../Dialect/XeGPU/WG/arith_maximumf.mlir | 12 +- .../Dialect/XeGPU/WG/flash_attention_fwd.mlir | 50 ++++---- .../XeGPU/WG/gemm_4kx4kx4k_f16_f16_f32.mlir | 26 ++-- .../XeGPU/WG/gemm_4kx4kx4k_transpose_b.mlir | 34 ++--- .../Dialect/XeGPU/WG/math_exp.mlir | 12 +- .../Dialect/XeGPU/WG/simple_gemm.mlir | 26 ++-- 14 files changed, 248 insertions(+), 258 deletions(-) delete mode 100644 build_tools/patches/relaxing_xegpu-propagation.patch diff --git a/build_tools/llvm_version.txt b/build_tools/llvm_version.txt index c93d17e96..e1bb882e1 100644 --- a/build_tools/llvm_version.txt +++ b/build_tools/llvm_version.txt @@ -1 +1 @@ -19b28074618c92fa4c4281eeee67c715abebbfd7 +50351218b3a4687079fe79f932aac0e00d5d990f diff --git a/build_tools/patches/0001-Add-support-for-VectorAnyINTEL-capability.patch b/build_tools/patches/0001-Add-support-for-VectorAnyINTEL-capability.patch index 1102c888d..921285e5f 100644 --- a/build_tools/patches/0001-Add-support-for-VectorAnyINTEL-capability.patch +++ b/build_tools/patches/0001-Add-support-for-VectorAnyINTEL-capability.patch @@ -1,7 +1,7 @@ -From 887ae8599b205921bc4fd34da3c1de767f5568ae Mon Sep 17 00:00:00 2001 +From 1dd1130e8647473448ca9ce808e7bc6efa424675 Mon Sep 17 00:00:00 2001 From: Garra1980 -Date: Tue, 23 Sep 2025 21:22:18 +0200 -Subject: [PATCH] Add support for VectorAnyINTEL capability +Date: Mon, 5 Jan 2026 17:46:15 +0100 +Subject: [PATCH] Add-support-for-VectorAnyINTEL-capability --- .../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 11 +- @@ -24,7 +24,7 @@ Subject: [PATCH] Add support for VectorAnyINTEL capability 17 files changed, 324 insertions(+), 68 deletions(-) diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td -index 0e42d08cdb1f..f821b0d2e59b 100644 +index ecbbf39a534e..d6a72472bd1b 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -4240,7 +4240,14 @@ def SPIRV_BFloat16KHR : TypeAlias; @@ -45,21 +45,21 @@ index 0e42d08cdb1f..f821b0d2e59b 100644 // dialect-specific types so we use "Any" here. @@ -4293,7 +4300,7 @@ class SPIRV_MatrixOfType allowedTypes> : "Matrix">; - + class SPIRV_VectorOf : - FixedVectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>; + FixedVectorOfLengthRangeAndType<[2, 0xFFFFFFFF], [type]>; - + class SPIRV_ScalarOrVectorOf : AnyTypeOf<[type, SPIRV_VectorOf]>; diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td -index 6b4e3dd60319..987b33c055e9 100644 +index 0fb4837e528b..a33b18e8c868 100644 --- a/mlir/include/mlir/IR/CommonTypeConstraints.td +++ b/mlir/include/mlir/IR/CommonTypeConstraints.td -@@ -654,6 +654,92 @@ class ScalableVectorOfRankAndLengthAndType allowedRanks, +@@ -696,6 +696,92 @@ class ScalableVectorOfRankAndLengthAndType allowedRanks, ScalableVectorOfLength.summary, "::mlir::VectorType">; - + +// Whether the number of elements of a vector is from the given +// `allowedRanges` list, the list has two values, start and end of the range (inclusive) +class IsVectorOfLengthRangePred allowedRanges> : @@ -150,7 +150,7 @@ index 6b4e3dd60319..987b33c055e9 100644 // Negative values for `n` index in reverse. class ShapedTypeWithNthDimOfSize allowedSizes> : Type< diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp -index c8efdf009422..5236dc299f81 100644 +index 22b57d6c0821..8f02dd856d4e 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp @@ -186,9 +186,12 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect, @@ -169,25 +169,25 @@ index c8efdf009422..5236dc299f81 100644 return Type(); } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp -index 7e9a80e7d73a..1db6233cf73f 100644 +index 53a48abe5ad0..4c39a7c83281 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -186,9 +186,10 @@ bool CompositeType::classof(Type type) { } - + bool CompositeType::isValid(VectorType type) { - return type.getRank() == 1 && - llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) && -- llvm::isa(type.getElementType()); +- isa(type.getElementType()); + // Number of elements should be between [2 to 2^32 - 1]. -+ return type.getRank() == 1 && mlir::isa(type.getElementType()) && ++ return type.getRank() == 1 && isa(type.getElementType()) && + type.getNumElements() >= 2 && + type.getNumElements() <= std::numeric_limits::max(); } - + Type CompositeType::getElementType(unsigned index) const { -@@ -221,8 +222,23 @@ void TypeCapabilityVisitor::addConcrete(VectorType type) { - +@@ -218,8 +219,23 @@ void TypeCapabilityVisitor::addConcrete(VectorType type) { + int64_t vecSize = type.getNumElements(); if (vecSize == 8 || vecSize == 16) { - static constexpr auto cap = Capability::Vector16; @@ -211,9 +211,9 @@ index 7e9a80e7d73a..1db6233cf73f 100644 + capabilities.push_back(ref); } } - + diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp -index 122f61e0a66a..c6f37e9345ed 100644 +index 816226749463..31c590efcda3 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -84,9 +84,13 @@ static std::optional> getTargetShape(VectorType vecType) { @@ -230,7 +230,7 @@ index 122f61e0a66a..c6f37e9345ed 100644 + return llvm::is_contained(ors, elidedExt); + })) continue; - + LLVM_DEBUG({ @@ -112,9 +116,13 @@ static LogicalResult checkExtensionRequirements( template @@ -246,12 +246,12 @@ index 122f61e0a66a..c6f37e9345ed 100644 + return llvm::is_contained(ors, elidedCap); + })) continue; - + LLVM_DEBUG({ @@ -131,6 +139,55 @@ static LogicalResult checkCapabilityRequirements( return success(); } - + +/// Check capabilities and extensions requirements, +/// this function also checks for capability infered extension requirements, +/// the check is based on capabilities that are passed to the targetEnv. @@ -307,20 +307,20 @@ index 122f61e0a66a..c6f37e9345ed 100644 @@ -284,11 +341,14 @@ convertScalarType(const spirv::TargetEnv &targetEnv, return nullptr; } - + + // Convert to 32-bit float and remove floatType related capability + // restriction if (auto floatType = dyn_cast(type)) { LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); return Builder(targetEnv.getContext()).getF32Type(); } - + + // Convert to 32-bit int and remove intType related capability restriction auto intType = cast(type); LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); return IntegerType::get(targetEnv.getContext(), /*width=*/32, @@ -402,10 +462,13 @@ convertVectorType(const spirv::TargetEnv &targetEnv, - + if (type.getRank() <= 1 && type.getNumElements() == 1) return elementType; - @@ -336,11 +336,11 @@ index 122f61e0a66a..c6f37e9345ed 100644 + "between [2 - 2^32 -1]\n"); return nullptr; } - + @@ -427,16 +490,40 @@ convertVectorType(const spirv::TargetEnv &targetEnv, cast(type).getExtensions(extensions, storageClass); cast(type).getCapabilities(capabilities, storageClass); - + - // If all requirements are met, then we can accept this type as-is. - if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) && - succeeded(checkExtensionRequirements(type, targetEnv, extensions))) @@ -383,9 +383,9 @@ index 122f61e0a66a..c6f37e9345ed 100644 + return nullptr; + } } - + static Type -@@ -1693,16 +1780,18 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) { +@@ -1694,16 +1781,18 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) { SmallVector, 4> typeExtensions; SmallVector, 8> typeCapabilities; for (Type valueType : valueTypes) { @@ -411,7 +411,7 @@ index 122f61e0a66a..c6f37e9345ed 100644 + op->getName(), this->targetEnv, typeCapabilities, typeExtensions))) return false; } - + diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir index 9d7ab2be096e..3aa22e261f7c 100644 --- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir @@ -419,7 +419,7 @@ index 9d7ab2be096e..3aa22e261f7c 100644 @@ -28,9 +28,9 @@ module attributes { #spirv.vce, #spirv.resource_limits<>> } { - + -func.func @unsupported_5elem_vector(%arg0: vector<5xi32>) { +func.func @unsupported_5elem_vector(%arg0: vector<5xi32>, %arg1: vector<5xi32>) { // expected-error@+1 {{failed to legalize operation 'arith.subi'}} @@ -427,7 +427,7 @@ index 9d7ab2be096e..3aa22e261f7c 100644 + %1 = arith.subi %arg0, %arg1: vector<5xi32> return } - + diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir index 3cb529459899..e881d512bf2e 100644 --- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir @@ -472,7 +472,7 @@ index 3cb529459899..e881d512bf2e 100644 + %1 = arith.mulf %arg1, %arg1: vector<8xf64> return } - + diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir index 0c77c8833457..d6628afb7329 100644 --- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir @@ -480,7 +480,7 @@ index 0c77c8833457..d6628afb7329 100644 @@ -347,8 +347,21 @@ module attributes { spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { - + -// CHECK-NOT: spirv.func @large_vector -func.func @large_vector(%arg0: vector<1024xi32>) { return } +// CHECK-NOT: spirv.func @large_vector_unsupported @@ -498,15 +498,15 @@ index 0c77c8833457..d6628afb7329 100644 + +// CHECK: spirv.func @large_any_vector +func.func @large_any_vector(%arg0: vector<1024xi32>) { return } - + } // end module - + diff --git a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir index c703274bda57..670edc9deb91 100644 --- a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir @@ -349,6 +349,7 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f16 { - + func.func @dot(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> i32 { // expected-error @+1 {{'spirv.Dot' op operand #0 must be fixed-length vector of 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16}} + // expected-error @+1 {{'spirv.Dot' op operand #0 must be fixed-length vector of 16/32/64-bit float or BFloat16 values of length 2-4294967295}} @@ -519,7 +519,7 @@ index 4bdac198a1e8..dee8c7f9a65e 100644 +++ b/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir @@ -137,7 +137,7 @@ func.func @bitwise_or_all_ones_vector(%arg: vector<3xi8>) -> vector<3xi8> { // ----- - + func.func @bitwise_or_float(%arg0: f16, %arg1: f16) -> f16 { - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4}} + // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2-4294967295}} @@ -528,7 +528,7 @@ index 4bdac198a1e8..dee8c7f9a65e 100644 } @@ -165,7 +165,7 @@ func.func @bitwise_xor_vector(%arg: vector<4xi32>) -> vector<4xi32> { // ----- - + func.func @bitwise_xor_float(%arg0: f16, %arg1: f16) -> f16 { - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4}} + // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2-9223372036854775807}} @@ -537,7 +537,7 @@ index 4bdac198a1e8..dee8c7f9a65e 100644 } @@ -274,7 +274,7 @@ func.func @bitwise_and_zext_vector(%arg: vector<2xi8>) -> vector<2xi32> { // ----- - + func.func @bitwise_and_float(%arg0: f16, %arg1: f16) -> f16 { - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4}} + // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2-9223372036854775807}} @@ -550,7 +550,7 @@ index fd8a2ffbbddf..011759101a74 100644 +++ b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir @@ -27,7 +27,7 @@ func.func @exp(%arg0 : i32) -> () { // ----- - + func.func @exp(%arg0 : vector<5xf32>) -> () { - // expected-error @+1 {{op operand #0 must be 16/32-bit float or fixed-length vector of 16/32-bit float values of length 2/3/4}} + // CHECK: spirv.GL.Exp {{%.*}} : vector<5xf32 @@ -563,7 +563,7 @@ index 2e2fb1a9df32..ad8a66e16745 100644 +++ b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir @@ -21,7 +21,7 @@ spirv.func @f32_to_bf16_vec(%arg0 : vector<2xf32>) "None" { // ----- - + spirv.func @f32_to_bf16_unsupported(%arg0 : f64) "None" { - // expected-error @+1 {{operand #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16, but got}} + // expected-error @+1 {{op operand #0 must be Float32 or fixed-length vector of Float32 values of length 2-4294967295, but got 'f64'}} @@ -571,7 +571,7 @@ index 2e2fb1a9df32..ad8a66e16745 100644 spirv.Return } @@ -58,6 +58,7 @@ spirv.func @bf16_to_f32_vec(%arg0 : vector<2xi16>) "None" { - + spirv.func @bf16_to_f32_unsupported(%arg0 : i16) "None" { // expected-error @+1 {{result #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16, but got}} + // expected-error @+1 {{op result #0 must be Float32 or fixed-length vector of Float32 values of length 2-4294967295, but got 'f16'}} @@ -583,7 +583,7 @@ index d7f4ed05969a..3acd5b88e42a 100644 --- a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir @@ -184,7 +184,7 @@ func.func @logicalUnary(%arg0 : i1) - + func.func @logicalUnary(%arg0 : i32) { - // expected-error @+1 {{'operand' must be bool or fixed-length vector of bool values of length 2/3/4/8/16, but got 'i32'}} @@ -597,7 +597,7 @@ index bdb2abde8d8e..7b9b5d9a4688 100644 +++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir @@ -511,7 +511,7 @@ func.func @group_non_uniform_bitwise_and(%val: i32) -> i32 { // ----- - + func.func @group_non_uniform_bitwise_and(%val: i1) -> i1 { - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}} + // expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2-4294967295, but got 'i1'}} @@ -606,7 +606,7 @@ index bdb2abde8d8e..7b9b5d9a4688 100644 } @@ -532,7 +532,7 @@ func.func @group_non_uniform_bitwise_or(%val: i32) -> i32 { // ----- - + func.func @group_non_uniform_bitwise_or(%val: i1) -> i1 { - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}} + // expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2-4294967295, but got 'i1'}} @@ -615,7 +615,7 @@ index bdb2abde8d8e..7b9b5d9a4688 100644 } @@ -553,7 +553,7 @@ func.func @group_non_uniform_bitwise_xor(%val: i32) -> i32 { // ----- - + func.func @group_non_uniform_bitwise_xor(%val: i1) -> i1 { - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}} + // expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2-4294967295, but got 'i1'}} @@ -624,7 +624,7 @@ index bdb2abde8d8e..7b9b5d9a4688 100644 } @@ -574,7 +574,7 @@ func.func @group_non_uniform_logical_and(%val: i1) -> i1 { // ----- - + func.func @group_non_uniform_logical_and(%val: i32) -> i32 { - // expected-error @+1 {{operand #0 must be bool or fixed-length vector of bool values of length 2/3/4/8/16, but got 'i32'}} + // expected-error @+1 {{op operand #0 must be bool or fixed-length vector of bool values of length 2-4294967295, but got 'i32'}} @@ -633,7 +633,7 @@ index bdb2abde8d8e..7b9b5d9a4688 100644 } @@ -595,7 +595,7 @@ func.func @group_non_uniform_logical_or(%val: i1) -> i1 { // ----- - + func.func @group_non_uniform_logical_or(%val: i32) -> i32 { - // expected-error @+1 {{operand #0 must be bool or fixed-length vector of bool values of length 2/3/4/8/16, but got 'i32'}} + // expected-error @+1 {{op operand #0 must be bool or fixed-length vector of bool values of length 2-4294967295, but got 'i32'}} @@ -642,7 +642,7 @@ index bdb2abde8d8e..7b9b5d9a4688 100644 } @@ -616,7 +616,7 @@ func.func @group_non_uniform_logical_xor(%val: i1) -> i1 { // ----- - + func.func @group_non_uniform_logical_xor(%val: i32) -> i32 { - // expected-error @+1 {{operand #0 must be bool or fixed-length vector of bool values of length 2/3/4/8/16, but got 'i32'}} + // expected-error @+1 {{op operand #0 must be bool or fixed-length vector of bool values of length 2-4294967295, but got 'i32'}} @@ -655,7 +655,7 @@ index 6aaaa6012fef..60ef7afeeeed 100644 +++ b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir @@ -27,7 +27,7 @@ func.func @exp(%arg0 : i32) -> () { // ----- - + func.func @exp(%arg0 : vector<5xf32>) -> () { - // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4}} + // CHECK: spirv.CL.exp {{%.*}} : vector<5xf32> @@ -665,7 +665,7 @@ index 6aaaa6012fef..60ef7afeeeed 100644 @@ -66,6 +66,14 @@ func.func @fabsvec(%arg0 : vector<3xf16>) -> () { return } - + +// ----- + +func.func @fabs_any_vec(%arg0 : vector<5xf32>) -> () { @@ -678,9 +678,9 @@ index 6aaaa6012fef..60ef7afeeeed 100644 // CHECK: spirv.CL.fabs {{%.*}} : f64 %2 = spirv.CL.fabs %arg0 : f64 @@ -82,14 +90,6 @@ func.func @fabs(%arg0 : i32) -> () { - + // ----- - + -func.func @fabs(%arg0 : vector<5xf32>) -> () { - // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4}} - %2 = spirv.CL.fabs %arg0 : vector<5xf32> @@ -695,7 +695,7 @@ index 6aaaa6012fef..60ef7afeeeed 100644 @@ -122,6 +122,14 @@ func.func @sabsvec(%arg0 : vector<3xi16>) -> () { return } - + +// ----- + +func.func @sabs_any_vec(%arg0 : vector<5xi32>) -> () { @@ -708,9 +708,9 @@ index 6aaaa6012fef..60ef7afeeeed 100644 // CHECK: spirv.CL.s_abs {{%.*}} : i64 %2 = spirv.CL.s_abs %arg0 : i64 @@ -144,14 +152,6 @@ func.func @sabs(%arg0 : f32) -> () { - + // ----- - + -func.func @sabs(%arg0 : vector<5xi32>) -> () { - // expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4}} - %2 = spirv.CL.s_abs %arg0 : vector<5xi32> @@ -746,7 +746,7 @@ index 17accd93e824..ed9a9976e89b 100644 @@ -44,6 +44,12 @@ spirv.module Physical64 OpenCL requires #spirv.vce) "None" { + // CHECK: {{%.*}} = spirv.CL.fabs {{%.*}} : vector<5000xf32> + %0 = spirv.CL.fabs %arg0 : vector<5000xf32> @@ -756,5 +756,6 @@ index 17accd93e824..ed9a9976e89b 100644 spirv.func @fma(%arg0 : f32, %arg1 : f32, %arg2 : f32) "None" { // CHECK: spirv.CL.fma {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : f32 %13 = spirv.CL.fma %arg0, %arg1, %arg2 : f32 --- +-- 2.34.1 + diff --git a/build_tools/patches/0006-Add-SPIRV_ExecutionModeAttributesAttr.patch b/build_tools/patches/0006-Add-SPIRV_ExecutionModeAttributesAttr.patch index 2625040d9..08eb7615e 100644 --- a/build_tools/patches/0006-Add-SPIRV_ExecutionModeAttributesAttr.patch +++ b/build_tools/patches/0006-Add-SPIRV_ExecutionModeAttributesAttr.patch @@ -1,9 +1,8 @@ -From a18280b49a33d421477db322d642aff187b029f8 Mon Sep 17 00:00:00 2001 -From: Dimple Prajapati -Date: Tue, 7 May 2024 23:26:34 +0000 -Subject: [PATCH] Add SPIRV_ExecutionModeAttributesAttr +From ab51a97f5358afe72bfce6483a3416a50dc40018 Mon Sep 17 00:00:00 2001 +From: Garra1980 +Date: Mon, 5 Jan 2026 17:51:39 +0100 +Subject: [PATCH] Add-SPIRV_ExecutionModeAttributesAttr -add spirv.ExecutionMode Op during GPUToSPIRV Pass lowering --- .../mlir/Dialect/SPIRV/IR/SPIRVAttributes.td | 11 +++++++++++ .../mlir/Dialect/SPIRV/IR/TargetAndABI.h | 8 ++++++++ @@ -13,13 +12,13 @@ add spirv.ExecutionMode Op during GPUToSPIRV Pass lowering 5 files changed, 55 insertions(+) diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td -index 3a11284da051..1267ecd251ae 100644 +index 1bc3c63646fd..cf5b5ffa451d 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td @@ -56,6 +56,17 @@ def SPIRV_LinkageAttributesAttr : SPIRV_Attr<"LinkageAttributes", "linkage_attri let assemblyFormat = "`<` struct(params) `>`"; } - + +// This attribute specifies SPIR-V execution mode information via GPU functions +// 1) Execution mode attribute. +// 2) [optional] Execution mode value. @@ -41,7 +40,7 @@ index 24574bfaf619..f64f99294038 100644 @@ -137,6 +137,14 @@ FailureOr getExecutionModel(TargetEnvAttr targetAttr); /// Returns failure if it cannot be selected. FailureOr getMemoryModel(TargetEnvAttr targetAttr); - + +/// Returns the attribute name for specifying execution mode attribute +/// information. +StringRef getExecutionModeFuncAttrName(); @@ -52,12 +51,12 @@ index 24574bfaf619..f64f99294038 100644 + } // namespace spirv } // namespace mlir - + diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp -index d7885e035959..5195035f088f 100644 +index c33a903d0339..a6578465abac 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp -@@ -343,6 +343,20 @@ LogicalResult GPUFuncOpConversion::matchAndRewrite( +@@ -346,6 +346,20 @@ LogicalResult GPUFuncOpConversion::matchAndRewrite( return failure(); newFuncOp->removeAttr( rewriter.getStringAttr(gpu::GPUDialect::getKernelFuncAttrName())); @@ -77,27 +76,27 @@ index d7885e035959..5195035f088f 100644 + return success(); } - + diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp -index 65aaafa55386..b9f906ada3ee 100644 +index 8f02dd856d4e..dfb13c173596 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp -@@ -948,6 +948,10 @@ LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op, +@@ -1026,6 +1026,10 @@ LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op, } else if (symbol == spirv::getTargetEnvAttrName()) { - if (!llvm::isa(attr)) + if (!isa(attr)) return op->emitError("'") << symbol << "' must be a spirv::TargetEnvAttr"; + } else if (symbol == spirv::getExecutionModeFuncAttrName()) { -+ if (!llvm::isa(attr)) ++ if (!isa(attr)) + return op->emitError("'") + << symbol << "' must be a spirv::ExecutionModeFuncAttributeAttr"; } else { return op->emitError("found unsupported '") << symbol << "' attribute on operation"; diff --git a/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp b/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp -index bbc318e17300..56251a8f3990 100644 +index 8c52ba8b8583..864fc75c53b1 100644 --- a/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp -@@ -242,3 +242,21 @@ spirv::getMemoryModel(spirv::TargetEnvAttr targetAttr) { +@@ -238,3 +238,21 @@ spirv::getMemoryModel(spirv::TargetEnvAttr targetAttr) { } return failure(); } @@ -119,5 +118,6 @@ index bbc318e17300..56251a8f3990 100644 + + return {}; +} --- +-- 2.34.1 + diff --git a/build_tools/patches/0014-Add-32-64-and-128-length-vector-as-supported-vectors.patch b/build_tools/patches/0014-Add-32-64-and-128-length-vector-as-supported-vectors.patch index 19316b6a7..a2dbdbf01 100644 --- a/build_tools/patches/0014-Add-32-64-and-128-length-vector-as-supported-vectors.patch +++ b/build_tools/patches/0014-Add-32-64-and-128-length-vector-as-supported-vectors.patch @@ -1,24 +1,20 @@ -From 27bd5d19a0f122d9acded716fe936d07416e1308 Mon Sep 17 00:00:00 2001 -From: "Shahneous Bari, Md Abdullah" -Date: Wed, 17 Dec 2025 14:21:48 +0000 -Subject: [PATCH] Add 32, 64 and 128 length vector as supported vectors. +From 0f49dfa51b379d1197b290e810ce2d3d178b5dde Mon Sep 17 00:00:00 2001 +From: Garra1980 +Date: Mon, 5 Jan 2026 15:48:16 +0100 +Subject: [PATCH] Add-32-64-and-128-length-vector-as-supported-vectors -This is needed to support large loads/stores using OpenCL intrinsics in -MLIR workflow. -THIS IS A HACK and temporary solution, -need to re-visit this with better solution later. --- - llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp | 81 ++++++++++++++------ - 1 file changed, 59 insertions(+), 22 deletions(-) + llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp | 82 +++++++++++++++----- + 1 file changed, 61 insertions(+), 21 deletions(-) diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp -index 30703ee40be0..2f3baa1b6c7e 100644 +index 590182731b00..0d9f4df374f9 100644 --- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp @@ -50,6 +50,28 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { const LLT s64 = LLT::scalar(64); const LLT s128 = LLT::scalar(128); - + + // @IMEX, Add 32, 64 and 128 length vector as supported vectors. + // This is needed to support large loads/stores using OpenCL intrinsics in + // MLIR workflow. @@ -44,20 +40,15 @@ index 30703ee40be0..2f3baa1b6c7e 100644 const LLT v16s64 = LLT::fixed_vector(16, 64); const LLT v16s32 = LLT::fixed_vector(16, 32); const LLT v16s16 = LLT::fixed_vector(16, 16); -@@ -99,41 +121,53 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { - +@@ -100,16 +122,22 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { + // TODO: remove copy-pasting here by using concatenation in some way. auto allPtrsScalarsAndVectors = { -- p0, p1, p2, p3, p4, p5, p6, p7, p8, -- p9, p10, p11, p12, s1, s8, s16, s32, s64, -- v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, v3s16, v3s32, -- v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, v8s8, v8s16, -- v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64}; -- -- auto allVectors = {v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, -- v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, -- v4s64, v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, -- v16s8, v16s16, v16s32, v16s64}; +- p0, p1, p2, p3, p4, p5, p6, p7, p8, +- p9, p10, p11, p12, p13, s1, s8, s16, s32, +- s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, v3s16, +- v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, v8s8, +- v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64}; + p0, p1, p2, p3, p4, p5, p6, p7, p8, + p9, p10, p11, p12, s1, s8, s16, s32, s64, + s128, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, v3s16, @@ -73,11 +64,17 @@ index 30703ee40be0..2f3baa1b6c7e 100644 + v16s8, v16s16, v16s32, v16s64, v32s1, v32s8, v32s16, + v32s32, v32s64, v64s1, v64s8, v64s16, v64s32, v64s64, + v128s1, v128s8, v128s16, v128s32, v128s64}; - + +- auto allVectors = {v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, +- v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, +- v4s64, v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, +- v16s8, v16s16, v16s32, v16s64}; + auto allShaderVectors = {v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, v3s16, v3s32, v3s64, - v4s1, v4s8, v4s16, v4s32, v4s64}; - +@@ -118,25 +146,32 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { + auto allScalars = {s1, s8, s16, s32, s64}; + auto allScalarsAndVectors = { - s1, s8, s16, s32, s64, s128, v2s1, v2s8, - v2s16, v2s32, v2s64, v3s1, v3s8, v3s16, v3s32, v3s64, @@ -89,7 +86,7 @@ index 30703ee40be0..2f3baa1b6c7e 100644 + v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64, v32s1, + v32s8, v32s16, v32s32, v32s64, v64s1, v64s8, v64s16, v64s32, + v64s64, v128s1, v128s8, v128s16, v128s32, v128s64}; - + auto allIntScalarsAndVectors = { - s8, s16, s32, s64, s128, v2s8, v2s16, v2s32, v2s64, - v3s8, v3s16, v3s32, v3s64, v4s8, v4s16, v4s32, v4s64, v8s8, @@ -99,15 +96,15 @@ index 30703ee40be0..2f3baa1b6c7e 100644 + v8s8, v8s16, v8s32, v8s64, v16s8, v16s16, v16s32, v16s64, + v32s1, v32s8, v32s16, v32s32, v32s64, v64s1, v64s8, v64s16, + v64s32, v64s64, v128s1, v128s8, v128s16, v128s32, v128s64}; - + - auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1}; + auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, + v8s1, v16s1, v32s1, v64s1}; - + auto allIntScalars = {s8, s16, s32, s64, s128}; - + auto allFloatScalarsAndF16Vector2AndVector4s = {s16, s32, s64, v2s16, v4s16}; - + auto allFloatScalarsAndVectors = { - s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64, - v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64}; @@ -115,10 +112,10 @@ index 30703ee40be0..2f3baa1b6c7e 100644 + v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64, + v32s1, v32s8, v32s16, v32s32, v32s64, v64s1, v64s8, v64s16, v64s32, + v64s64, v128s1, v128s8, v128s16, v128s32, v128s64}; - - auto allFloatAndIntScalarsAndPtrs = {s8, s16, s32, s64, p0, p1, p2, p3, p4, - p5, p6, p7, p8, p9, p10, p11, p12}; -@@ -170,7 +204,9 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { + + auto allFloatAndIntScalarsAndPtrs = {s8, s16, s32, s64, p0, p1, + p2, p3, p4, p5, p6, p7, +@@ -174,7 +209,9 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { // shader execution models, vector sizes are strictly limited to 4. In // non-shader contexts, vector sizes of 8 and 16 are also permitted, but // arbitrary sizes (e.g., 6 or 11) are not. @@ -126,18 +123,21 @@ index 30703ee40be0..2f3baa1b6c7e 100644 + + // @IMEX, make the max vector size to be 128 for now. + uint32_t MaxVectorSize = ST.isShader() ? 4 : 128; - + LLVM_DEBUG(dbgs() << "MaxVectorSize: " << MaxVectorSize << "\n"); + for (auto Opc : getTypeFoldingSupportedOpcodes()) { - if (Opc != G_EXTRACT_VECTOR_ELT) -@@ -531,7 +567,8 @@ bool SPIRVLegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper, - LLT DstTy = MRI.getType(DstReg); - LLT SrcTy = MRI.getType(SrcReg); - -- int32_t MaxVectorSize = ST.isShader() ? 4 : 16; -+ // @IMEX make the max vector size to be 128 -+ int32_t MaxVectorSize = ST.isShader() ? 4 : 128; +@@ -579,7 +616,10 @@ static bool needsVectorLegalization(const LLT &Ty, const SPIRVSubtarget &ST) { + if (!Ty.isVector()) + return false; + unsigned NumElements = Ty.getNumElements(); +- unsigned MaxVectorSize = ST.isShader() ? 4 : 16; ++ ++ // @IMEX make the max vector size to be 128 ++ int32_t MaxVectorSize = ST.isShader() ? 4 : 128; ++ + return (NumElements > 4 && !isPowerOf2_32(NumElements)) || + NumElements > MaxVectorSize; + } +-- +2.34.1 - bool DstNeedsLegalization = false; - bool SrcNeedsLegalization = false; --- -2.43.0 diff --git a/build_tools/patches/relaxing_xegpu-propagation.patch b/build_tools/patches/relaxing_xegpu-propagation.patch deleted file mode 100644 index 16fac6c32..000000000 --- a/build_tools/patches/relaxing_xegpu-propagation.patch +++ /dev/null @@ -1,33 +0,0 @@ -From 8f85a84fdad1c1ee6c39a3b04e74e0d8513a6abf Mon Sep 17 00:00:00 2001 -From: Garra1980 -Date: Tue, 25 Nov 2025 22:45:25 +0100 -Subject: [PATCH] relaxing_xegpu propagation - ---- - mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp | 4 ---- - 1 file changed, 4 deletions(-) - -diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp -index 6b3ba5a5981c..ecd6a3358f4a 100644 ---- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp -+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp -@@ -127,8 +127,6 @@ public: - SmallVector LayoutInfo::getLaneLayout() const { - if (!isAssigned()) - return {}; -- assert(storage.getEffectiveLaneLayoutAsInt().size() && -- "Expected lane layout to be assigned"); - return llvm::map_to_vector(storage.getEffectiveLaneLayoutAsInt(), - [](int64_t val) { return static_cast(val); }); - } -@@ -136,8 +134,6 @@ SmallVector LayoutInfo::getLaneLayout() const { - SmallVector LayoutInfo::getLaneData() const { - if (!isAssigned()) - return {}; -- assert(storage.getEffectiveLaneDataAsInt().size() && -- "Expected lane data to be assigned"); - return llvm::map_to_vector(storage.getEffectiveLaneDataAsInt(), - [](int64_t val) { return static_cast(val); }); - } --- -2.34.1 diff --git a/build_tools/patches/wg_fa_support.patch b/build_tools/patches/wg_fa_support.patch index c703fc0e7..9e9c8e5ab 100644 --- a/build_tools/patches/wg_fa_support.patch +++ b/build_tools/patches/wg_fa_support.patch @@ -1,54 +1,65 @@ +From 4342fab14630d1ee774ec78dea66fd78f200a3ad Mon Sep 17 00:00:00 2001 +From: Garra1980 +Date: Mon, 22 Dec 2025 16:40:20 +0100 +Subject: [PATCH] wg_fa_support + +--- + .../XeGPU/Transforms/XeGPUBlocking.cpp | 6 +- + .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 59 +++++++++++++++++++ + .../Transforms/XeGPUWgToSgDistribute.cpp | 12 ++-- + 3 files changed, 68 insertions(+), 9 deletions(-) + diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp -index ec5feb8bc8c4..c8b9f0eb6a06 100644 +index ba2753f517ce..dc4f05c0d914 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp @@ -228,7 +228,7 @@ XeGPUBlockingPass::getTileShape(Operation *op) const { if (isa(op)) return getTileShape(op->getOpOperand(0)); - + - if (isa(op)) + if (isa(op)) return getTileShape(op->getOpResult(0)); - + return std::nullopt; -@@ -413,14 +413,14 @@ void XeGPUBlockingPass::runOnOperation() { +@@ -415,14 +415,14 @@ void XeGPUBlockingPass::runOnOperation() { // Remove the layout attributes cached per operands. for (OpOperand &opr : op->getOpOperands()) { - std::string name = xegpu::getLayoutName(opr); + std::string name = xegpu::getTemporaryLayoutName(opr); - if (op->hasAttrOfType(name)) + if (op->hasAttrOfType(name)) op->removeAttr(name); } - + // Update the layout attributes per result. for (OpResult result : op->getOpResults()) { - std::string name = xegpu::getLayoutName(result); + std::string name = xegpu::getTemporaryLayoutName(result); - if (auto layout = op->getAttrOfType(name)) { + if (auto layout = op->getAttrOfType(name)) { op->removeAttr(name); if (!isa(op)) xegpu::setDistributeLayoutAttr(result, layout.dropInstData()); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp -index 574fe29f4bab..64b4436d7016 100644 +index 7fc75e7294ea..94f31e653511 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp -@@ -1273,6 +1273,60 @@ static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder, +@@ -1336,6 +1336,60 @@ static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder, return success(); } - + +static LogicalResult resolveConflicts(Operation *op) { + auto r = op->walk([&](xegpu::LoadNdOp loadNdOp) -> WalkResult { + // Load op has a conflict if tensor desc layout is different from the its + // result layout. + auto getResultLayout = [](OpResult result) { -+ auto resultLayoutName = xegpu::getLayoutName(result); ++ auto resultLayoutName = xegpu::getTemporaryLayoutName(result); + return result.getOwner()->getAttrOfType( + resultLayoutName); + }; + auto hasConflict = [&getResultLayout](xegpu::LoadNdOp loadNdOp) -> bool { + auto tdescType = loadNdOp.getTensorDescType(); + auto tdescLayout = tdescType.getLayout(); -+ // auto resultLayoutName = xegpu::getLayoutName(loadNdOp->getOpResult(0)); ++ // auto resultLayoutName = xegpu::getTemporaryLayoutName(loadNdOp->getOpResult(0)); + auto resultLayout = getResultLayout(loadNdOp->getOpResult(0)); + return tdescLayout && resultLayout && tdescLayout != resultLayout; + }; @@ -93,7 +104,7 @@ index 574fe29f4bab..64b4436d7016 100644 namespace { struct XeGPUPropagateLayoutPass final : public xegpu::impl::XeGPUPropagateLayoutBase { -@@ -1346,4 +1400,9 @@ void XeGPUPropagateLayoutPass::runOnOperation() { +@@ -1411,4 +1465,9 @@ void XeGPUPropagateLayoutPass::runOnOperation() { signalPassFailure(); return; } @@ -104,10 +115,10 @@ index 574fe29f4bab..64b4436d7016 100644 + } } diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp -index 48bd0662b03f..76ed73fdfaef 100644 +index 07572a495076..448d78f4dc4f 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp -@@ -1238,13 +1238,13 @@ struct WgToSgVectorTransposeOp +@@ -1270,13 +1270,13 @@ struct WgToSgVectorTransposeOp SmallVector sourceSgLayout = sourceLayout.getEffectiveSgLayoutAsInt(); SmallVector resultSgLayout = layout.getEffectiveSgLayoutAsInt(); @@ -115,7 +126,7 @@ index 48bd0662b03f..76ed73fdfaef 100644 - DenseI32ArrayAttr resultOrder = layout.getOrder(); + // DenseI32ArrayAttr sourceOrder = sourceLayout.getOrder(); + // DenseI32ArrayAttr resultOrder = layout.getOrder(); - + - if (!sourceOrder || !resultOrder) { - return rewriter.notifyMatchFailure( - op, "Both source and result must have order attributes"); @@ -124,6 +135,9 @@ index 48bd0662b03f..76ed73fdfaef 100644 + // return rewriter.notifyMatchFailure( + // op, "Both source and result must have order attributes"); + // } - + ArrayRef permutation = op.getPermutation(); size_t permutationSize = permutation.size(); +-- +2.34.1 + diff --git a/lib/Dialect/XeTile/Transforms/BlockingAnalysis.cpp b/lib/Dialect/XeTile/Transforms/BlockingAnalysis.cpp index 9b48a54a2..600cc484c 100644 --- a/lib/Dialect/XeTile/Transforms/BlockingAnalysis.cpp +++ b/lib/Dialect/XeTile/Transforms/BlockingAnalysis.cpp @@ -282,6 +282,10 @@ class BlockingAnalysisImpl void visitBranchOperand(mlir::OpOperand &operand) override {} + void + visitNonControlFlowArguments(mlir::RegionSuccessor &successor, + mlir::ArrayRef arguments) override {}; + void visitCallOperand(mlir::OpOperand &operand) override {} void setToExitState(BlockingLattice *lattice) override {} diff --git a/lib/Transforms/VnniTransformation.cpp b/lib/Transforms/VnniTransformation.cpp index 0b17ade66..2b48f9cbd 100644 --- a/lib/Transforms/VnniTransformation.cpp +++ b/lib/Transforms/VnniTransformation.cpp @@ -231,6 +231,10 @@ class LayoutAnalysisImpl void visitCallOperand(mlir::OpOperand &operand) override {} + void + visitNonControlFlowArguments(mlir::RegionSuccessor &successor, + mlir::ArrayRef arguments) override {}; + void setToExitState(LayoutLattice *lattice) override { (void)lattice->meet(false); } diff --git a/test/Integration/Dialect/XeGPU/WG/arith_maximumf.mlir b/test/Integration/Dialect/XeGPU/WG/arith_maximumf.mlir index 48624dc0c..a748d8e6b 100644 --- a/test/Integration/Dialect/XeGPU/WG/arith_maximumf.mlir +++ b/test/Integration/Dialect/XeGPU/WG/arith_maximumf.mlir @@ -45,12 +45,12 @@ module @gemm attributes {gpu.container_module} { %m = arith.muli %block_id_x, %c256 : index %n = arith.muli %block_id_y, %c256 : index %input_tdesc_1 = xegpu.create_nd_tdesc %input1_gpu : memref<256x256xf32> -> !xegpu.tensor_desc<256x256xf32, #map> - %input_val_1 = xegpu.load_nd %input_tdesc_1[%m, %n] : !xegpu.tensor_desc<256x256xf32, #map> -> vector<256x256xf32> + %input_val_1 = xegpu.load_nd %input_tdesc_1[%m, %n] {layout = #map}: !xegpu.tensor_desc<256x256xf32, #map> -> vector<256x256xf32> %input_tdesc_2 = xegpu.create_nd_tdesc %input2_gpu : memref<256x256xf32> -> !xegpu.tensor_desc<256x256xf32, #map> - %input_val_2 = xegpu.load_nd %input_tdesc_2[%m, %n] : !xegpu.tensor_desc<256x256xf32, #map> -> vector<256x256xf32> + %input_val_2 = xegpu.load_nd %input_tdesc_2[%m, %n] {layout = #map}: !xegpu.tensor_desc<256x256xf32, #map> -> vector<256x256xf32> %result_val = arith.maximumf %input_val_1, %input_val_2 {layout_result_0 = #map} : vector<256x256xf32> %result_tdesc = xegpu.create_nd_tdesc %result_gpu : memref<256x256xf32> -> !xegpu.tensor_desc<256x256xf32, #map> - xegpu.store_nd %result_val, %result_tdesc[%m, %n] : vector<256x256xf32>, !xegpu.tensor_desc<256x256xf32, #map> + xegpu.store_nd %result_val, %result_tdesc[%m, %n] {layout = #map}: vector<256x256xf32>, !xegpu.tensor_desc<256x256xf32, #map> gpu.return } } @@ -64,12 +64,12 @@ module @gemm attributes {gpu.container_module} { %m = arith.muli %block_id_x, %c256 : index %n = arith.muli %block_id_y, %c256 : index %input_tdesc_1 = xegpu.create_nd_tdesc %input1_gpu : memref<256x256xf32> -> !xegpu.tensor_desc<256x256xf32, #map> - %input_val_1 = xegpu.load_nd %input_tdesc_1[%m, %n] : !xegpu.tensor_desc<256x256xf32, #map> -> vector<256x256xf32> + %input_val_1 = xegpu.load_nd %input_tdesc_1[%m, %n] {layout = #map}: !xegpu.tensor_desc<256x256xf32, #map> -> vector<256x256xf32> %input_tdesc_2 = xegpu.create_nd_tdesc %input2_gpu : memref<256x256xf32> -> !xegpu.tensor_desc<256x256xf32, #map> - %input_val_2 = xegpu.load_nd %input_tdesc_2[%m, %n] : !xegpu.tensor_desc<256x256xf32, #map> -> vector<256x256xf32> + %input_val_2 = xegpu.load_nd %input_tdesc_2[%m, %n] {layout = #map}: !xegpu.tensor_desc<256x256xf32, #map> -> vector<256x256xf32> %result_val = arith.maximumf %input_val_1, %input_val_2 fastmath {layout_result_0 = #map} : vector<256x256xf32> %result_tdesc = xegpu.create_nd_tdesc %result_gpu : memref<256x256xf32> -> !xegpu.tensor_desc<256x256xf32, #map> - xegpu.store_nd %result_val, %result_tdesc[%m, %n] : vector<256x256xf32>, !xegpu.tensor_desc<256x256xf32, #map> + xegpu.store_nd %result_val, %result_tdesc[%m, %n] {layout = #map}: vector<256x256xf32>, !xegpu.tensor_desc<256x256xf32, #map> gpu.return } } diff --git a/test/Integration/Dialect/XeGPU/WG/flash_attention_fwd.mlir b/test/Integration/Dialect/XeGPU/WG/flash_attention_fwd.mlir index f8480b6a3..3886a8e7e 100644 --- a/test/Integration/Dialect/XeGPU/WG/flash_attention_fwd.mlir +++ b/test/Integration/Dialect/XeGPU/WG/flash_attention_fwd.mlir @@ -68,17 +68,17 @@ module @flash_attention attributes {gpu.container_module} { // For prefetch SG layout is 4x2. Each SG prefetch 16x32xf16 tile. // Note that prefetch x offset is same as Q x offset. This is because WGs in same batch colloborate on K and V prefetch. %k_prefetch_tile = xegpu.create_nd_tdesc %K , shape: [%size_x, %BLOCK_DMODEL], strides: [%BLOCK_DMODEL, %c1] : memref -> !xegpu.tensor_desc<64x64xf16, #k_prefetch> - xegpu.prefetch_nd %k_prefetch_tile[%wg_q_x_offset, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<64x64xf16, #k_prefetch> + xegpu.prefetch_nd %k_prefetch_tile[%wg_q_x_offset, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, layout = #k_prefetch} : !xegpu.tensor_desc<64x64xf16, #k_prefetch> %wg_q_x_offset_plus_BLOCK_N = arith.addi %wg_q_x_offset, %BLOCK_N : index - xegpu.prefetch_nd %k_prefetch_tile[%wg_q_x_offset_plus_BLOCK_N, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<64x64xf16, #k_prefetch> + xegpu.prefetch_nd %k_prefetch_tile[%wg_q_x_offset_plus_BLOCK_N, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, layout = #k_prefetch} : !xegpu.tensor_desc<64x64xf16, #k_prefetch> %wg_q_x_offset_plus_2_BLOCK_N = arith.addi %wg_q_x_offset_plus_BLOCK_N, %BLOCK_N : index - xegpu.prefetch_nd %k_prefetch_tile[%wg_q_x_offset_plus_2_BLOCK_N, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<64x64xf16, #k_prefetch> + xegpu.prefetch_nd %k_prefetch_tile[%wg_q_x_offset_plus_2_BLOCK_N, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, layout = #k_prefetch} : !xegpu.tensor_desc<64x64xf16, #k_prefetch> // V prefetch is similar to K %v_prefetch_tile = xegpu.create_nd_tdesc %V , shape: [%size_x, %BLOCK_DMODEL], strides: [%BLOCK_DMODEL, %c1] : memref -> !xegpu.tensor_desc<64x64xf16, #v_prefetch> - xegpu.prefetch_nd %v_prefetch_tile[%wg_q_x_offset, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<64x64xf16, #v_prefetch> - xegpu.prefetch_nd %v_prefetch_tile[%wg_q_x_offset_plus_BLOCK_N, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<64x64xf16, #v_prefetch> - xegpu.prefetch_nd %v_prefetch_tile[%wg_q_x_offset_plus_2_BLOCK_N, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<64x64xf16, #v_prefetch> + xegpu.prefetch_nd %v_prefetch_tile[%wg_q_x_offset, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, layout = #v_prefetch} : !xegpu.tensor_desc<64x64xf16, #v_prefetch> + xegpu.prefetch_nd %v_prefetch_tile[%wg_q_x_offset_plus_BLOCK_N, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, layout = #v_prefetch} : !xegpu.tensor_desc<64x64xf16, #v_prefetch> + xegpu.prefetch_nd %v_prefetch_tile[%wg_q_x_offset_plus_2_BLOCK_N, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, layout = #v_prefetch} : !xegpu.tensor_desc<64x64xf16, #v_prefetch> %BLOCK_N_3_t = arith.addi %BLOCK_N, %BLOCK_N : index %BLOCK_N_3 = arith.addi %BLOCK_N_3_t, %BLOCK_N : index @@ -98,7 +98,7 @@ module @flash_attention attributes {gpu.container_module} { // Load Q tile. Each WG loads 128x64xf16 tile of Q. - %q_value = xegpu.load_nd %q_tile[%wg_q_x_offset, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<128x64xf16, #q> -> vector<128x64xf16> + %q_value = xegpu.load_nd %q_tile[%wg_q_x_offset, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, layout = #q} : !xegpu.tensor_desc<128x64xf16, #q> -> vector<128x64xf16> // Inner loop. This loop iterate over K and V tiles and update the accumulator by computing softmax(q*k^T)*v // K and V tiles are accessed in 64x64xf16 blocks (BLOCK_N=64). However, we load them in 16x64xf16 slices. @@ -116,42 +116,42 @@ module @flash_attention attributes {gpu.container_module} { // K prefetch %prefetch_offset_x_running_t = arith.addi %BLOCK_N_3, %k : index %prefetch_offset_x_running = arith.addi %wg_q_x_offset, %prefetch_offset_x_running_t : index - xegpu.prefetch_nd %k_prefetch_tile[%prefetch_offset_x_running, %c0] : !xegpu.tensor_desc<64x64xf16, #k_prefetch> + xegpu.prefetch_nd %k_prefetch_tile[%prefetch_offset_x_running, %c0] {layout = #k_prefetch}: !xegpu.tensor_desc<64x64xf16, #k_prefetch> // V prefetch - xegpu.prefetch_nd %v_prefetch_tile[%prefetch_offset_x_running, %c0] : !xegpu.tensor_desc<64x64xf16, #v_prefetch> + xegpu.prefetch_nd %v_prefetch_tile[%prefetch_offset_x_running, %c0] {layout = #v_prefetch}: !xegpu.tensor_desc<64x64xf16, #v_prefetch> // Load first 16x64xf16 K slice. K is in column major layout, so we need to transpose after loading. %wg_x_offset_running = arith.addi %wg_x_offset, %k : index - %k_value_slice_0_t0 = xegpu.load_nd %k_tile_slice[%wg_x_offset_running, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x64xf16, #k> -> vector<16x64xf16> + %k_value_slice_0_t0 = xegpu.load_nd %k_tile_slice[%wg_x_offset_running, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, layout = #k} : !xegpu.tensor_desc<16x64xf16, #k> -> vector<16x64xf16> %k_value_slice_0 = vector.transpose %k_value_slice_0_t0, [1, 0] {layout_result_0 = #kt} : vector<16x64xf16> to vector<64x16xf16> // Compute first 128x16 of Q * K^T using DPAS. - %qk_out_0 = xegpu.dpas %q_value, %k_value_slice_0, %zero_dpas_128x16 {layout_result_0 = #layout_128x16} : vector<128x64xf16>, vector<64x16xf16>, vector<128x16xf32> -> vector<128x16xf32> + %qk_out_0 = xegpu.dpas %q_value, %k_value_slice_0, %zero_dpas_128x16 {layout_a = #q, layout_b = #kt, layout_cd = #layout_128x16} : vector<128x64xf16>, vector<64x16xf16>, vector<128x16xf32> -> vector<128x16xf32> // Load second 16x64xf16 K slice. %wg_x_offset_running_plus_16 = arith.addi %wg_x_offset_running, %c16 : index - %k_value_slice_1_t0 = xegpu.load_nd %k_tile_slice[%wg_x_offset_running_plus_16, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x64xf16, #k> -> vector<16x64xf16> + %k_value_slice_1_t0 = xegpu.load_nd %k_tile_slice[%wg_x_offset_running_plus_16, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, layout = #k} : !xegpu.tensor_desc<16x64xf16, #k> -> vector<16x64xf16> %k_value_slice_1 = vector.transpose %k_value_slice_1_t0, [1, 0] {layout_result_0 = #kt} : vector<16x64xf16> to vector<64x16xf16> // Compute second 128x16 of Q * K^T using DPAS - %qk_out_1 = xegpu.dpas %q_value, %k_value_slice_1, %zero_dpas_128x16 {layout_result_0 = #layout_128x16} : vector<128x64xf16>, vector<64x16xf16>, vector<128x16xf32> -> vector<128x16xf32> + %qk_out_1 = xegpu.dpas %q_value, %k_value_slice_1, %zero_dpas_128x16 {layout_a = #q, layout_b = #kt, layout_cd = #layout_128x16} : vector<128x64xf16>, vector<64x16xf16>, vector<128x16xf32> -> vector<128x16xf32> // Load third 16x64xf16 K slice %wg_x_offset_running_plus_32 = arith.addi %wg_x_offset_running_plus_16, %c16 : index - %k_value_slice_2_t0 = xegpu.load_nd %k_tile_slice[%wg_x_offset_running_plus_32, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x64xf16, #k> -> vector<16x64xf16> + %k_value_slice_2_t0 = xegpu.load_nd %k_tile_slice[%wg_x_offset_running_plus_32, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, layout = #k} : !xegpu.tensor_desc<16x64xf16, #k> -> vector<16x64xf16> %k_value_slice_2 = vector.transpose %k_value_slice_2_t0, [1, 0] {layout_result_0 = #kt} : vector<16x64xf16> to vector<64x16xf16> // Compute third 128x16 of Q * K^T using DPAS - %qk_out_2 = xegpu.dpas %q_value, %k_value_slice_2, %zero_dpas_128x16 {layout_result_0 = #layout_128x16} : vector<128x64xf16>, vector<64x16xf16>, vector<128x16xf32> -> vector<128x16xf32> + %qk_out_2 = xegpu.dpas %q_value, %k_value_slice_2, %zero_dpas_128x16 {layout_a = #q, layout_b = #kt, layout_cd = #layout_128x16} : vector<128x64xf16>, vector<64x16xf16>, vector<128x16xf32> -> vector<128x16xf32> // Load forth 16x64 K slice %wg_x_offset_running_plus_48 = arith.addi %wg_x_offset_running_plus_32, %c16 : index - %k_value_slice_3_t0 = xegpu.load_nd %k_tile_slice[%wg_x_offset_running_plus_48, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x64xf16, #k> -> vector<16x64xf16> + %k_value_slice_3_t0 = xegpu.load_nd %k_tile_slice[%wg_x_offset_running_plus_48, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, layout = #k} : !xegpu.tensor_desc<16x64xf16, #k> -> vector<16x64xf16> %k_value_slice_3 = vector.transpose %k_value_slice_3_t0, [1, 0] {layout_result_0 = #kt} : vector<16x64xf16> to vector<64x16xf16> // Compute forth 128x16 of Q * K^T using DPAS - %qk_out_3 = xegpu.dpas %q_value, %k_value_slice_3, %zero_dpas_128x16 {layout_result_0 = #layout_128x16} : vector<128x64xf16>, vector<64x16xf16>, vector<128x16xf32> -> vector<128x16xf32> + %qk_out_3 = xegpu.dpas %q_value, %k_value_slice_3, %zero_dpas_128x16 {layout_a = #q, layout_b = #kt, layout_cd = #layout_128x16} : vector<128x64xf16>, vector<64x16xf16>, vector<128x16xf32> -> vector<128x16xf32> // Softmax computation on QK_out tile // Do max reduction on qk_out @@ -209,24 +209,24 @@ module @flash_attention attributes {gpu.container_module} { %qk_out_3_f16 = arith.truncf %qk_out_3_exp {layout_result_0 = #layout_128x16} : vector<128x16xf32> to vector<128x16xf16> // Load first 16x64 V slice. - %v_val_slice_0 = xegpu.load_nd %v_tile_slice[%wg_x_offset_running, %c0] { l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x64xf16, #v> -> vector<16x64xf16> + %v_val_slice_0 = xegpu.load_nd %v_tile_slice[%wg_x_offset_running, %c0] {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, layout = #v} : !xegpu.tensor_desc<16x64xf16, #v> -> vector<16x64xf16> // Compute first iteration update of 128x64 of P * V - %pv_out_iter0 = xegpu.dpas %qk_out_0_f16, %v_val_slice_0, %acc_in_updated {layout_result_0 = #out} : vector<128x16xf16>, vector<16x64xf16>, vector<128x64xf32> -> vector<128x64xf32> + %pv_out_iter0 = xegpu.dpas %qk_out_0_f16, %v_val_slice_0, %acc_in_updated {layout_a = #q, layout_b = #v, layout_cd = #out} : vector<128x16xf16>, vector<16x64xf16>, vector<128x64xf32> -> vector<128x64xf32> // Load second 16x64 V slice. - %v_val_slice_1 = xegpu.load_nd %v_tile_slice[%wg_x_offset_running_plus_16, %c0] { l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x64xf16, #v> -> vector<16x64xf16> + %v_val_slice_1 = xegpu.load_nd %v_tile_slice[%wg_x_offset_running_plus_16, %c0] { l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, layout = #v} : !xegpu.tensor_desc<16x64xf16, #v> -> vector<16x64xf16> // Compute second iteration update of 128x64 of P * V - %pv_out_iter1 = xegpu.dpas %qk_out_1_f16, %v_val_slice_1, %pv_out_iter0 {layout_result_0 = #out} : vector<128x16xf16>, vector<16x64xf16>, vector<128x64xf32> -> vector<128x64xf32> + %pv_out_iter1 = xegpu.dpas %qk_out_1_f16, %v_val_slice_1, %pv_out_iter0 {layout_a = #q, layout_b = #v, layout_cd = #out} : vector<128x16xf16>, vector<16x64xf16>, vector<128x64xf32> -> vector<128x64xf32> // Load third 16x64 V slice. - %v_val_slice_2 = xegpu.load_nd %v_tile_slice[%wg_x_offset_running_plus_32, %c0] { l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x64xf16, #v> -> vector<16x64xf16> + %v_val_slice_2 = xegpu.load_nd %v_tile_slice[%wg_x_offset_running_plus_32, %c0] { l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, layout = #v} : !xegpu.tensor_desc<16x64xf16, #v> -> vector<16x64xf16> // Compute third iteration update of 128x64 of P * V - %pv_out_iter2 = xegpu.dpas %qk_out_2_f16, %v_val_slice_2, %pv_out_iter1 {layout_result_0 = #out} : vector<128x16xf16>, vector<16x64xf16>, vector<128x64xf32> -> vector<128x64xf32> + %pv_out_iter2 = xegpu.dpas %qk_out_2_f16, %v_val_slice_2, %pv_out_iter1 {layout_a = #q, layout_b = #v, layout_cd = #out} : vector<128x16xf16>, vector<16x64xf16>, vector<128x64xf32> -> vector<128x64xf32> // Load forth 16x64 V slice. - %v_val_slice_3 = xegpu.load_nd %v_tile_slice[%wg_x_offset_running_plus_48, %c0] { l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x64xf16, #v> -> vector<16x64xf16> + %v_val_slice_3 = xegpu.load_nd %v_tile_slice[%wg_x_offset_running_plus_48, %c0] { l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, layout = #v} : !xegpu.tensor_desc<16x64xf16, #v> -> vector<16x64xf16> // Compute forth iteration update of 128x64 of P * V - %pv_out_iter3 = xegpu.dpas %qk_out_3_f16, %v_val_slice_3, %pv_out_iter2 {layout_result_0 = #out} : vector<128x16xf16>, vector<16x64xf16>, vector<128x64xf32> -> vector<128x64xf32> + %pv_out_iter3 = xegpu.dpas %qk_out_3_f16, %v_val_slice_3, %pv_out_iter2 {layout_a = #q, layout_b = #v, layout_cd = #out} : vector<128x16xf16>, vector<16x64xf16>, vector<128x64xf32> -> vector<128x64xf32> scf.yield %pv_out_iter3, %m_ij_row, %l_i_row_new : vector<128x64xf32>, vector<128x1xf32>, vector<128x1xf32> } {layout_result_0 = #out, layout_result_1 = #layout_128x1, layout_result_2 = #layout_128x1}// end of inner loop diff --git a/test/Integration/Dialect/XeGPU/WG/gemm_4kx4kx4k_f16_f16_f32.mlir b/test/Integration/Dialect/XeGPU/WG/gemm_4kx4kx4k_f16_f16_f32.mlir index 4c25250d3..6bda55571 100644 --- a/test/Integration/Dialect/XeGPU/WG/gemm_4kx4kx4k_f16_f16_f32.mlir +++ b/test/Integration/Dialect/XeGPU/WG/gemm_4kx4kx4k_f16_f16_f32.mlir @@ -48,34 +48,34 @@ module @gemm attributes {gpu.container_module} { %m = arith.muli %block_id_x, %c256 : index %n = arith.muli %block_id_y, %c256 : index %c_tdesc = xegpu.create_nd_tdesc %C : memref<4096x4096xf32> -> !xegpu.tensor_desc<256x256xf32, #c> - %c_init_value = xegpu.load_nd %c_tdesc[%m, %n] : !xegpu.tensor_desc<256x256xf32, #c> -> vector<256x256xf32> + %c_init_value = xegpu.load_nd %c_tdesc[%m, %n] {layout = #c}: !xegpu.tensor_desc<256x256xf32, #c> -> vector<256x256xf32> %a_tdesc = xegpu.create_nd_tdesc %A : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #a> %b_tdesc = xegpu.create_nd_tdesc %B : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16, #b> // Prefetch A 3 times. %a_prefetch_tdesc = xegpu.create_nd_tdesc %A : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #a_prefetch> - xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c0] : !xegpu.tensor_desc<256x32xf16, #a_prefetch> - xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c32] : !xegpu.tensor_desc<256x32xf16, #a_prefetch> - xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c64] : !xegpu.tensor_desc<256x32xf16, #a_prefetch> + xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c0] {layout = #a_prefetch}: !xegpu.tensor_desc<256x32xf16, #a_prefetch> + xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c32] {layout = #a_prefetch}: !xegpu.tensor_desc<256x32xf16, #a_prefetch> + xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c64] {layout = #a_prefetch}: !xegpu.tensor_desc<256x32xf16, #a_prefetch> // Prefetch B 3 times. %b_prefetch_tdesc = xegpu.create_nd_tdesc %B : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16, #b_prefetch> - xegpu.prefetch_nd %b_prefetch_tdesc[%c0, %n] : !xegpu.tensor_desc<32x256xf16, #b_prefetch> - xegpu.prefetch_nd %b_prefetch_tdesc[%c32, %n] : !xegpu.tensor_desc<32x256xf16, #b_prefetch> - xegpu.prefetch_nd %b_prefetch_tdesc[%c64, %n] : !xegpu.tensor_desc<32x256xf16, #b_prefetch> + xegpu.prefetch_nd %b_prefetch_tdesc[%c0, %n] {layout = #b_prefetch}: !xegpu.tensor_desc<32x256xf16, #b_prefetch> + xegpu.prefetch_nd %b_prefetch_tdesc[%c32, %n] {layout = #b_prefetch}: !xegpu.tensor_desc<32x256xf16, #b_prefetch> + xegpu.prefetch_nd %b_prefetch_tdesc[%c64, %n] {layout = #b_prefetch}: !xegpu.tensor_desc<32x256xf16, #b_prefetch> %out = scf.for %k = %c0 to %c4096 step %c32 iter_args(%c_value = %c_init_value) -> (vector<256x256xf32>) { - %a_value = xegpu.load_nd %a_tdesc[%m, %k] : !xegpu.tensor_desc<256x32xf16, #a> -> vector<256x32xf16> - %b_value = xegpu.load_nd %b_tdesc[%k, %n] : !xegpu.tensor_desc<32x256xf16, #b> -> vector<32x256xf16> + %a_value = xegpu.load_nd %a_tdesc[%m, %k] {layout = #a}: !xegpu.tensor_desc<256x32xf16, #a> -> vector<256x32xf16> + %b_value = xegpu.load_nd %b_tdesc[%k, %n] {layout = #b}: !xegpu.tensor_desc<32x256xf16, #b> -> vector<32x256xf16> // Prefetch next tiles. %prefetch_offset = arith.addi %k, %c96 : index - xegpu.prefetch_nd %a_prefetch_tdesc[%m, %prefetch_offset] : !xegpu.tensor_desc<256x32xf16, #a_prefetch> - xegpu.prefetch_nd %b_prefetch_tdesc[%prefetch_offset, %n] : !xegpu.tensor_desc<32x256xf16, #b_prefetch> - %c_new_value = xegpu.dpas %a_value, %b_value, %c_value {layout_result_0 = #c} + xegpu.prefetch_nd %a_prefetch_tdesc[%m, %prefetch_offset] {layout = #a_prefetch}: !xegpu.tensor_desc<256x32xf16, #a_prefetch> + xegpu.prefetch_nd %b_prefetch_tdesc[%prefetch_offset, %n] {layout = #b_prefetch}: !xegpu.tensor_desc<32x256xf16, #b_prefetch> + %c_new_value = xegpu.dpas %a_value, %b_value, %c_value {layout_a = #a, layout_b = #b, layout_cd = #c} : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf32> -> vector<256x256xf32> scf.yield %c_new_value : vector<256x256xf32> } - xegpu.store_nd %out, %c_tdesc[%m, %n] : vector<256x256xf32>, !xegpu.tensor_desc<256x256xf32, #c> + xegpu.store_nd %out, %c_tdesc[%m, %n] {layout = #c}: vector<256x256xf32>, !xegpu.tensor_desc<256x256xf32, #c> gpu.return } } diff --git a/test/Integration/Dialect/XeGPU/WG/gemm_4kx4kx4k_transpose_b.mlir b/test/Integration/Dialect/XeGPU/WG/gemm_4kx4kx4k_transpose_b.mlir index f5cc4ae90..5f90499af 100644 --- a/test/Integration/Dialect/XeGPU/WG/gemm_4kx4kx4k_transpose_b.mlir +++ b/test/Integration/Dialect/XeGPU/WG/gemm_4kx4kx4k_transpose_b.mlir @@ -3,8 +3,8 @@ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck #a = #xegpu.layout -#b = #xegpu.layout -#bt = #xegpu.layout +#b = #xegpu.layout +#bt = #xegpu.layout #c = #xegpu.layout #a_prefetch = #xegpu.layout #b_prefetch = #xegpu.layout @@ -49,35 +49,35 @@ module @gemm attributes {gpu.container_module} { %m = arith.muli %block_id_x, %c256 : index %n = arith.muli %block_id_y, %c256 : index %c_tdesc = xegpu.create_nd_tdesc %C : memref<4096x4096xf32> -> !xegpu.tensor_desc<256x256xf32, #c> - %c_init_value = xegpu.load_nd %c_tdesc[%m, %n] : !xegpu.tensor_desc<256x256xf32, #c> -> vector<256x256xf32> + %c_init_value = xegpu.load_nd %c_tdesc[%m, %n] {layout = #c}: !xegpu.tensor_desc<256x256xf32, #c> -> vector<256x256xf32> %a_tdesc = xegpu.create_nd_tdesc %A : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #a> %b_tdesc = xegpu.create_nd_tdesc %B : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #bt> // Prefetch A 3 times. %a_prefetch_tdesc = xegpu.create_nd_tdesc %A : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #a_prefetch> - xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c0] : !xegpu.tensor_desc<256x32xf16, #a_prefetch> - xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c32] : !xegpu.tensor_desc<256x32xf16, #a_prefetch> - xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c64] : !xegpu.tensor_desc<256x32xf16, #a_prefetch> + xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c0] {layout = #a_prefetch}: !xegpu.tensor_desc<256x32xf16, #a_prefetch> + xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c32] {layout = #a_prefetch}: !xegpu.tensor_desc<256x32xf16, #a_prefetch> + xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c64] {layout = #a_prefetch}: !xegpu.tensor_desc<256x32xf16, #a_prefetch> // Prefetch B 3 times. %b_prefetch_tdesc = xegpu.create_nd_tdesc %B : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #b_prefetch> - xegpu.prefetch_nd %b_prefetch_tdesc[%n, %c0] : !xegpu.tensor_desc<256x32xf16, #b_prefetch> - xegpu.prefetch_nd %b_prefetch_tdesc[%n, %c32] : !xegpu.tensor_desc<256x32xf16, #b_prefetch> - xegpu.prefetch_nd %b_prefetch_tdesc[%n, %c64] : !xegpu.tensor_desc<256x32xf16, #b_prefetch> + xegpu.prefetch_nd %b_prefetch_tdesc[%n, %c0] {layout = #b_prefetch}: !xegpu.tensor_desc<256x32xf16, #b_prefetch> + xegpu.prefetch_nd %b_prefetch_tdesc[%n, %c32] {layout = #b_prefetch}: !xegpu.tensor_desc<256x32xf16, #b_prefetch> + xegpu.prefetch_nd %b_prefetch_tdesc[%n, %c64] {layout = #b_prefetch}: !xegpu.tensor_desc<256x32xf16, #b_prefetch> %out = scf.for %k = %c0 to %c4096 step %c32 iter_args(%c_value = %c_init_value) -> (vector<256x256xf32>) { - %a_value = xegpu.load_nd %a_tdesc[%m, %k] : !xegpu.tensor_desc<256x32xf16, #a> -> vector<256x32xf16> - %b_value = xegpu.load_nd %b_tdesc[%n, %k] : !xegpu.tensor_desc<256x32xf16, #bt> -> vector<256x32xf16> + %a_value = xegpu.load_nd %a_tdesc[%m, %k] {layout = #a}: !xegpu.tensor_desc<256x32xf16, #a> -> vector<256x32xf16> + %b_value = xegpu.load_nd %b_tdesc[%n, %k] {layout = #bt}: !xegpu.tensor_desc<256x32xf16, #bt> -> vector<256x32xf16> // Prefetch next tiles. %prefetch_offset = arith.addi %k, %c96 : index - xegpu.prefetch_nd %a_prefetch_tdesc[%m, %prefetch_offset] : !xegpu.tensor_desc<256x32xf16, #a_prefetch> - xegpu.prefetch_nd %b_prefetch_tdesc[%n, %prefetch_offset] : !xegpu.tensor_desc<256x32xf16, #b_prefetch> + xegpu.prefetch_nd %a_prefetch_tdesc[%m, %prefetch_offset] {layout = #a_prefetch}: !xegpu.tensor_desc<256x32xf16, #a_prefetch> + xegpu.prefetch_nd %b_prefetch_tdesc[%n, %prefetch_offset] {layout = #b_prefetch}: !xegpu.tensor_desc<256x32xf16, #b_prefetch> %b_transposed = vector.transpose %b_value, [1, 0] {layout_result_0 = #b} : vector<256x32xf16> to vector<32x256xf16> - %c_new_value = xegpu.dpas %a_value, %b_transposed, %c_value {layout_result_0 = #c} + %c_new_value = xegpu.dpas %a_value, %b_transposed, %c_value {layout_a = #a, layout_b = #b, layout_cd = #c} : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf32> -> vector<256x256xf32> - scf.yield %c_new_value : vector<256x256xf32> - } - xegpu.store_nd %out, %c_tdesc[%m, %n] : vector<256x256xf32>, !xegpu.tensor_desc<256x256xf32, #c> + scf.yield %c_new_value : vector<256x256xf32> + } + xegpu.store_nd %out, %c_tdesc[%m, %n] {layout = #c}: vector<256x256xf32>, !xegpu.tensor_desc<256x256xf32, #c> gpu.return } } diff --git a/test/Integration/Dialect/XeGPU/WG/math_exp.mlir b/test/Integration/Dialect/XeGPU/WG/math_exp.mlir index dd79ab505..592d18195 100644 --- a/test/Integration/Dialect/XeGPU/WG/math_exp.mlir +++ b/test/Integration/Dialect/XeGPU/WG/math_exp.mlir @@ -44,10 +44,10 @@ module @gemm attributes {gpu.container_module} { %m = arith.muli %block_id_x, %c256 : index %n = arith.muli %block_id_y, %c256 : index %input_tdesc = xegpu.create_nd_tdesc %input_gpu : memref<256x256xf32> -> !xegpu.tensor_desc<256x256xf32, #map> - %input_val = xegpu.load_nd %input_tdesc[%m, %n] : !xegpu.tensor_desc<256x256xf32, #map> -> vector<256x256xf32> - %result_val = math.exp %input_val {layout_result_0 = #map} : vector<256x256xf32> + %input_val = xegpu.load_nd %input_tdesc[%m, %n] {layout = #map} : !xegpu.tensor_desc<256x256xf32, #map> -> vector<256x256xf32> + %result_val = math.exp %input_val {layout_result_0 = #map}: vector<256x256xf32> %result_tdesc = xegpu.create_nd_tdesc %result_gpu : memref<256x256xf32> -> !xegpu.tensor_desc<256x256xf32, #map> - xegpu.store_nd %result_val, %result_tdesc[%m, %n] : vector<256x256xf32>, !xegpu.tensor_desc<256x256xf32, #map> + xegpu.store_nd %result_val, %result_tdesc[%m, %n] {layout = #map} : vector<256x256xf32>, !xegpu.tensor_desc<256x256xf32, #map> gpu.return } @@ -63,10 +63,10 @@ module @gemm attributes {gpu.container_module} { %m = arith.muli %block_id_x, %c256 : index %n = arith.muli %block_id_y, %c256 : index %input_tdesc = xegpu.create_nd_tdesc %input_gpu_with_fast_math : memref<256x256xf32> -> !xegpu.tensor_desc<256x256xf32, #map> - %input_val = xegpu.load_nd %input_tdesc[%m, %n] : !xegpu.tensor_desc<256x256xf32, #map> -> vector<256x256xf32> - %result_val = math.exp %input_val fastmath {layout_result_0 = #map} : vector<256x256xf32> + %input_val = xegpu.load_nd %input_tdesc[%m, %n] {layout = #map} : !xegpu.tensor_desc<256x256xf32, #map> -> vector<256x256xf32> + %result_val = math.exp %input_val fastmath {layout_result_0 = #map}: vector<256x256xf32> %result_tdesc = xegpu.create_nd_tdesc %result_gpu_with_fastmath : memref<256x256xf32> -> !xegpu.tensor_desc<256x256xf32, #map> - xegpu.store_nd %result_val, %result_tdesc[%m, %n] : vector<256x256xf32>, !xegpu.tensor_desc<256x256xf32, #map> + xegpu.store_nd %result_val, %result_tdesc[%m, %n] {layout = #map} : vector<256x256xf32>, !xegpu.tensor_desc<256x256xf32, #map> gpu.return } } diff --git a/test/Integration/Dialect/XeGPU/WG/simple_gemm.mlir b/test/Integration/Dialect/XeGPU/WG/simple_gemm.mlir index d46d6caa2..e3b011c21 100644 --- a/test/Integration/Dialect/XeGPU/WG/simple_gemm.mlir +++ b/test/Integration/Dialect/XeGPU/WG/simple_gemm.mlir @@ -51,34 +51,34 @@ module @gemm attributes {gpu.container_module} { %m = arith.muli %block_id_x, %c256 : index %n = arith.muli %block_id_y, %c256 : index %c_tdesc = xegpu.create_nd_tdesc %C : memref<256x256xf32> -> !xegpu.tensor_desc<256x256xf32, #c> - %c_init_value = xegpu.load_nd %c_tdesc[%m, %n] : !xegpu.tensor_desc<256x256xf32, #c> -> vector<256x256xf32> + %c_init_value = xegpu.load_nd %c_tdesc[%m, %n] {layout = #c}: !xegpu.tensor_desc<256x256xf32, #c> -> vector<256x256xf32> %a_tdesc = xegpu.create_nd_tdesc %A : memref<256x256xf16> -> !xegpu.tensor_desc<256x32xf16, #a> %b_tdesc = xegpu.create_nd_tdesc %B : memref<256x256xf16> -> !xegpu.tensor_desc<32x256xf16, #b> // Prefetch A 3 times. %a_prefetch_tdesc = xegpu.create_nd_tdesc %A : memref<256x256xf16> -> !xegpu.tensor_desc<256x32xf16, #a_prefetch> - xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c0] : !xegpu.tensor_desc<256x32xf16, #a_prefetch> - xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c32] : !xegpu.tensor_desc<256x32xf16, #a_prefetch> - xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c64] : !xegpu.tensor_desc<256x32xf16, #a_prefetch> + xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c0] {layout = #b_prefetch}: !xegpu.tensor_desc<256x32xf16, #a_prefetch> + xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c32] {layout = #b_prefetch}: !xegpu.tensor_desc<256x32xf16, #a_prefetch> + xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c64] {layout = #b_prefetch}: !xegpu.tensor_desc<256x32xf16, #a_prefetch> // Prefetch B 3 times. %b_prefetch_tdesc = xegpu.create_nd_tdesc %B : memref<256x256xf16> -> !xegpu.tensor_desc<32x256xf16, #b_prefetch> - xegpu.prefetch_nd %b_prefetch_tdesc[%c0, %n] : !xegpu.tensor_desc<32x256xf16, #b_prefetch> - xegpu.prefetch_nd %b_prefetch_tdesc[%c32, %n] : !xegpu.tensor_desc<32x256xf16, #b_prefetch> - xegpu.prefetch_nd %b_prefetch_tdesc[%c64, %n] : !xegpu.tensor_desc<32x256xf16, #b_prefetch> + xegpu.prefetch_nd %b_prefetch_tdesc[%c0, %n] {layout = #b_prefetch}: !xegpu.tensor_desc<32x256xf16, #b_prefetch> + xegpu.prefetch_nd %b_prefetch_tdesc[%c32, %n] {layout = #b_prefetch}: !xegpu.tensor_desc<32x256xf16, #b_prefetch> + xegpu.prefetch_nd %b_prefetch_tdesc[%c64, %n] {layout = #b_prefetch}: !xegpu.tensor_desc<32x256xf16, #b_prefetch> %out = scf.for %k = %c0 to %c256 step %c32 iter_args(%c_value = %c_init_value) -> (vector<256x256xf32>) { - %a_value = xegpu.load_nd %a_tdesc[%m, %k] : !xegpu.tensor_desc<256x32xf16, #a> -> vector<256x32xf16> - %b_value = xegpu.load_nd %b_tdesc[%k, %n] : !xegpu.tensor_desc<32x256xf16, #b> -> vector<32x256xf16> + %a_value = xegpu.load_nd %a_tdesc[%m, %k] {layout = #a}: !xegpu.tensor_desc<256x32xf16, #a> -> vector<256x32xf16> + %b_value = xegpu.load_nd %b_tdesc[%k, %n] {layout = #b}: !xegpu.tensor_desc<32x256xf16, #b> -> vector<32x256xf16> // Prefetch next tiles. %prefetch_offset = arith.addi %k, %c96 : index - xegpu.prefetch_nd %a_prefetch_tdesc[%m, %prefetch_offset] : !xegpu.tensor_desc<256x32xf16, #a_prefetch> - xegpu.prefetch_nd %b_prefetch_tdesc[%prefetch_offset, %n] : !xegpu.tensor_desc<32x256xf16, #b_prefetch> - %c_new_value = xegpu.dpas %a_value, %b_value, %c_value {layout_result_0 = #c} + xegpu.prefetch_nd %a_prefetch_tdesc[%m, %prefetch_offset] {layout = #a_prefetch}: !xegpu.tensor_desc<256x32xf16, #a_prefetch> + xegpu.prefetch_nd %b_prefetch_tdesc[%prefetch_offset, %n] {layout = #b_prefetch}: !xegpu.tensor_desc<32x256xf16, #b_prefetch> + %c_new_value = xegpu.dpas %a_value, %b_value, %c_value {layout_a = #a, layout_b = #b, layout_cd = #c} : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf32> -> vector<256x256xf32> scf.yield %c_new_value : vector<256x256xf32> } - xegpu.store_nd %out, %c_tdesc[%m, %n] : vector<256x256xf32>, !xegpu.tensor_desc<256x256xf32, #c> + xegpu.store_nd %out, %c_tdesc[%m, %n] {layout = #c}: vector<256x256xf32>, !xegpu.tensor_desc<256x256xf32, #c> gpu.return } } From 194f232f76a99fe40af4db77c500e35a27a2e02f Mon Sep 17 00:00:00 2001 From: Garra1980 Date: Wed, 14 Jan 2026 22:21:52 +0100 Subject: [PATCH 2/2] fix pre-commit --- ...upport-for-VectorAnyINTEL-capability.patch | 93 +++++++++---------- ...dd-SPIRV_ExecutionModeAttributesAttr.patch | 11 +-- ...8-length-vector-as-supported-vectors.patch | 27 +++--- build_tools/patches/wg_fa_support.patch | 15 ++- .../XeTile/Transforms/BlockingAnalysis.cpp | 6 +- lib/Transforms/VnniTransformation.cpp | 6 +- .../XeGPU/WG/gemm_4kx4kx4k_transpose_b.mlir | 4 +- 7 files changed, 79 insertions(+), 83 deletions(-) diff --git a/build_tools/patches/0001-Add-support-for-VectorAnyINTEL-capability.patch b/build_tools/patches/0001-Add-support-for-VectorAnyINTEL-capability.patch index 921285e5f..3b5d4eb98 100644 --- a/build_tools/patches/0001-Add-support-for-VectorAnyINTEL-capability.patch +++ b/build_tools/patches/0001-Add-support-for-VectorAnyINTEL-capability.patch @@ -45,11 +45,11 @@ index ecbbf39a534e..d6a72472bd1b 100644 // dialect-specific types so we use "Any" here. @@ -4293,7 +4300,7 @@ class SPIRV_MatrixOfType allowedTypes> : "Matrix">; - + class SPIRV_VectorOf : - FixedVectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>; + FixedVectorOfLengthRangeAndType<[2, 0xFFFFFFFF], [type]>; - + class SPIRV_ScalarOrVectorOf : AnyTypeOf<[type, SPIRV_VectorOf]>; diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td @@ -59,7 +59,7 @@ index 0fb4837e528b..a33b18e8c868 100644 @@ -696,6 +696,92 @@ class ScalableVectorOfRankAndLengthAndType allowedRanks, ScalableVectorOfLength.summary, "::mlir::VectorType">; - + +// Whether the number of elements of a vector is from the given +// `allowedRanges` list, the list has two values, start and end of the range (inclusive) +class IsVectorOfLengthRangePred allowedRanges> : @@ -174,7 +174,7 @@ index 53a48abe5ad0..4c39a7c83281 100644 +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -186,9 +186,10 @@ bool CompositeType::classof(Type type) { } - + bool CompositeType::isValid(VectorType type) { - return type.getRank() == 1 && - llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) && @@ -184,10 +184,10 @@ index 53a48abe5ad0..4c39a7c83281 100644 + type.getNumElements() >= 2 && + type.getNumElements() <= std::numeric_limits::max(); } - + Type CompositeType::getElementType(unsigned index) const { @@ -218,8 +219,23 @@ void TypeCapabilityVisitor::addConcrete(VectorType type) { - + int64_t vecSize = type.getNumElements(); if (vecSize == 8 || vecSize == 16) { - static constexpr auto cap = Capability::Vector16; @@ -211,7 +211,7 @@ index 53a48abe5ad0..4c39a7c83281 100644 + capabilities.push_back(ref); } } - + diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 816226749463..31c590efcda3 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -230,7 +230,7 @@ index 816226749463..31c590efcda3 100644 + return llvm::is_contained(ors, elidedExt); + })) continue; - + LLVM_DEBUG({ @@ -112,9 +116,13 @@ static LogicalResult checkExtensionRequirements( template @@ -246,12 +246,12 @@ index 816226749463..31c590efcda3 100644 + return llvm::is_contained(ors, elidedCap); + })) continue; - + LLVM_DEBUG({ @@ -131,6 +139,55 @@ static LogicalResult checkCapabilityRequirements( return success(); } - + +/// Check capabilities and extensions requirements, +/// this function also checks for capability infered extension requirements, +/// the check is based on capabilities that are passed to the targetEnv. @@ -307,20 +307,20 @@ index 816226749463..31c590efcda3 100644 @@ -284,11 +341,14 @@ convertScalarType(const spirv::TargetEnv &targetEnv, return nullptr; } - + + // Convert to 32-bit float and remove floatType related capability + // restriction if (auto floatType = dyn_cast(type)) { LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); return Builder(targetEnv.getContext()).getF32Type(); } - + + // Convert to 32-bit int and remove intType related capability restriction auto intType = cast(type); LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); return IntegerType::get(targetEnv.getContext(), /*width=*/32, @@ -402,10 +462,13 @@ convertVectorType(const spirv::TargetEnv &targetEnv, - + if (type.getRank() <= 1 && type.getNumElements() == 1) return elementType; - @@ -336,11 +336,11 @@ index 816226749463..31c590efcda3 100644 + "between [2 - 2^32 -1]\n"); return nullptr; } - + @@ -427,16 +490,40 @@ convertVectorType(const spirv::TargetEnv &targetEnv, cast(type).getExtensions(extensions, storageClass); cast(type).getCapabilities(capabilities, storageClass); - + - // If all requirements are met, then we can accept this type as-is. - if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) && - succeeded(checkExtensionRequirements(type, targetEnv, extensions))) @@ -383,7 +383,7 @@ index 816226749463..31c590efcda3 100644 + return nullptr; + } } - + static Type @@ -1694,16 +1781,18 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) { SmallVector, 4> typeExtensions; @@ -411,7 +411,7 @@ index 816226749463..31c590efcda3 100644 + op->getName(), this->targetEnv, typeCapabilities, typeExtensions))) return false; } - + diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir index 9d7ab2be096e..3aa22e261f7c 100644 --- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir @@ -419,7 +419,7 @@ index 9d7ab2be096e..3aa22e261f7c 100644 @@ -28,9 +28,9 @@ module attributes { #spirv.vce, #spirv.resource_limits<>> } { - + -func.func @unsupported_5elem_vector(%arg0: vector<5xi32>) { +func.func @unsupported_5elem_vector(%arg0: vector<5xi32>, %arg1: vector<5xi32>) { // expected-error@+1 {{failed to legalize operation 'arith.subi'}} @@ -427,7 +427,7 @@ index 9d7ab2be096e..3aa22e261f7c 100644 + %1 = arith.subi %arg0, %arg1: vector<5xi32> return } - + diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir index 3cb529459899..e881d512bf2e 100644 --- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir @@ -472,7 +472,7 @@ index 3cb529459899..e881d512bf2e 100644 + %1 = arith.mulf %arg1, %arg1: vector<8xf64> return } - + diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir index 0c77c8833457..d6628afb7329 100644 --- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir @@ -480,7 +480,7 @@ index 0c77c8833457..d6628afb7329 100644 @@ -347,8 +347,21 @@ module attributes { spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { - + -// CHECK-NOT: spirv.func @large_vector -func.func @large_vector(%arg0: vector<1024xi32>) { return } +// CHECK-NOT: spirv.func @large_vector_unsupported @@ -498,15 +498,15 @@ index 0c77c8833457..d6628afb7329 100644 + +// CHECK: spirv.func @large_any_vector +func.func @large_any_vector(%arg0: vector<1024xi32>) { return } - + } // end module - + diff --git a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir index c703274bda57..670edc9deb91 100644 --- a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir @@ -349,6 +349,7 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f16 { - + func.func @dot(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> i32 { // expected-error @+1 {{'spirv.Dot' op operand #0 must be fixed-length vector of 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16}} + // expected-error @+1 {{'spirv.Dot' op operand #0 must be fixed-length vector of 16/32/64-bit float or BFloat16 values of length 2-4294967295}} @@ -519,7 +519,7 @@ index 4bdac198a1e8..dee8c7f9a65e 100644 +++ b/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir @@ -137,7 +137,7 @@ func.func @bitwise_or_all_ones_vector(%arg: vector<3xi8>) -> vector<3xi8> { // ----- - + func.func @bitwise_or_float(%arg0: f16, %arg1: f16) -> f16 { - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4}} + // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2-4294967295}} @@ -528,7 +528,7 @@ index 4bdac198a1e8..dee8c7f9a65e 100644 } @@ -165,7 +165,7 @@ func.func @bitwise_xor_vector(%arg: vector<4xi32>) -> vector<4xi32> { // ----- - + func.func @bitwise_xor_float(%arg0: f16, %arg1: f16) -> f16 { - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4}} + // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2-9223372036854775807}} @@ -537,7 +537,7 @@ index 4bdac198a1e8..dee8c7f9a65e 100644 } @@ -274,7 +274,7 @@ func.func @bitwise_and_zext_vector(%arg: vector<2xi8>) -> vector<2xi32> { // ----- - + func.func @bitwise_and_float(%arg0: f16, %arg1: f16) -> f16 { - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4}} + // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2-9223372036854775807}} @@ -550,7 +550,7 @@ index fd8a2ffbbddf..011759101a74 100644 +++ b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir @@ -27,7 +27,7 @@ func.func @exp(%arg0 : i32) -> () { // ----- - + func.func @exp(%arg0 : vector<5xf32>) -> () { - // expected-error @+1 {{op operand #0 must be 16/32-bit float or fixed-length vector of 16/32-bit float values of length 2/3/4}} + // CHECK: spirv.GL.Exp {{%.*}} : vector<5xf32 @@ -563,7 +563,7 @@ index 2e2fb1a9df32..ad8a66e16745 100644 +++ b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir @@ -21,7 +21,7 @@ spirv.func @f32_to_bf16_vec(%arg0 : vector<2xf32>) "None" { // ----- - + spirv.func @f32_to_bf16_unsupported(%arg0 : f64) "None" { - // expected-error @+1 {{operand #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16, but got}} + // expected-error @+1 {{op operand #0 must be Float32 or fixed-length vector of Float32 values of length 2-4294967295, but got 'f64'}} @@ -571,7 +571,7 @@ index 2e2fb1a9df32..ad8a66e16745 100644 spirv.Return } @@ -58,6 +58,7 @@ spirv.func @bf16_to_f32_vec(%arg0 : vector<2xi16>) "None" { - + spirv.func @bf16_to_f32_unsupported(%arg0 : i16) "None" { // expected-error @+1 {{result #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16, but got}} + // expected-error @+1 {{op result #0 must be Float32 or fixed-length vector of Float32 values of length 2-4294967295, but got 'f16'}} @@ -583,7 +583,7 @@ index d7f4ed05969a..3acd5b88e42a 100644 --- a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir @@ -184,7 +184,7 @@ func.func @logicalUnary(%arg0 : i1) - + func.func @logicalUnary(%arg0 : i32) { - // expected-error @+1 {{'operand' must be bool or fixed-length vector of bool values of length 2/3/4/8/16, but got 'i32'}} @@ -597,7 +597,7 @@ index bdb2abde8d8e..7b9b5d9a4688 100644 +++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir @@ -511,7 +511,7 @@ func.func @group_non_uniform_bitwise_and(%val: i32) -> i32 { // ----- - + func.func @group_non_uniform_bitwise_and(%val: i1) -> i1 { - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}} + // expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2-4294967295, but got 'i1'}} @@ -606,7 +606,7 @@ index bdb2abde8d8e..7b9b5d9a4688 100644 } @@ -532,7 +532,7 @@ func.func @group_non_uniform_bitwise_or(%val: i32) -> i32 { // ----- - + func.func @group_non_uniform_bitwise_or(%val: i1) -> i1 { - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}} + // expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2-4294967295, but got 'i1'}} @@ -615,7 +615,7 @@ index bdb2abde8d8e..7b9b5d9a4688 100644 } @@ -553,7 +553,7 @@ func.func @group_non_uniform_bitwise_xor(%val: i32) -> i32 { // ----- - + func.func @group_non_uniform_bitwise_xor(%val: i1) -> i1 { - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}} + // expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2-4294967295, but got 'i1'}} @@ -624,7 +624,7 @@ index bdb2abde8d8e..7b9b5d9a4688 100644 } @@ -574,7 +574,7 @@ func.func @group_non_uniform_logical_and(%val: i1) -> i1 { // ----- - + func.func @group_non_uniform_logical_and(%val: i32) -> i32 { - // expected-error @+1 {{operand #0 must be bool or fixed-length vector of bool values of length 2/3/4/8/16, but got 'i32'}} + // expected-error @+1 {{op operand #0 must be bool or fixed-length vector of bool values of length 2-4294967295, but got 'i32'}} @@ -633,7 +633,7 @@ index bdb2abde8d8e..7b9b5d9a4688 100644 } @@ -595,7 +595,7 @@ func.func @group_non_uniform_logical_or(%val: i1) -> i1 { // ----- - + func.func @group_non_uniform_logical_or(%val: i32) -> i32 { - // expected-error @+1 {{operand #0 must be bool or fixed-length vector of bool values of length 2/3/4/8/16, but got 'i32'}} + // expected-error @+1 {{op operand #0 must be bool or fixed-length vector of bool values of length 2-4294967295, but got 'i32'}} @@ -642,7 +642,7 @@ index bdb2abde8d8e..7b9b5d9a4688 100644 } @@ -616,7 +616,7 @@ func.func @group_non_uniform_logical_xor(%val: i1) -> i1 { // ----- - + func.func @group_non_uniform_logical_xor(%val: i32) -> i32 { - // expected-error @+1 {{operand #0 must be bool or fixed-length vector of bool values of length 2/3/4/8/16, but got 'i32'}} + // expected-error @+1 {{op operand #0 must be bool or fixed-length vector of bool values of length 2-4294967295, but got 'i32'}} @@ -655,7 +655,7 @@ index 6aaaa6012fef..60ef7afeeeed 100644 +++ b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir @@ -27,7 +27,7 @@ func.func @exp(%arg0 : i32) -> () { // ----- - + func.func @exp(%arg0 : vector<5xf32>) -> () { - // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4}} + // CHECK: spirv.CL.exp {{%.*}} : vector<5xf32> @@ -665,7 +665,7 @@ index 6aaaa6012fef..60ef7afeeeed 100644 @@ -66,6 +66,14 @@ func.func @fabsvec(%arg0 : vector<3xf16>) -> () { return } - + +// ----- + +func.func @fabs_any_vec(%arg0 : vector<5xf32>) -> () { @@ -678,9 +678,9 @@ index 6aaaa6012fef..60ef7afeeeed 100644 // CHECK: spirv.CL.fabs {{%.*}} : f64 %2 = spirv.CL.fabs %arg0 : f64 @@ -82,14 +90,6 @@ func.func @fabs(%arg0 : i32) -> () { - + // ----- - + -func.func @fabs(%arg0 : vector<5xf32>) -> () { - // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4}} - %2 = spirv.CL.fabs %arg0 : vector<5xf32> @@ -695,7 +695,7 @@ index 6aaaa6012fef..60ef7afeeeed 100644 @@ -122,6 +122,14 @@ func.func @sabsvec(%arg0 : vector<3xi16>) -> () { return } - + +// ----- + +func.func @sabs_any_vec(%arg0 : vector<5xi32>) -> () { @@ -708,9 +708,9 @@ index 6aaaa6012fef..60ef7afeeeed 100644 // CHECK: spirv.CL.s_abs {{%.*}} : i64 %2 = spirv.CL.s_abs %arg0 : i64 @@ -144,14 +152,6 @@ func.func @sabs(%arg0 : f32) -> () { - + // ----- - + -func.func @sabs(%arg0 : vector<5xi32>) -> () { - // expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4}} - %2 = spirv.CL.s_abs %arg0 : vector<5xi32> @@ -746,7 +746,7 @@ index 17accd93e824..ed9a9976e89b 100644 @@ -44,6 +44,12 @@ spirv.module Physical64 OpenCL requires #spirv.vce) "None" { + // CHECK: {{%.*}} = spirv.CL.fabs {{%.*}} : vector<5000xf32> + %0 = spirv.CL.fabs %arg0 : vector<5000xf32> @@ -756,6 +756,5 @@ index 17accd93e824..ed9a9976e89b 100644 spirv.func @fma(%arg0 : f32, %arg1 : f32, %arg2 : f32) "None" { // CHECK: spirv.CL.fma {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : f32 %13 = spirv.CL.fma %arg0, %arg1, %arg2 : f32 --- +-- 2.34.1 - diff --git a/build_tools/patches/0006-Add-SPIRV_ExecutionModeAttributesAttr.patch b/build_tools/patches/0006-Add-SPIRV_ExecutionModeAttributesAttr.patch index 08eb7615e..025296757 100644 --- a/build_tools/patches/0006-Add-SPIRV_ExecutionModeAttributesAttr.patch +++ b/build_tools/patches/0006-Add-SPIRV_ExecutionModeAttributesAttr.patch @@ -18,7 +18,7 @@ index 1bc3c63646fd..cf5b5ffa451d 100644 @@ -56,6 +56,17 @@ def SPIRV_LinkageAttributesAttr : SPIRV_Attr<"LinkageAttributes", "linkage_attri let assemblyFormat = "`<` struct(params) `>`"; } - + +// This attribute specifies SPIR-V execution mode information via GPU functions +// 1) Execution mode attribute. +// 2) [optional] Execution mode value. @@ -40,7 +40,7 @@ index 24574bfaf619..f64f99294038 100644 @@ -137,6 +137,14 @@ FailureOr getExecutionModel(TargetEnvAttr targetAttr); /// Returns failure if it cannot be selected. FailureOr getMemoryModel(TargetEnvAttr targetAttr); - + +/// Returns the attribute name for specifying execution mode attribute +/// information. +StringRef getExecutionModeFuncAttrName(); @@ -51,7 +51,7 @@ index 24574bfaf619..f64f99294038 100644 + } // namespace spirv } // namespace mlir - + diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index c33a903d0339..a6578465abac 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -76,7 +76,7 @@ index c33a903d0339..a6578465abac 100644 + return success(); } - + diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp index 8f02dd856d4e..dfb13c173596 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp @@ -118,6 +118,5 @@ index 8c52ba8b8583..864fc75c53b1 100644 + + return {}; +} --- +-- 2.34.1 - diff --git a/build_tools/patches/0014-Add-32-64-and-128-length-vector-as-supported-vectors.patch b/build_tools/patches/0014-Add-32-64-and-128-length-vector-as-supported-vectors.patch index a2dbdbf01..6f3bfa59d 100644 --- a/build_tools/patches/0014-Add-32-64-and-128-length-vector-as-supported-vectors.patch +++ b/build_tools/patches/0014-Add-32-64-and-128-length-vector-as-supported-vectors.patch @@ -14,7 +14,7 @@ index 590182731b00..0d9f4df374f9 100644 @@ -50,6 +50,28 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { const LLT s64 = LLT::scalar(64); const LLT s128 = LLT::scalar(128); - + + // @IMEX, Add 32, 64 and 128 length vector as supported vectors. + // This is needed to support large loads/stores using OpenCL intrinsics in + // MLIR workflow. @@ -41,7 +41,7 @@ index 590182731b00..0d9f4df374f9 100644 const LLT v16s32 = LLT::fixed_vector(16, 32); const LLT v16s16 = LLT::fixed_vector(16, 16); @@ -100,16 +122,22 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { - + // TODO: remove copy-pasting here by using concatenation in some way. auto allPtrsScalarsAndVectors = { - p0, p1, p2, p3, p4, p5, p6, p7, p8, @@ -64,17 +64,17 @@ index 590182731b00..0d9f4df374f9 100644 + v16s8, v16s16, v16s32, v16s64, v32s1, v32s8, v32s16, + v32s32, v32s64, v64s1, v64s8, v64s16, v64s32, v64s64, + v128s1, v128s8, v128s16, v128s32, v128s64}; - + - auto allVectors = {v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, - v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, - v4s64, v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, - v16s8, v16s16, v16s32, v16s64}; - + auto allShaderVectors = {v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, v3s16, v3s32, v3s64, @@ -118,25 +146,32 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { auto allScalars = {s1, s8, s16, s32, s64}; - + auto allScalarsAndVectors = { - s1, s8, s16, s32, s64, s128, v2s1, v2s8, - v2s16, v2s32, v2s64, v3s1, v3s8, v3s16, v3s32, v3s64, @@ -86,7 +86,7 @@ index 590182731b00..0d9f4df374f9 100644 + v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64, v32s1, + v32s8, v32s16, v32s32, v32s64, v64s1, v64s8, v64s16, v64s32, + v64s64, v128s1, v128s8, v128s16, v128s32, v128s64}; - + auto allIntScalarsAndVectors = { - s8, s16, s32, s64, s128, v2s8, v2s16, v2s32, v2s64, - v3s8, v3s16, v3s32, v3s64, v4s8, v4s16, v4s32, v4s64, v8s8, @@ -96,15 +96,15 @@ index 590182731b00..0d9f4df374f9 100644 + v8s8, v8s16, v8s32, v8s64, v16s8, v16s16, v16s32, v16s64, + v32s1, v32s8, v32s16, v32s32, v32s64, v64s1, v64s8, v64s16, + v64s32, v64s64, v128s1, v128s8, v128s16, v128s32, v128s64}; - + - auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1}; + auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, + v8s1, v16s1, v32s1, v64s1}; - + auto allIntScalars = {s8, s16, s32, s64, s128}; - + auto allFloatScalarsAndF16Vector2AndVector4s = {s16, s32, s64, v2s16, v4s16}; - + auto allFloatScalarsAndVectors = { - s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64, - v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64}; @@ -112,7 +112,7 @@ index 590182731b00..0d9f4df374f9 100644 + v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64, + v32s1, v32s8, v32s16, v32s32, v32s64, v64s1, v64s8, v64s16, v64s32, + v64s64, v128s1, v128s8, v128s16, v128s32, v128s64}; - + auto allFloatAndIntScalarsAndPtrs = {s8, s16, s32, s64, p0, p1, p2, p3, p4, p5, p6, p7, @@ -174,7 +209,9 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { @@ -124,7 +124,7 @@ index 590182731b00..0d9f4df374f9 100644 + // @IMEX, make the max vector size to be 128 for now. + uint32_t MaxVectorSize = ST.isShader() ? 4 : 128; LLVM_DEBUG(dbgs() << "MaxVectorSize: " << MaxVectorSize << "\n"); - + for (auto Opc : getTypeFoldingSupportedOpcodes()) { @@ -579,7 +616,10 @@ static bool needsVectorLegalization(const LLT &Ty, const SPIRVSubtarget &ST) { if (!Ty.isVector()) @@ -138,6 +138,5 @@ index 590182731b00..0d9f4df374f9 100644 return (NumElements > 4 && !isPowerOf2_32(NumElements)) || NumElements > MaxVectorSize; } --- +-- 2.34.1 - diff --git a/build_tools/patches/wg_fa_support.patch b/build_tools/patches/wg_fa_support.patch index 9e9c8e5ab..39508836a 100644 --- a/build_tools/patches/wg_fa_support.patch +++ b/build_tools/patches/wg_fa_support.patch @@ -16,11 +16,11 @@ index ba2753f517ce..dc4f05c0d914 100644 @@ -228,7 +228,7 @@ XeGPUBlockingPass::getTileShape(Operation *op) const { if (isa(op)) return getTileShape(op->getOpOperand(0)); - + - if (isa(op)) + if (isa(op)) return getTileShape(op->getOpResult(0)); - + return std::nullopt; @@ -415,14 +415,14 @@ void XeGPUBlockingPass::runOnOperation() { // Remove the layout attributes cached per operands. @@ -30,7 +30,7 @@ index ba2753f517ce..dc4f05c0d914 100644 + if (op->hasAttrOfType(name)) op->removeAttr(name); } - + // Update the layout attributes per result. for (OpResult result : op->getOpResults()) { std::string name = xegpu::getTemporaryLayoutName(result); @@ -46,7 +46,7 @@ index 7fc75e7294ea..94f31e653511 100644 @@ -1336,6 +1336,60 @@ static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder, return success(); } - + +static LogicalResult resolveConflicts(Operation *op) { + auto r = op->walk([&](xegpu::LoadNdOp loadNdOp) -> WalkResult { + // Load op has a conflict if tensor desc layout is different from the its @@ -126,7 +126,7 @@ index 07572a495076..448d78f4dc4f 100644 - DenseI32ArrayAttr resultOrder = layout.getOrder(); + // DenseI32ArrayAttr sourceOrder = sourceLayout.getOrder(); + // DenseI32ArrayAttr resultOrder = layout.getOrder(); - + - if (!sourceOrder || !resultOrder) { - return rewriter.notifyMatchFailure( - op, "Both source and result must have order attributes"); @@ -135,9 +135,8 @@ index 07572a495076..448d78f4dc4f 100644 + // return rewriter.notifyMatchFailure( + // op, "Both source and result must have order attributes"); + // } - + ArrayRef permutation = op.getPermutation(); size_t permutationSize = permutation.size(); --- +-- 2.34.1 - diff --git a/lib/Dialect/XeTile/Transforms/BlockingAnalysis.cpp b/lib/Dialect/XeTile/Transforms/BlockingAnalysis.cpp index 600cc484c..90b12ce3c 100644 --- a/lib/Dialect/XeTile/Transforms/BlockingAnalysis.cpp +++ b/lib/Dialect/XeTile/Transforms/BlockingAnalysis.cpp @@ -282,9 +282,9 @@ class BlockingAnalysisImpl void visitBranchOperand(mlir::OpOperand &operand) override {} - void - visitNonControlFlowArguments(mlir::RegionSuccessor &successor, - mlir::ArrayRef arguments) override {}; + void visitNonControlFlowArguments( + mlir::RegionSuccessor &successor, + mlir::ArrayRef arguments) override{}; void visitCallOperand(mlir::OpOperand &operand) override {} diff --git a/lib/Transforms/VnniTransformation.cpp b/lib/Transforms/VnniTransformation.cpp index 2b48f9cbd..b87e9b6c7 100644 --- a/lib/Transforms/VnniTransformation.cpp +++ b/lib/Transforms/VnniTransformation.cpp @@ -231,9 +231,9 @@ class LayoutAnalysisImpl void visitCallOperand(mlir::OpOperand &operand) override {} - void - visitNonControlFlowArguments(mlir::RegionSuccessor &successor, - mlir::ArrayRef arguments) override {}; + void visitNonControlFlowArguments( + mlir::RegionSuccessor &successor, + mlir::ArrayRef arguments) override{}; void setToExitState(LayoutLattice *lattice) override { (void)lattice->meet(false); diff --git a/test/Integration/Dialect/XeGPU/WG/gemm_4kx4kx4k_transpose_b.mlir b/test/Integration/Dialect/XeGPU/WG/gemm_4kx4kx4k_transpose_b.mlir index 5f90499af..805fefb40 100644 --- a/test/Integration/Dialect/XeGPU/WG/gemm_4kx4kx4k_transpose_b.mlir +++ b/test/Integration/Dialect/XeGPU/WG/gemm_4kx4kx4k_transpose_b.mlir @@ -75,8 +75,8 @@ module @gemm attributes {gpu.container_module} { %b_transposed = vector.transpose %b_value, [1, 0] {layout_result_0 = #b} : vector<256x32xf16> to vector<32x256xf16> %c_new_value = xegpu.dpas %a_value, %b_transposed, %c_value {layout_a = #a, layout_b = #b, layout_cd = #c} : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf32> -> vector<256x256xf32> - scf.yield %c_new_value : vector<256x256xf32> - } + scf.yield %c_new_value : vector<256x256xf32> + } xegpu.store_nd %out, %c_tdesc[%m, %n] {layout = #c}: vector<256x256xf32>, !xegpu.tensor_desc<256x256xf32, #c> gpu.return }