From 3285a60770f2676a81867ff52aa3caaa615ed8aa Mon Sep 17 00:00:00 2001 From: Vladimir Shiryaev Date: Thu, 30 Apr 2026 20:40:20 -0700 Subject: [PATCH] [mlir][dxsa] Add dcl_thread_group instruction Example: dxsa.dcl_thread_group Signed-off-by: Vladimir Shiryaev --- mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td | 24 ++++++++++++ mlir/lib/Dialect/DXSA/IR/DXSA.cpp | 13 +++++++ mlir/lib/Target/DXSA/BinaryParser.cpp | 20 ++++++++++ mlir/test/Target/DXSA/dcl_thread_group.mlir | 8 ++++ .../Target/DXSA/dcl_thread_group_invalid.mlir | 35 ++++++++++++++++++ .../Target/DXSA/inputs/dcl_thread_group.bin | Bin 0 -> 64 bytes 6 files changed, 100 insertions(+) create mode 100644 mlir/test/Target/DXSA/dcl_thread_group.mlir create mode 100644 mlir/test/Target/DXSA/dcl_thread_group_invalid.mlir create mode 100644 mlir/test/Target/DXSA/inputs/dcl_thread_group.bin diff --git a/mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td b/mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td index ed5a295a38db..9bb4cf55ed7f 100644 --- a/mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td +++ b/mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td @@ -719,4 +719,28 @@ def DXSA_DclOutput : DXSA_Op<"dcl_output"> { let assemblyFormat = "$operand attr-dict"; } +def DXSA_DclThreadGroup : DXSA_Op<"dcl_thread_group"> { + let summary = "declares compute shader thread group dimensions"; + let description = [{ + The `dxsa.dcl_thread_group` operation declares the `$x`, `$y` and `$z` + dimensions of a compute shader thread group. + + The product `$x` * `$y` * `$z` must not exceed 1024. + + Example: + + ```mlir + dxsa.dcl_thread_group + ``` + }]; + let arguments = (ins + ConfinedAttr]>:$x, + ConfinedAttr]>:$y, + ConfinedAttr]>:$z); + let assemblyFormat = [{ + `<` `x` `=` $x `,` `y` `=` $y `,` `z` `=` $z `>` attr-dict + }]; + let hasVerifier = 1; +} + #endif // DXSA_OPS diff --git a/mlir/lib/Dialect/DXSA/IR/DXSA.cpp b/mlir/lib/Dialect/DXSA/IR/DXSA.cpp index e5c94051e3a5..f49027a07fb8 100644 --- a/mlir/lib/Dialect/DXSA/IR/DXSA.cpp +++ b/mlir/lib/Dialect/DXSA/IR/DXSA.cpp @@ -34,6 +34,19 @@ void DXSADialect::initialize() { >(); } +//===----------------------------------------------------------------------===// +// DclThreadGroup +//===----------------------------------------------------------------------===// + +LogicalResult DclThreadGroup::verify() { + constexpr int64_t maxTotalThreads = 1024; + if (auto total = static_cast(getX()) * getY() * getZ(); + total > maxTotalThreads) + return emitOpError("thread group size x*y*z must be <= ") + << maxTotalThreads << ", got " << total; + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/DXSA/BinaryParser.cpp b/mlir/lib/Target/DXSA/BinaryParser.cpp index f1c860c758a2..225a8cd17ffd 100644 --- a/mlir/lib/Target/DXSA/BinaryParser.cpp +++ b/mlir/lib/Target/DXSA/BinaryParser.cpp @@ -635,6 +635,13 @@ class DXBuilder { return dxsa::DclOutput::create(builder, loc, operand); } + Instruction buildDclThreadGroup(uint32_t x, uint32_t y, uint32_t z, + Location loc) { + return dxsa::DclThreadGroup::create( + builder, loc, builder.getI32IntegerAttr(x), + builder.getI32IntegerAttr(y), builder.getI32IntegerAttr(z)); + } + private: MLIRContext *context; ModuleOp module; @@ -1174,6 +1181,16 @@ class Parser { return builder.buildDclOutput(*operand, loc); } + FailureOr parseDclThreadGroup(Location loc) { + auto x = parseToken(); + FAILURE_IF_FAILED(x); + auto y = parseToken(); + FAILURE_IF_FAILED(y); + auto z = parseToken(); + FAILURE_IF_FAILED(z); + return builder.buildDclThreadGroup(*x, *y, *z, loc); + } + OptionalParseResult parseDclInstruction(uint32_t opcodeToken, Location loc, Instruction &out) { FailureOr result; @@ -1220,6 +1237,9 @@ class Parser { case D3D10_SB_OPCODE_DCL_OUTPUT: result = parseDclOutput(loc); break; + case D3D11_SB_OPCODE_DCL_THREAD_GROUP: + result = parseDclThreadGroup(loc); + break; default: return std::nullopt; } diff --git a/mlir/test/Target/DXSA/dcl_thread_group.mlir b/mlir/test/Target/DXSA/dcl_thread_group.mlir new file mode 100644 index 000000000000..ef296ba4130b --- /dev/null +++ b/mlir/test/Target/DXSA/dcl_thread_group.mlir @@ -0,0 +1,8 @@ +// RUN: mlir-translate --import-dxsa-bin %S/inputs/dcl_thread_group.bin | FileCheck %s + +// CHECK: module { +// CHECK-NEXT: dxsa.dcl_thread_group +// CHECK-NEXT: dxsa.dcl_thread_group +// CHECK-NEXT: dxsa.dcl_thread_group +// CHECK-NEXT: dxsa.dcl_thread_group +// CHECK-NEXT: } diff --git a/mlir/test/Target/DXSA/dcl_thread_group_invalid.mlir b/mlir/test/Target/DXSA/dcl_thread_group_invalid.mlir new file mode 100644 index 000000000000..17b1be8587aa --- /dev/null +++ b/mlir/test/Target/DXSA/dcl_thread_group_invalid.mlir @@ -0,0 +1,35 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics + +// expected-error@+1 {{attribute 'x' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive whose maximum value is 1024}} +dxsa.dcl_thread_group + +// ----- + +// expected-error@+1 {{attribute 'x' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive whose maximum value is 1024}} +dxsa.dcl_thread_group + +// ----- + +// expected-error@+1 {{attribute 'y' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive whose maximum value is 1024}} +dxsa.dcl_thread_group + +// ----- + +// expected-error@+1 {{attribute 'y' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive whose maximum value is 1024}} +dxsa.dcl_thread_group + +// ----- + +// expected-error@+1 {{attribute 'z' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive whose maximum value is 64}} +dxsa.dcl_thread_group + +// ----- + +// expected-error@+1 {{attribute 'z' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive whose maximum value is 64}} +dxsa.dcl_thread_group + +// ----- + +// 64 * 8 * 4 == 2048 +// expected-error@+1 {{'dxsa.dcl_thread_group' op thread group size x*y*z must be <= 1024, got 2048}} +dxsa.dcl_thread_group diff --git a/mlir/test/Target/DXSA/inputs/dcl_thread_group.bin b/mlir/test/Target/DXSA/inputs/dcl_thread_group.bin new file mode 100644 index 0000000000000000000000000000000000000000..f478f3ed7a8d6a69f20d026f2f5eaad41604beba GIT binary patch literal 64 hcmbQuz`(-Dz`y{*vw?gdMB#(vU}7M7m|6#*FaTz-0?q&c literal 0 HcmV?d00001