From 304a3455d6cd0ee2d9c7ff99377547917040d071 Mon Sep 17 00:00:00 2001 From: Piotr Sobczak Date: Wed, 7 Jan 2026 14:03:22 +0100 Subject: [PATCH] Add optional type name override Add llvmTypeNameOverride property to DialectType to allow overriding the LLVM IR type name generated for dialect types. This is useful when the natural dialect+mnemonic name does not match backend conventions. --- include/llvm-dialects/Dialect/Dialect.td | 4 + include/llvm-dialects/TableGen/DialectType.h | 4 + lib/TableGen/DialectType.cpp | 29 ++++- test/example/generated/ExampleDialect.cpp.inc | 2 +- test/unit/dialect/TestDialect.td | 53 ++++++++ test/unit/interface/CMakeLists.txt | 3 +- test/unit/interface/DialectTypeTests.cpp | 114 ++++++++++++++++++ 7 files changed, 204 insertions(+), 5 deletions(-) create mode 100644 test/unit/interface/DialectTypeTests.cpp diff --git a/include/llvm-dialects/Dialect/Dialect.td b/include/llvm-dialects/Dialect/Dialect.td index 8d1d623..7d51932 100644 --- a/include/llvm-dialects/Dialect/Dialect.td +++ b/include/llvm-dialects/Dialect/Dialect.td @@ -239,6 +239,10 @@ class DialectType : Type, Predicate { /// discriminant; the discriminant should be a type that cannot naturally /// appear elsewhere, e.g. (repr_struct (IntegerType 41)) dag representation = (repr_targetext); + + /// Optional custom name for the LLVM IR type. If not set, defaults to + /// ".". + string llvmTypeNameOverride = ""; } def and; diff --git a/include/llvm-dialects/TableGen/DialectType.h b/include/llvm-dialects/TableGen/DialectType.h index 8625fab..a7eec36 100644 --- a/include/llvm-dialects/TableGen/DialectType.h +++ b/include/llvm-dialects/TableGen/DialectType.h @@ -46,6 +46,9 @@ class DialectType : public BaseCppPredicate { RecordTy *getDialectRec() const { return m_dialectRec; } llvm::StringRef getName() const { return m_name; } llvm::StringRef getMnemonic() const { return m_mnemonic; } + llvm::StringRef getLlvmTypeNameOverride() const { + return m_llvmTypeNameOverride; + } bool defaultGetterHasExplicitContextArgument() const { return m_defaultGetterHasExplicitContextArgument; } @@ -64,6 +67,7 @@ class DialectType : public BaseCppPredicate { RecordTy *m_dialectRec = nullptr; std::string m_name; std::string m_mnemonic; + std::string m_llvmTypeNameOverride; bool m_defaultGetterHasExplicitContextArgument = false; std::string m_summary; std::string m_description; diff --git a/lib/TableGen/DialectType.cpp b/lib/TableGen/DialectType.cpp index fc0b53c..855d031 100644 --- a/lib/TableGen/DialectType.cpp +++ b/lib/TableGen/DialectType.cpp @@ -43,6 +43,7 @@ bool DialectType::init(raw_ostream &errs, GenDialectsContext &context, m_mnemonic = record->getValueAsString("mnemonic"); m_summary = record->getValueAsString("summary"); m_description = record->getValueAsString("description"); + m_llvmTypeNameOverride = record->getValueAsString("llvmTypeNameOverride"); if (auto *dag = cast(record->getValue("representation")->getValue())) { @@ -170,12 +171,25 @@ void DialectType::emitDeclaration(raw_ostream &out, GenDialect *dialect) const { fmt.addSubst("_type", getName()); fmt.addSubst("mnemonic", getMnemonic()); + std::string typeName; + if (!m_llvmTypeNameOverride.empty()) { + typeName = m_llvmTypeNameOverride; + } else { + typeName = (dialect->name + "." + getMnemonic()).str(); + } + // For struct-backed types, add trailing dot for the prefix + if (m_structBacked) { + typeName += "."; + } + + fmt.addSubst("typeName", typeName); + if (m_structBacked) { out << tgfmt(R"( class $_type : public ::llvm::StructType { using ::llvm::StructType::StructType; public: - static constexpr ::llvm::StringLiteral s_prefix{"$dialect.$mnemonic."}; + static constexpr ::llvm::StringLiteral s_prefix{"$typeName"}; using ::llvm::StructType::getElementType; @@ -225,10 +239,11 @@ void DialectType::emitDeclaration(raw_ostream &out, GenDialect *dialect) const { } else { // TargetExtType + fmt.addSubst("typeName", typeName); out << tgfmt(R"( class $_type : public ::llvm::TargetExtType { - static constexpr ::llvm::StringLiteral s_name{"$dialect.$mnemonic"}; + static constexpr ::llvm::StringLiteral s_name{"$typeName"}; public: static bool classof(const ::llvm::TargetExtType *t) { @@ -276,6 +291,14 @@ void DialectType::emitDefinition(raw_ostream &out, GenDialect *dialect) const { fmt.addSubst("fields", symbols.chooseName("fields")); fmt.addSubst("st", symbols.chooseName("st")); + // Compute the type name prefix (without trailing dot for struct-backed types) + std::string typeNamePrefix; + if (!m_llvmTypeNameOverride.empty()) { + typeNamePrefix = m_llvmTypeNameOverride; + } else { + typeNamePrefix = (dialect->name + "." + getMnemonic()).str(); + } + if (m_structBacked) { out << tgfmt("$_type* $_type::get(", &fmt); bool contextPresent = @@ -305,7 +328,7 @@ void DialectType::emitDefinition(raw_ostream &out, GenDialect *dialect) const { out << tgfmt( " std::string $name; ::llvm::raw_string_ostream $os($name);\n", &fmt); - out << tgfmt(" $os << \"$0\";\n", &fmt, m_mnemonic); + out << tgfmt(" $os << \"$0\";\n", &fmt, typeNamePrefix); for (const auto &getterArg : getterArgs) out << tgfmt(" $os << '.' << (uint64_t)$0;\n", &fmt, getterArg.name); diff --git a/test/example/generated/ExampleDialect.cpp.inc b/test/example/generated/ExampleDialect.cpp.inc index 902dc84..3b2b8fd 100644 --- a/test/example/generated/ExampleDialect.cpp.inc +++ b/test/example/generated/ExampleDialect.cpp.inc @@ -263,7 +263,7 @@ StructBackedType* StructBackedType::get(::llvm::LLVMContext & ctx, uint32_t fiel static_assert(sizeof(field2) <= sizeof(unsigned)); std::string name; ::llvm::raw_string_ostream os(name); - os << "struct.backed"; + os << "xd.ir.struct.backed"; os << '.' << (uint64_t)field0; os << '.' << (uint64_t)field1; os << '.' << (uint64_t)field2; diff --git a/test/unit/dialect/TestDialect.td b/test/unit/dialect/TestDialect.td index 1dfecd9..af3ea6f 100644 --- a/test/unit/dialect/TestDialect.td +++ b/test/unit/dialect/TestDialect.td @@ -82,3 +82,56 @@ def InstNameConflictWithExplRetTestOp : Op { + let typeArguments = (args + AttrI32:$param1, + AttrI32:$param2 + ); + + let defaultGetterHasExplicitContextArgument = true; + let llvmTypeNameOverride = "custom.renamed.target"; + + let summary = "TargetExtType with custom name"; + let description = [{ + Tests that llvmTypeNameOverride works for TargetExtType. + Type name should be "custom.renamed.target" instead of "test.custom.target". + }]; +} + +// Test struct-backed type with custom name prefix +def CustomNameStructType : DialectType { + let typeArguments = (args + AttrI32:$rows, + AttrI32:$cols + ); + + let defaultGetterHasExplicitContextArgument = true; + let representation = (repr_struct (IntegerType 73)); + let llvmTypeNameOverride = "custom.renamed.struct"; + + let summary = "Struct-backed type with custom name prefix"; + let description = [{ + Tests that llvmTypeNameOverride works for struct-backed types. + Type name prefix should be "custom.renamed.struct." instead of "test.custom.struct.". + The trailing dot is automatically added for struct types. + }]; +} + +// Test TargetExtType without custom name (default behavior) +def DefaultNameTargetExtType : DialectType { + let typeArguments = (args + AttrI32:$value + ); + + let defaultGetterHasExplicitContextArgument = true; + + let summary = "TargetExtType with default name"; + let description = [{ + Tests default naming behavior for TargetExtType. + Type name should be "test.default.target". + }]; +} diff --git a/test/unit/interface/CMakeLists.txt b/test/unit/interface/CMakeLists.txt index 282a192..b35ed3e 100644 --- a/test/unit/interface/CMakeLists.txt +++ b/test/unit/interface/CMakeLists.txt @@ -20,6 +20,7 @@ add_dialects_unit_test(DialectsADTTests OpSetTests.cpp OpMapTests.cpp - OpMapIRTests.cpp) + OpMapIRTests.cpp + DialectTypeTests.cpp) add_dependencies(DialectsADTTests TestDialectTableGen) diff --git a/test/unit/interface/DialectTypeTests.cpp b/test/unit/interface/DialectTypeTests.cpp new file mode 100644 index 0000000..c8d86b5 --- /dev/null +++ b/test/unit/interface/DialectTypeTests.cpp @@ -0,0 +1,114 @@ +/* + *********************************************************************************************************************** + * Copyright (c) 2026 Advanced Micro Devices, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *********************************************************************************************************************** + */ + +#include "TestDialect.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Type.h" +#include "gtest/gtest.h" + +using namespace llvm; +using namespace test; + +// Test custom llvmTypeNameOverride for TargetExtType +TEST(DialectTypeTest, CustomNameTargetExtType) { + LLVMContext ctx; + + auto *type = CustomNameTargetExtType::get(ctx, 10, 20); + ASSERT_NE(type, nullptr); + + // Verify it's recognized as the correct type via classof + EXPECT_TRUE(isa(type)); + EXPECT_TRUE(CustomNameTargetExtType::classof(type)); + + // Verify the type name is the custom one, not the default + auto *targetExtType = cast(type); + EXPECT_EQ(targetExtType->getName(), "custom.renamed.target"); + + // Verify parameters are accessible + EXPECT_EQ(type->getParam1(), 10u); + EXPECT_EQ(type->getParam2(), 20u); +} + +// Test custom llvmTypeNameOverride for struct-backed type +TEST(DialectTypeTest, CustomNameStructType) { + LLVMContext ctx; + + auto *type = CustomNameStructType::get(ctx, 16, 16); + ASSERT_NE(type, nullptr); + + // Verify it's recognized as the correct type via classof + EXPECT_TRUE(isa(type)); + EXPECT_TRUE(CustomNameStructType::classof(type)); + + // Verify the struct type name has the custom prefix + auto *structType = cast(type); + EXPECT_TRUE(structType->getName().starts_with("custom.renamed.struct.")); + + // Verify parameters are accessible + EXPECT_EQ(type->getRows(), 16u); + EXPECT_EQ(type->getCols(), 16u); +} + +// Test default naming behavior (no llvmTypeNameOverride specified) +TEST(DialectTypeTest, DefaultNameTargetExtType) { + LLVMContext ctx; + + auto *type = DefaultNameTargetExtType::get(ctx, 42); + ASSERT_NE(type, nullptr); + + // Verify it's recognized as the correct type + EXPECT_TRUE(isa(type)); + EXPECT_TRUE(DefaultNameTargetExtType::classof(type)); + + // Verify the type name uses the default dialect.mnemonic format + auto *targetExtType = cast(type); + EXPECT_EQ(targetExtType->getName(), "test.default.target"); + + // Verify parameter is accessible + EXPECT_EQ(type->getValue(), 42u); +} + +// Test that different custom types are distinguishable +TEST(DialectTypeTest, TypeDistinction) { + LLVMContext ctx; + + auto *customTarget = CustomNameTargetExtType::get(ctx, 1, 2); + auto *defaultTarget = DefaultNameTargetExtType::get(ctx, 3); + auto *customStruct = CustomNameStructType::get(ctx, 4, 5); + + // Types should be distinct + EXPECT_NE(static_cast(customTarget), + static_cast(defaultTarget)); + EXPECT_NE(static_cast(customTarget), + static_cast(customStruct)); + EXPECT_NE(static_cast(defaultTarget), + static_cast(customStruct)); + + // classof should work correctly + EXPECT_TRUE(CustomNameTargetExtType::classof(customTarget)); + EXPECT_FALSE(CustomNameTargetExtType::classof(defaultTarget)); + EXPECT_FALSE(CustomNameTargetExtType::classof(customStruct)); + + EXPECT_FALSE(DefaultNameTargetExtType::classof(customTarget)); + EXPECT_TRUE(DefaultNameTargetExtType::classof(defaultTarget)); + EXPECT_FALSE(DefaultNameTargetExtType::classof(customStruct)); + + EXPECT_FALSE(CustomNameStructType::classof(customTarget)); + EXPECT_FALSE(CustomNameStructType::classof(defaultTarget)); + EXPECT_TRUE(CustomNameStructType::classof(customStruct)); +}