Skip to content

Commit 75dab49

Browse files
[SYCL][clang] Fix more free-function kernel integration header cases (#20877)
This commit fixes a number of known issues when the integration header generates prototypes of the free-function kernels. These changes focus on the additional removal of aliasing and proper handling of templated template arguments. This commit also adds disabled test cases for a known issue with unresolved nested templated type aliases. These are cases for future work. --------- Signed-off-by: Larsen, Steffen <steffen.larsen@intel.com> Co-authored-by: Sachkov, Alexey <alexey.sachkov@intel.com>
1 parent 31f3d6a commit 75dab49

File tree

3 files changed

+389
-92
lines changed

3 files changed

+389
-92
lines changed

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 279 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -6613,6 +6613,254 @@ static void PrintNSClosingBraces(raw_ostream &OS, const DeclContext *DC) {
66136613
[](raw_ostream &, const NamespaceDecl *) {}, OS, DC);
66146614
}
66156615

6616+
/// Dedicated visitor which helps with printing of kernel arguments in forward
6617+
/// declarations of free function kernels which are declared as function
6618+
/// templates.
6619+
///
6620+
/// Based on:
6621+
/// \code
6622+
/// template <typename T1, typename T2>
6623+
/// void foo(T1 a, int b, T2 c);
6624+
/// \endcode
6625+
///
6626+
/// It prints into the output stream "T1, int, T2".
6627+
///
6628+
/// The main complexity (which motivates addition of such visitor) comes from
6629+
/// the fact that there could be type aliases and default template arguments.
6630+
/// For example:
6631+
/// \code
6632+
/// template<typename T>
6633+
/// void kernel(sycl::accessor<T, 1>);
6634+
/// template void kernel(sycl::accessor<int, 1>);
6635+
/// \endcode
6636+
/// sycl::accessor has many template arguments which have default values. If
6637+
/// we iterate over non-canonicalized argument type, we don't get those default
6638+
/// values and we don't get necessary namespace qualifiers for all the template
6639+
/// arguments. If we iterate over canonicalized argument type, then all
6640+
/// references to T will be replaced with something like type-argument-X-Y.
6641+
/// What this visitor does is it iterates over both in sync, picking the right
6642+
/// values from one or another.
6643+
///
6644+
/// The template argument visitor functions take an additional
6645+
/// ArrayRef<TemplateArgument> argument corresponding to the template arguments
6646+
/// of the outermost template. This is used by some of these functions for
6647+
/// mapping dependent template arguments.
6648+
///
6649+
/// Moral of the story: drop integration header ASAP (but that is blocked
6650+
/// by support for 3rd-party host compilers, which is important).
6651+
class FreeFunctionTemplateKernelArgsPrinter
6652+
: public ConstTemplateArgumentVisitor<FreeFunctionTemplateKernelArgsPrinter,
6653+
void, ArrayRef<TemplateArgument>> {
6654+
raw_ostream &O;
6655+
PrintingPolicy &Policy;
6656+
ASTContext &Context;
6657+
6658+
using Base =
6659+
ConstTemplateArgumentVisitor<FreeFunctionTemplateKernelArgsPrinter, void,
6660+
ArrayRef<TemplateArgument>>;
6661+
6662+
// Desugars a template argument. This helps avoid aliases.
6663+
static TemplateArgument DesugarTemplateArgument(const TemplateArgument &Arg) {
6664+
switch (Arg.getKind()) {
6665+
case TemplateArgument::ArgKind::Type: {
6666+
QualType ArgTy = Arg.getAsType();
6667+
return {QualType(ArgTy->getUnqualifiedDesugaredType(),
6668+
ArgTy.getCVRQualifiers())};
6669+
}
6670+
case TemplateArgument::ArgKind::Template: {
6671+
TemplateName TN = Arg.getAsTemplate();
6672+
while (std::optional<TemplateName> DesugaredTN =
6673+
TN.desugar(/*IgnoreDeduced=*/false))
6674+
TN = *DesugaredTN;
6675+
return {TN};
6676+
}
6677+
default:
6678+
return Arg;
6679+
}
6680+
}
6681+
6682+
void PrintDesugared(const TemplateArgument &Arg) {
6683+
DesugarTemplateArgument(Arg).print(Policy, O, /*IncludeType=*/false);
6684+
}
6685+
6686+
public:
6687+
FreeFunctionTemplateKernelArgsPrinter(raw_ostream &O, PrintingPolicy &Policy,
6688+
ASTContext &Context)
6689+
: O(O), Policy(Policy), Context(Context) {}
6690+
6691+
void Visit(const ParmVarDecl *Param) {
6692+
// There are cases when we can't directly use neither the original
6693+
// argument type, nor its canonical version. An example would be:
6694+
// template<typename T>
6695+
// void kernel(sycl::accessor<T, 1>);
6696+
// template void kernel(sycl::accessor<int, 1>);
6697+
// Accessor has multiple non-type template arguments with default values
6698+
// and non-qualified type will not include necessary namespaces for all
6699+
// of them. Qualified type will have that information, but all references
6700+
// to T will be replaced to something like type-argument-0
6701+
// What we do instead is we iterate template arguments of both versions
6702+
// of a type in sync and take elements from one or another to get the best
6703+
// of both: proper references to template arguments of a kernel itself and
6704+
// fully-qualified names for enumerations.
6705+
//
6706+
// Moral of the story: drop integration header ASAP (but that is blocked
6707+
// by support for 3rd-party host compilers, which is important).
6708+
QualType T = Param->getType();
6709+
QualType CT = T.getCanonicalType();
6710+
6711+
const auto *TST = dyn_cast<TemplateSpecializationType>(T.getTypePtr());
6712+
const auto *CTST = dyn_cast<TemplateSpecializationType>(CT.getTypePtr());
6713+
if (!TST || !CTST) {
6714+
O << T.getDesugaredType(Context).getAsString(Policy);
6715+
return;
6716+
}
6717+
6718+
const TemplateSpecializationType *TSTAsNonAlias =
6719+
TST->getAsNonAliasTemplateSpecializationType();
6720+
if (TSTAsNonAlias)
6721+
TST = TSTAsNonAlias;
6722+
6723+
ArrayRef<TemplateArgument> SpecArgs = TST->template_arguments();
6724+
ArrayRef<TemplateArgument> DeclArgs = CTST->template_arguments();
6725+
6726+
const TemplateDecl *TD = CTST->getTemplateName().getAsTemplateDecl();
6727+
if (!TD->getIdentifier())
6728+
TD = TST->getTemplateName().getAsTemplateDecl();
6729+
assert(TD->getIdentifier() &&
6730+
"Either the type or the canonical type should have an identifier.");
6731+
TD->printQualifiedName(O);
6732+
6733+
O << "<";
6734+
for (size_t I = 0, E = std::max(DeclArgs.size(), SpecArgs.size()),
6735+
SE = SpecArgs.size();
6736+
I < E; ++I) {
6737+
if (I != 0)
6738+
O << ", ";
6739+
// If we have a specialized argument, use it. Otherwise fallback to a
6740+
// default argument.
6741+
// We pass specialized arguments in case there are references to them
6742+
// from other types.
6743+
// FIXME: passing SpecArgs here is incorrect. It refers to template
6744+
// arguments of a single function argument, but DeclArgs contain
6745+
// references (in form of depth-index) to template arguments of the
6746+
// function itself which results in incorrect integration header being
6747+
// produced.
6748+
Base::Visit(I < SE ? SpecArgs[I] : DeclArgs[I], SpecArgs);
6749+
}
6750+
O << ">";
6751+
}
6752+
6753+
// Internal version of the function above that is used when template argument
6754+
// is a template by itself
6755+
void Visit(const TemplateSpecializationType *T,
6756+
ArrayRef<TemplateArgument> SpecArgs) {
6757+
const TemplateDecl *TD = T->getTemplateName().getAsTemplateDecl();
6758+
const auto *TTPD = dyn_cast<TemplateTemplateParmDecl>(TD);
6759+
if (TTPD && !TTPD->getIdentifier())
6760+
PrintDesugared(SpecArgs[TTPD->getIndex()]);
6761+
else
6762+
TD->printQualifiedName(O);
6763+
O << "<";
6764+
ArrayRef<const TemplateArgument> DeclArgs = T->template_arguments();
6765+
for (size_t I = 0, E = DeclArgs.size(); I < E; ++I) {
6766+
if (I != 0)
6767+
O << ", ";
6768+
Base::Visit(DeclArgs[I], SpecArgs);
6769+
}
6770+
O << ">";
6771+
}
6772+
6773+
void VisitNullTemplateArgument(const TemplateArgument &,
6774+
ArrayRef<TemplateArgument>) {
6775+
llvm_unreachable("If template argument has not been deduced, then we can't "
6776+
"forward-declare it, something went wrong");
6777+
}
6778+
6779+
void VisitTypeTemplateArgument(const TemplateArgument &Arg,
6780+
ArrayRef<TemplateArgument> SpecArgs) {
6781+
TemplateArgument DesugaredArg = DesugarTemplateArgument(Arg);
6782+
// If we reference an existing template argument without a known identifier,
6783+
// print it instead.
6784+
const auto *TPT = dyn_cast<TemplateTypeParmType>(DesugaredArg.getAsType());
6785+
if (TPT && !TPT->getIdentifier()) {
6786+
PrintDesugared(SpecArgs[TPT->getIndex()]);
6787+
return;
6788+
}
6789+
6790+
const auto *TST =
6791+
dyn_cast<TemplateSpecializationType>(DesugaredArg.getAsType());
6792+
if (TST && Arg.isInstantiationDependent()) {
6793+
// This is an instantiation dependent template specialization, meaning
6794+
// that some of its arguments reference template arguments of the free
6795+
// function kernel itself.
6796+
Visit(TST, SpecArgs);
6797+
return;
6798+
}
6799+
6800+
DesugaredArg.print(Policy, O, /* IncludeType = */ false);
6801+
}
6802+
6803+
void VisitDeclarationTemplateArgument(const TemplateArgument &,
6804+
ArrayRef<TemplateArgument>) {
6805+
llvm_unreachable("Free function kernels cannot have non-type template "
6806+
"arguments which are pointers or references");
6807+
}
6808+
6809+
void VisitNullPtrTemplateArgument(const TemplateArgument &,
6810+
ArrayRef<TemplateArgument>) {
6811+
llvm_unreachable("Free function kernels cannot have non-type template "
6812+
"arguments which are pointers or references");
6813+
}
6814+
6815+
void VisitIntegralTemplateArgument(const TemplateArgument &Arg,
6816+
ArrayRef<TemplateArgument>) {
6817+
PrintDesugared(Arg);
6818+
}
6819+
6820+
void VisitStructuralValueTemplateArgument(const TemplateArgument &Arg,
6821+
ArrayRef<TemplateArgument>) {
6822+
PrintDesugared(Arg);
6823+
}
6824+
6825+
void VisitTemplateTemplateArgument(const TemplateArgument &Arg,
6826+
ArrayRef<TemplateArgument>) {
6827+
PrintDesugared(Arg);
6828+
}
6829+
6830+
void VisitTemplateExpansionTemplateArgument(const TemplateArgument &Arg,
6831+
ArrayRef<TemplateArgument>) {
6832+
PrintDesugared(Arg);
6833+
}
6834+
6835+
void VisitExpressionTemplateArgument(const TemplateArgument &Arg,
6836+
ArrayRef<TemplateArgument>) {
6837+
Expr *E = Arg.getAsExpr();
6838+
assert(E && "Failed to get an Expr for an Expression template arg?");
6839+
6840+
if (Arg.isInstantiationDependent() ||
6841+
E->getType()->isScopedEnumeralType()) {
6842+
// Scoped enumerations can't be implicitly cast from integers, so
6843+
// we don't need to evaluate them.
6844+
// If expression is instantiation-dependent, then we can't evaluate it
6845+
// either, let's fallback to default printing mechanism.
6846+
PrintDesugared(Arg);
6847+
return;
6848+
}
6849+
6850+
Expr::EvalResult Res;
6851+
[[maybe_unused]] bool Success =
6852+
Arg.getAsExpr()->EvaluateAsConstantExpr(Res, Context);
6853+
assert(Success && "invalid non-type template argument?");
6854+
assert(!Res.Val.isAbsent() && "couldn't read the evaulation result?");
6855+
Res.Val.printPretty(O, Policy, Arg.getAsExpr()->getType(), &Context);
6856+
}
6857+
6858+
void VisitPackTemplateArgument(const TemplateArgument &Arg,
6859+
ArrayRef<TemplateArgument>) {
6860+
PrintDesugared(Arg);
6861+
}
6862+
};
6863+
66166864
class FreeFunctionPrinter {
66176865
raw_ostream &O;
66186866
PrintingPolicy &Policy;
@@ -6776,86 +7024,16 @@ class FreeFunctionPrinter {
67767024
llvm::raw_svector_ostream ParmListOstream{ParamList};
67777025
Policy.SuppressTagKeyword = true;
67787026

6779-
for (ParmVarDecl *Param : Parameters) {
7027+
FreeFunctionTemplateKernelArgsPrinter Printer(ParmListOstream, Policy,
7028+
Context);
7029+
7030+
for (const ParmVarDecl *Param : Parameters) {
67807031
if (FirstParam)
67817032
FirstParam = false;
67827033
else
67837034
ParmListOstream << ", ";
67847035

6785-
// There are cases when we can't directly use neither the original
6786-
// argument type, nor its canonical version. An example would be:
6787-
// template<typename T>
6788-
// void kernel(sycl::accessor<T, 1>);
6789-
// template void kernel(sycl::accessor<int, 1>);
6790-
// Accessor has multiple non-type template arguments with default values
6791-
// and non-qualified type will not include necessary namespaces for all
6792-
// of them. Qualified type will have that information, but all references
6793-
// to T will be replaced to something like type-argument-0
6794-
// What we do instead is we iterate template arguments of both versions
6795-
// of a type in sync and take elements from one or another to get the best
6796-
// of both: proper references to template arguments of a kernel itself and
6797-
// fully-qualified names for enumerations.
6798-
//
6799-
// Moral of the story: drop integration header ASAP (but that is blocked
6800-
// by support for 3rd-party host compilers, which is important).
6801-
QualType T = Param->getType();
6802-
QualType CT = T.getCanonicalType();
6803-
6804-
const auto *TST = dyn_cast<TemplateSpecializationType>(T.getTypePtr());
6805-
const auto *CTST = dyn_cast<TemplateSpecializationType>(CT.getTypePtr());
6806-
if (!TST || !CTST) {
6807-
ParmListOstream << T.getAsString(Policy);
6808-
continue;
6809-
}
6810-
6811-
const TemplateSpecializationType *TSTAsNonAlias =
6812-
TST->getAsNonAliasTemplateSpecializationType();
6813-
if (TSTAsNonAlias)
6814-
TST = TSTAsNonAlias;
6815-
6816-
TemplateName CTN = CTST->getTemplateName();
6817-
CTN.getAsTemplateDecl()->printQualifiedName(ParmListOstream);
6818-
ParmListOstream << "<";
6819-
6820-
ArrayRef<TemplateArgument> SpecArgs = TST->template_arguments();
6821-
ArrayRef<TemplateArgument> DeclArgs = CTST->template_arguments();
6822-
6823-
auto TemplateArgPrinter = [&](const TemplateArgument &Arg) {
6824-
if (Arg.getKind() != TemplateArgument::ArgKind::Expression ||
6825-
Arg.isInstantiationDependent()) {
6826-
Arg.print(Policy, ParmListOstream, /* IncludeType = */ false);
6827-
return;
6828-
}
6829-
6830-
Expr *E = Arg.getAsExpr();
6831-
assert(E && "Failed to get an Expr for an Expression template arg?");
6832-
if (E->getType().getTypePtr()->isScopedEnumeralType()) {
6833-
// Scoped enumerations can't be implicitly cast from integers, so
6834-
// we don't need to evaluate them.
6835-
Arg.print(Policy, ParmListOstream, /* IncludeType = */ false);
6836-
return;
6837-
}
6838-
6839-
Expr::EvalResult Res;
6840-
[[maybe_unused]] bool Success =
6841-
Arg.getAsExpr()->EvaluateAsConstantExpr(Res, Context);
6842-
assert(Success && "invalid non-type template argument?");
6843-
assert(!Res.Val.isAbsent() && "couldn't read the evaulation result?");
6844-
Res.Val.printPretty(ParmListOstream, Policy, Arg.getAsExpr()->getType(),
6845-
&Context);
6846-
};
6847-
6848-
for (size_t I = 0, E = std::max(DeclArgs.size(), SpecArgs.size()),
6849-
SE = SpecArgs.size();
6850-
I < E; ++I) {
6851-
if (I != 0)
6852-
ParmListOstream << ", ";
6853-
// If we have a specialized argument, use it. Otherwise fallback to a
6854-
// default argument.
6855-
TemplateArgPrinter(I < SE ? SpecArgs[I] : DeclArgs[I]);
6856-
}
6857-
6858-
ParmListOstream << ">";
7036+
Printer.Visit(Param);
68597037
}
68607038
return ParamList.str().str();
68617039
}
@@ -6873,26 +7051,39 @@ class FreeFunctionPrinter {
68737051
std::string getTemplateParameters(const clang::TemplateParameterList *TPL) {
68747052
std::string TemplateParams{"template <"};
68757053
bool FirstParam{true};
6876-
for (NamedDecl *Param : *TPL) {
7054+
for (const NamedDecl *Param : *TPL) {
68777055
if (!FirstParam)
68787056
TemplateParams += ", ";
68797057
FirstParam = false;
6880-
if (const auto *TemplateParam = dyn_cast<TemplateTypeParmDecl>(Param)) {
6881-
TemplateParams +=
6882-
TemplateParam->wasDeclaredWithTypename() ? "typename " : "class ";
6883-
if (TemplateParam->isParameterPack())
6884-
TemplateParams += "... ";
6885-
TemplateParams += TemplateParam->getNameAsString();
6886-
} else if (const auto *NonTypeParam =
6887-
dyn_cast<NonTypeTemplateParmDecl>(Param)) {
6888-
TemplateParams += NonTypeParam->getType().getAsString();
6889-
TemplateParams += " ";
6890-
TemplateParams += NonTypeParam->getNameAsString();
6891-
}
7058+
TemplateParams += getTemplateParameter(Param);
68927059
}
68937060
TemplateParams += "> ";
68947061
return TemplateParams;
68957062
}
7063+
7064+
/// Helper method to get text representation of a template parameter.
7065+
/// \param Param The template parameter.
7066+
std::string getTemplateParameter(const NamedDecl *Param) {
7067+
auto GetTypenameOrClass = [](const auto *Param) {
7068+
return Param->wasDeclaredWithTypename() ? "typename " : "class ";
7069+
};
7070+
if (const auto *TemplateParam = dyn_cast<TemplateTypeParmDecl>(Param)) {
7071+
std::string TemplateParamStr = GetTypenameOrClass(TemplateParam);
7072+
if (TemplateParam->isParameterPack())
7073+
TemplateParamStr += "... ";
7074+
TemplateParamStr += TemplateParam->getNameAsString();
7075+
return TemplateParamStr;
7076+
} else if (const auto *NonTypeParam =
7077+
dyn_cast<NonTypeTemplateParmDecl>(Param)) {
7078+
return NonTypeParam->getType().getAsString() + " " +
7079+
NonTypeParam->getNameAsString();
7080+
} else if (const auto *TTParam =
7081+
dyn_cast<TemplateTemplateParmDecl>(Param)) {
7082+
return getTemplateParameters(TTParam->getTemplateParameters()) + " " +
7083+
GetTypenameOrClass(TTParam) + TTParam->getNameAsString();
7084+
}
7085+
return "";
7086+
}
68967087
};
68977088

68987089
void SYCLIntegrationHeader::emit(raw_ostream &O) {

0 commit comments

Comments
 (0)