Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -719,4 +719,28 @@ def DXSA_DclOutput : DXSA_Op<"dcl_output"> {
let assemblyFormat = "$operand attr-dict";
}

def DXSA_DclThreadGroup : DXSA_Op<"dcl_thread_group"> {
let summary = "declares compute shader thread group dimensions";
let description = [{
The `dxsa.dcl_thread_group` operation declares the `$x`, `$y` and `$z`
dimensions of a compute shader thread group.

The product `$x` * `$y` * `$z` must not exceed 1024.

Example:

```mlir
dxsa.dcl_thread_group<x = 1, y = 1, z = 1>
```
}];
let arguments = (ins
ConfinedAttr<I32Attr, [IntPositive, IntMaxValue<1024>]>:$x,
ConfinedAttr<I32Attr, [IntPositive, IntMaxValue<1024>]>:$y,
ConfinedAttr<I32Attr, [IntPositive, IntMaxValue<64>]>:$z);
let assemblyFormat = [{
`<` `x` `=` $x `,` `y` `=` $y `,` `z` `=` $z `>` attr-dict
}];
let hasVerifier = 1;
}

#endif // DXSA_OPS
13 changes: 13 additions & 0 deletions mlir/lib/Dialect/DXSA/IR/DXSA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,19 @@ void DXSADialect::initialize() {
>();
}

//===----------------------------------------------------------------------===//
// DclThreadGroup
//===----------------------------------------------------------------------===//

LogicalResult DclThreadGroup::verify() {
constexpr int64_t maxTotalThreads = 1024;
if (auto total = static_cast<int64_t>(getX()) * getY() * getZ();
total > maxTotalThreads)
return emitOpError("thread group size x*y*z must be <= ")
<< maxTotalThreads << ", got " << total;
return success();
}

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
Expand Down
20 changes: 20 additions & 0 deletions mlir/lib/Target/DXSA/BinaryParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,13 @@ class DXBuilder {
return dxsa::DclOutput::create(builder, loc, operand);
}

Instruction buildDclThreadGroup(uint32_t x, uint32_t y, uint32_t z,
Location loc) {
return dxsa::DclThreadGroup::create(
builder, loc, builder.getI32IntegerAttr(x),
builder.getI32IntegerAttr(y), builder.getI32IntegerAttr(z));
}

private:
MLIRContext *context;
ModuleOp module;
Expand Down Expand Up @@ -1174,6 +1181,16 @@ class Parser {
return builder.buildDclOutput(*operand, loc);
}

FailureOr<Instruction> parseDclThreadGroup(Location loc) {
auto x = parseToken();
FAILURE_IF_FAILED(x);
auto y = parseToken();
FAILURE_IF_FAILED(y);
auto z = parseToken();
FAILURE_IF_FAILED(z);
return builder.buildDclThreadGroup(*x, *y, *z, loc);
}

OptionalParseResult parseDclInstruction(uint32_t opcodeToken, Location loc,
Instruction &out) {
FailureOr<Instruction> result;
Expand Down Expand Up @@ -1220,6 +1237,9 @@ class Parser {
case D3D10_SB_OPCODE_DCL_OUTPUT:
result = parseDclOutput(loc);
break;
case D3D11_SB_OPCODE_DCL_THREAD_GROUP:
result = parseDclThreadGroup(loc);
break;
default:
return std::nullopt;
}
Expand Down
8 changes: 8 additions & 0 deletions mlir/test/Target/DXSA/dcl_thread_group.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// RUN: mlir-translate --import-dxsa-bin %S/inputs/dcl_thread_group.bin | FileCheck %s

// CHECK: module {
// CHECK-NEXT: dxsa.dcl_thread_group<x = 1, y = 1, z = 1>
// CHECK-NEXT: dxsa.dcl_thread_group<x = 1024, y = 1, z = 1>
// CHECK-NEXT: dxsa.dcl_thread_group<x = 1, y = 1024, z = 1>
// CHECK-NEXT: dxsa.dcl_thread_group<x = 1, y = 1, z = 64>
// CHECK-NEXT: }
35 changes: 35 additions & 0 deletions mlir/test/Target/DXSA/dcl_thread_group_invalid.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics

// expected-error@+1 {{attribute 'x' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive whose maximum value is 1024}}
dxsa.dcl_thread_group <x = 0, y = 1, z = 1>

// -----

// expected-error@+1 {{attribute 'x' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive whose maximum value is 1024}}
dxsa.dcl_thread_group <x = 1025, y = 1, z = 1>

// -----

// expected-error@+1 {{attribute 'y' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive whose maximum value is 1024}}
dxsa.dcl_thread_group <x = 1, y = 0, z = 1>

// -----

// expected-error@+1 {{attribute 'y' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive whose maximum value is 1024}}
dxsa.dcl_thread_group <x = 1, y = 1025, z = 1>

// -----

// expected-error@+1 {{attribute 'z' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive whose maximum value is 64}}
dxsa.dcl_thread_group <x = 1, y = 1, z = 0>

// -----

// expected-error@+1 {{attribute 'z' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive whose maximum value is 64}}
dxsa.dcl_thread_group <x = 1, y = 1, z = 65>

// -----

// 64 * 8 * 4 == 2048
// expected-error@+1 {{'dxsa.dcl_thread_group' op thread group size x*y*z must be <= 1024, got 2048}}
dxsa.dcl_thread_group <x = 64, y = 8, z = 4>
Binary file added mlir/test/Target/DXSA/inputs/dcl_thread_group.bin
Binary file not shown.