diff --git a/mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td b/mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td index ed5a295a38db..9bb4cf55ed7f 100644 --- a/mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td +++ b/mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td @@ -719,4 +719,28 @@ def DXSA_DclOutput : DXSA_Op<"dcl_output"> { let assemblyFormat = "$operand attr-dict"; } +def DXSA_DclThreadGroup : DXSA_Op<"dcl_thread_group"> { + let summary = "declares compute shader thread group dimensions"; + let description = [{ + The `dxsa.dcl_thread_group` operation declares the `$x`, `$y` and `$z` + dimensions of a compute shader thread group. + + The product `$x` * `$y` * `$z` must not exceed 1024. + + Example: + + ```mlir + dxsa.dcl_thread_group + ``` + }]; + let arguments = (ins + ConfinedAttr]>:$x, + ConfinedAttr]>:$y, + ConfinedAttr]>:$z); + let assemblyFormat = [{ + `<` `x` `=` $x `,` `y` `=` $y `,` `z` `=` $z `>` attr-dict + }]; + let hasVerifier = 1; +} + #endif // DXSA_OPS diff --git a/mlir/lib/Dialect/DXSA/IR/DXSA.cpp b/mlir/lib/Dialect/DXSA/IR/DXSA.cpp index e5c94051e3a5..f49027a07fb8 100644 --- a/mlir/lib/Dialect/DXSA/IR/DXSA.cpp +++ b/mlir/lib/Dialect/DXSA/IR/DXSA.cpp @@ -34,6 +34,19 @@ void DXSADialect::initialize() { >(); } +//===----------------------------------------------------------------------===// +// DclThreadGroup +//===----------------------------------------------------------------------===// + +LogicalResult DclThreadGroup::verify() { + constexpr int64_t maxTotalThreads = 1024; + if (auto total = static_cast(getX()) * getY() * getZ(); + total > maxTotalThreads) + return emitOpError("thread group size x*y*z must be <= ") + << maxTotalThreads << ", got " << total; + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/DXSA/BinaryParser.cpp b/mlir/lib/Target/DXSA/BinaryParser.cpp index f1c860c758a2..225a8cd17ffd 100644 --- a/mlir/lib/Target/DXSA/BinaryParser.cpp +++ b/mlir/lib/Target/DXSA/BinaryParser.cpp @@ -635,6 +635,13 @@ class DXBuilder { return dxsa::DclOutput::create(builder, loc, operand); } + Instruction buildDclThreadGroup(uint32_t x, uint32_t y, uint32_t z, + Location loc) { + return dxsa::DclThreadGroup::create( + builder, loc, builder.getI32IntegerAttr(x), + builder.getI32IntegerAttr(y), builder.getI32IntegerAttr(z)); + } + private: MLIRContext *context; ModuleOp module; @@ -1174,6 +1181,16 @@ class Parser { return builder.buildDclOutput(*operand, loc); } + FailureOr parseDclThreadGroup(Location loc) { + auto x = parseToken(); + FAILURE_IF_FAILED(x); + auto y = parseToken(); + FAILURE_IF_FAILED(y); + auto z = parseToken(); + FAILURE_IF_FAILED(z); + return builder.buildDclThreadGroup(*x, *y, *z, loc); + } + OptionalParseResult parseDclInstruction(uint32_t opcodeToken, Location loc, Instruction &out) { FailureOr result; @@ -1220,6 +1237,9 @@ class Parser { case D3D10_SB_OPCODE_DCL_OUTPUT: result = parseDclOutput(loc); break; + case D3D11_SB_OPCODE_DCL_THREAD_GROUP: + result = parseDclThreadGroup(loc); + break; default: return std::nullopt; } diff --git a/mlir/test/Target/DXSA/dcl_thread_group.mlir b/mlir/test/Target/DXSA/dcl_thread_group.mlir new file mode 100644 index 000000000000..ef296ba4130b --- /dev/null +++ b/mlir/test/Target/DXSA/dcl_thread_group.mlir @@ -0,0 +1,8 @@ +// RUN: mlir-translate --import-dxsa-bin %S/inputs/dcl_thread_group.bin | FileCheck %s + +// CHECK: module { +// CHECK-NEXT: dxsa.dcl_thread_group +// CHECK-NEXT: dxsa.dcl_thread_group +// CHECK-NEXT: dxsa.dcl_thread_group +// CHECK-NEXT: dxsa.dcl_thread_group +// CHECK-NEXT: } diff --git a/mlir/test/Target/DXSA/dcl_thread_group_invalid.mlir b/mlir/test/Target/DXSA/dcl_thread_group_invalid.mlir new file mode 100644 index 000000000000..17b1be8587aa --- /dev/null +++ b/mlir/test/Target/DXSA/dcl_thread_group_invalid.mlir @@ -0,0 +1,35 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics + +// expected-error@+1 {{attribute 'x' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive whose maximum value is 1024}} +dxsa.dcl_thread_group + +// ----- + +// expected-error@+1 {{attribute 'x' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive whose maximum value is 1024}} +dxsa.dcl_thread_group + +// ----- + +// expected-error@+1 {{attribute 'y' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive whose maximum value is 1024}} +dxsa.dcl_thread_group + +// ----- + +// expected-error@+1 {{attribute 'y' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive whose maximum value is 1024}} +dxsa.dcl_thread_group + +// ----- + +// expected-error@+1 {{attribute 'z' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive whose maximum value is 64}} +dxsa.dcl_thread_group + +// ----- + +// expected-error@+1 {{attribute 'z' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive whose maximum value is 64}} +dxsa.dcl_thread_group + +// ----- + +// 64 * 8 * 4 == 2048 +// expected-error@+1 {{'dxsa.dcl_thread_group' op thread group size x*y*z must be <= 1024, got 2048}} +dxsa.dcl_thread_group diff --git a/mlir/test/Target/DXSA/inputs/dcl_thread_group.bin b/mlir/test/Target/DXSA/inputs/dcl_thread_group.bin new file mode 100644 index 000000000000..f478f3ed7a8d Binary files /dev/null and b/mlir/test/Target/DXSA/inputs/dcl_thread_group.bin differ