Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/llvm-dialects/Dialect/Dialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,10 @@ class DialectType<Dialect dialect_, string mnemonic_> : Type, Predicate {
/// discriminant; the discriminant should be a type that cannot naturally
/// appear elsewhere, e.g. (repr_struct (IntegerType 41))
dag representation = (repr_targetext);

/// Optional custom name for the LLVM IR type. If not set, defaults to
/// "<dialect>.<mnemonic>".
string llvmTypeNameOverride = "";
}

def and;
Expand Down
4 changes: 4 additions & 0 deletions include/llvm-dialects/TableGen/DialectType.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ class DialectType : public BaseCppPredicate {
RecordTy *getDialectRec() const { return m_dialectRec; }
llvm::StringRef getName() const { return m_name; }
llvm::StringRef getMnemonic() const { return m_mnemonic; }
llvm::StringRef getLlvmTypeNameOverride() const {
return m_llvmTypeNameOverride;
}
bool defaultGetterHasExplicitContextArgument() const {
return m_defaultGetterHasExplicitContextArgument;
}
Expand All @@ -64,6 +67,7 @@ class DialectType : public BaseCppPredicate {
RecordTy *m_dialectRec = nullptr;
std::string m_name;
std::string m_mnemonic;
std::string m_llvmTypeNameOverride;
bool m_defaultGetterHasExplicitContextArgument = false;
std::string m_summary;
std::string m_description;
Expand Down
29 changes: 26 additions & 3 deletions lib/TableGen/DialectType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ bool DialectType::init(raw_ostream &errs, GenDialectsContext &context,
m_mnemonic = record->getValueAsString("mnemonic");
m_summary = record->getValueAsString("summary");
m_description = record->getValueAsString("description");
m_llvmTypeNameOverride = record->getValueAsString("llvmTypeNameOverride");

if (auto *dag =
cast<DagInit>(record->getValue("representation")->getValue())) {
Expand Down Expand Up @@ -170,12 +171,25 @@ void DialectType::emitDeclaration(raw_ostream &out, GenDialect *dialect) const {
fmt.addSubst("_type", getName());
fmt.addSubst("mnemonic", getMnemonic());

std::string typeName;
if (!m_llvmTypeNameOverride.empty()) {
typeName = m_llvmTypeNameOverride;
} else {
typeName = (dialect->name + "." + getMnemonic()).str();
}
// For struct-backed types, add trailing dot for the prefix
if (m_structBacked) {
typeName += ".";
}

fmt.addSubst("typeName", typeName);

if (m_structBacked) {
out << tgfmt(R"(
class $_type : public ::llvm::StructType {
using ::llvm::StructType::StructType;
public:
static constexpr ::llvm::StringLiteral s_prefix{"$dialect.$mnemonic."};
static constexpr ::llvm::StringLiteral s_prefix{"$typeName"};

using ::llvm::StructType::getElementType;

Expand Down Expand Up @@ -225,10 +239,11 @@ void DialectType::emitDeclaration(raw_ostream &out, GenDialect *dialect) const {
} else {

// TargetExtType
fmt.addSubst("typeName", typeName);

out << tgfmt(R"(
class $_type : public ::llvm::TargetExtType {
static constexpr ::llvm::StringLiteral s_name{"$dialect.$mnemonic"};
static constexpr ::llvm::StringLiteral s_name{"$typeName"};

public:
static bool classof(const ::llvm::TargetExtType *t) {
Expand Down Expand Up @@ -276,6 +291,14 @@ void DialectType::emitDefinition(raw_ostream &out, GenDialect *dialect) const {
fmt.addSubst("fields", symbols.chooseName("fields"));
fmt.addSubst("st", symbols.chooseName("st"));

// Compute the type name prefix (without trailing dot for struct-backed types)
std::string typeNamePrefix;
if (!m_llvmTypeNameOverride.empty()) {
typeNamePrefix = m_llvmTypeNameOverride;
} else {
typeNamePrefix = (dialect->name + "." + getMnemonic()).str();
}

if (m_structBacked) {
out << tgfmt("$_type* $_type::get(", &fmt);
bool contextPresent =
Expand Down Expand Up @@ -305,7 +328,7 @@ void DialectType::emitDefinition(raw_ostream &out, GenDialect *dialect) const {

out << tgfmt(
" std::string $name; ::llvm::raw_string_ostream $os($name);\n", &fmt);
out << tgfmt(" $os << \"$0\";\n", &fmt, m_mnemonic);
out << tgfmt(" $os << \"$0\";\n", &fmt, typeNamePrefix);
for (const auto &getterArg : getterArgs)
out << tgfmt(" $os << '.' << (uint64_t)$0;\n", &fmt, getterArg.name);

Expand Down
2 changes: 1 addition & 1 deletion test/example/generated/ExampleDialect.cpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ StructBackedType* StructBackedType::get(::llvm::LLVMContext & ctx, uint32_t fiel

static_assert(sizeof(field2) <= sizeof(unsigned));
std::string name; ::llvm::raw_string_ostream os(name);
os << "struct.backed";
os << "xd.ir.struct.backed";
os << '.' << (uint64_t)field0;
os << '.' << (uint64_t)field1;
os << '.' << (uint64_t)field2;
Expand Down
53 changes: 53 additions & 0 deletions test/unit/dialect/TestDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,56 @@ def InstNameConflictWithExplRetTestOp : Op<TestDialect, "try.conflict.type.suffi
value like IRBuilder methods. This op tries to produce a conflict
}];
}

// Test types with custom llvmTypeNameOverride

// Test TargetExtType with custom name
def CustomNameTargetExtType : DialectType<TestDialect, "custom.target"> {
let typeArguments = (args
AttrI32:$param1,
AttrI32:$param2
);

let defaultGetterHasExplicitContextArgument = true;
let llvmTypeNameOverride = "custom.renamed.target";

let summary = "TargetExtType with custom name";
let description = [{
Tests that llvmTypeNameOverride works for TargetExtType.
Type name should be "custom.renamed.target" instead of "test.custom.target".
}];
}

// Test struct-backed type with custom name prefix
def CustomNameStructType : DialectType<TestDialect, "custom.struct"> {
let typeArguments = (args
AttrI32:$rows,
AttrI32:$cols
);

let defaultGetterHasExplicitContextArgument = true;
let representation = (repr_struct (IntegerType 73));
let llvmTypeNameOverride = "custom.renamed.struct";

let summary = "Struct-backed type with custom name prefix";
let description = [{
Tests that llvmTypeNameOverride works for struct-backed types.
Type name prefix should be "custom.renamed.struct." instead of "test.custom.struct.".
The trailing dot is automatically added for struct types.
}];
}

// Test TargetExtType without custom name (default behavior)
def DefaultNameTargetExtType : DialectType<TestDialect, "default.target"> {
let typeArguments = (args
AttrI32:$value
);

let defaultGetterHasExplicitContextArgument = true;

let summary = "TargetExtType with default name";
let description = [{
Tests default naming behavior for TargetExtType.
Type name should be "test.default.target".
}];
}
3 changes: 2 additions & 1 deletion test/unit/interface/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
add_dialects_unit_test(DialectsADTTests
OpSetTests.cpp
OpMapTests.cpp
OpMapIRTests.cpp)
OpMapIRTests.cpp
DialectTypeTests.cpp)

add_dependencies(DialectsADTTests TestDialectTableGen)
114 changes: 114 additions & 0 deletions test/unit/interface/DialectTypeTests.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
***********************************************************************************************************************
* Copyright (c) 2026 Advanced Micro Devices, Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
***********************************************************************************************************************
*/

#include "TestDialect.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Type.h"
#include "gtest/gtest.h"

using namespace llvm;
using namespace test;

// Test custom llvmTypeNameOverride for TargetExtType
TEST(DialectTypeTest, CustomNameTargetExtType) {
LLVMContext ctx;

auto *type = CustomNameTargetExtType::get(ctx, 10, 20);
ASSERT_NE(type, nullptr);

// Verify it's recognized as the correct type via classof
EXPECT_TRUE(isa<CustomNameTargetExtType>(type));
EXPECT_TRUE(CustomNameTargetExtType::classof(type));

// Verify the type name is the custom one, not the default
auto *targetExtType = cast<TargetExtType>(type);
EXPECT_EQ(targetExtType->getName(), "custom.renamed.target");

// Verify parameters are accessible
EXPECT_EQ(type->getParam1(), 10u);
EXPECT_EQ(type->getParam2(), 20u);
}

// Test custom llvmTypeNameOverride for struct-backed type
TEST(DialectTypeTest, CustomNameStructType) {
LLVMContext ctx;

auto *type = CustomNameStructType::get(ctx, 16, 16);
ASSERT_NE(type, nullptr);

// Verify it's recognized as the correct type via classof
EXPECT_TRUE(isa<CustomNameStructType>(type));
EXPECT_TRUE(CustomNameStructType::classof(type));

// Verify the struct type name has the custom prefix
auto *structType = cast<StructType>(type);
EXPECT_TRUE(structType->getName().starts_with("custom.renamed.struct."));

// Verify parameters are accessible
EXPECT_EQ(type->getRows(), 16u);
EXPECT_EQ(type->getCols(), 16u);
}

// Test default naming behavior (no llvmTypeNameOverride specified)
TEST(DialectTypeTest, DefaultNameTargetExtType) {
LLVMContext ctx;

auto *type = DefaultNameTargetExtType::get(ctx, 42);
ASSERT_NE(type, nullptr);

// Verify it's recognized as the correct type
EXPECT_TRUE(isa<DefaultNameTargetExtType>(type));
EXPECT_TRUE(DefaultNameTargetExtType::classof(type));

// Verify the type name uses the default dialect.mnemonic format
auto *targetExtType = cast<TargetExtType>(type);
EXPECT_EQ(targetExtType->getName(), "test.default.target");

// Verify parameter is accessible
EXPECT_EQ(type->getValue(), 42u);
}

// Test that different custom types are distinguishable
TEST(DialectTypeTest, TypeDistinction) {
LLVMContext ctx;

auto *customTarget = CustomNameTargetExtType::get(ctx, 1, 2);
auto *defaultTarget = DefaultNameTargetExtType::get(ctx, 3);
auto *customStruct = CustomNameStructType::get(ctx, 4, 5);

// Types should be distinct
EXPECT_NE(static_cast<llvm::Type *>(customTarget),
static_cast<llvm::Type *>(defaultTarget));
EXPECT_NE(static_cast<llvm::Type *>(customTarget),
static_cast<llvm::Type *>(customStruct));
EXPECT_NE(static_cast<llvm::Type *>(defaultTarget),
static_cast<llvm::Type *>(customStruct));

// classof should work correctly
EXPECT_TRUE(CustomNameTargetExtType::classof(customTarget));
EXPECT_FALSE(CustomNameTargetExtType::classof(defaultTarget));
EXPECT_FALSE(CustomNameTargetExtType::classof(customStruct));

EXPECT_FALSE(DefaultNameTargetExtType::classof(customTarget));
EXPECT_TRUE(DefaultNameTargetExtType::classof(defaultTarget));
EXPECT_FALSE(DefaultNameTargetExtType::classof(customStruct));

EXPECT_FALSE(CustomNameStructType::classof(customTarget));
EXPECT_FALSE(CustomNameStructType::classof(defaultTarget));
EXPECT_TRUE(CustomNameStructType::classof(customStruct));
}