@@ -6613,6 +6613,254 @@ static void PrintNSClosingBraces(raw_ostream &OS, const DeclContext *DC) {
66136613 [](raw_ostream &, const NamespaceDecl *) {}, OS, DC);
66146614}
66156615
6616+ // / Dedicated visitor which helps with printing of kernel arguments in forward
6617+ // / declarations of free function kernels which are declared as function
6618+ // / templates.
6619+ // /
6620+ // / Based on:
6621+ // / \code
6622+ // / template <typename T1, typename T2>
6623+ // / void foo(T1 a, int b, T2 c);
6624+ // / \endcode
6625+ // /
6626+ // / It prints into the output stream "T1, int, T2".
6627+ // /
6628+ // / The main complexity (which motivates addition of such visitor) comes from
6629+ // / the fact that there could be type aliases and default template arguments.
6630+ // / For example:
6631+ // / \code
6632+ // / template<typename T>
6633+ // / void kernel(sycl::accessor<T, 1>);
6634+ // / template void kernel(sycl::accessor<int, 1>);
6635+ // / \endcode
6636+ // / sycl::accessor has many template arguments which have default values. If
6637+ // / we iterate over non-canonicalized argument type, we don't get those default
6638+ // / values and we don't get necessary namespace qualifiers for all the template
6639+ // / arguments. If we iterate over canonicalized argument type, then all
6640+ // / references to T will be replaced with something like type-argument-X-Y.
6641+ // / What this visitor does is it iterates over both in sync, picking the right
6642+ // / values from one or another.
6643+ // /
6644+ // / The template argument visitor functions take an additional
6645+ // / ArrayRef<TemplateArgument> argument corresponding to the template arguments
6646+ // / of the outermost template. This is used by some of these functions for
6647+ // / mapping dependent template arguments.
6648+ // /
6649+ // / Moral of the story: drop integration header ASAP (but that is blocked
6650+ // / by support for 3rd-party host compilers, which is important).
6651+ class FreeFunctionTemplateKernelArgsPrinter
6652+ : public ConstTemplateArgumentVisitor<FreeFunctionTemplateKernelArgsPrinter,
6653+ void , ArrayRef<TemplateArgument>> {
6654+ raw_ostream &O;
6655+ PrintingPolicy &Policy;
6656+ ASTContext &Context;
6657+
6658+ using Base =
6659+ ConstTemplateArgumentVisitor<FreeFunctionTemplateKernelArgsPrinter, void ,
6660+ ArrayRef<TemplateArgument>>;
6661+
6662+ // Desugars a template argument. This helps avoid aliases.
6663+ static TemplateArgument DesugarTemplateArgument (const TemplateArgument &Arg) {
6664+ switch (Arg.getKind ()) {
6665+ case TemplateArgument::ArgKind::Type: {
6666+ QualType ArgTy = Arg.getAsType ();
6667+ return {QualType (ArgTy->getUnqualifiedDesugaredType (),
6668+ ArgTy.getCVRQualifiers ())};
6669+ }
6670+ case TemplateArgument::ArgKind::Template: {
6671+ TemplateName TN = Arg.getAsTemplate ();
6672+ while (std::optional<TemplateName> DesugaredTN =
6673+ TN.desugar (/* IgnoreDeduced=*/ false ))
6674+ TN = *DesugaredTN;
6675+ return {TN};
6676+ }
6677+ default :
6678+ return Arg;
6679+ }
6680+ }
6681+
6682+ void PrintDesugared (const TemplateArgument &Arg) {
6683+ DesugarTemplateArgument (Arg).print (Policy, O, /* IncludeType=*/ false );
6684+ }
6685+
6686+ public:
6687+ FreeFunctionTemplateKernelArgsPrinter (raw_ostream &O, PrintingPolicy &Policy,
6688+ ASTContext &Context)
6689+ : O(O), Policy(Policy), Context(Context) {}
6690+
6691+ void Visit (const ParmVarDecl *Param) {
6692+ // There are cases when we can't directly use neither the original
6693+ // argument type, nor its canonical version. An example would be:
6694+ // template<typename T>
6695+ // void kernel(sycl::accessor<T, 1>);
6696+ // template void kernel(sycl::accessor<int, 1>);
6697+ // Accessor has multiple non-type template arguments with default values
6698+ // and non-qualified type will not include necessary namespaces for all
6699+ // of them. Qualified type will have that information, but all references
6700+ // to T will be replaced to something like type-argument-0
6701+ // What we do instead is we iterate template arguments of both versions
6702+ // of a type in sync and take elements from one or another to get the best
6703+ // of both: proper references to template arguments of a kernel itself and
6704+ // fully-qualified names for enumerations.
6705+ //
6706+ // Moral of the story: drop integration header ASAP (but that is blocked
6707+ // by support for 3rd-party host compilers, which is important).
6708+ QualType T = Param->getType ();
6709+ QualType CT = T.getCanonicalType ();
6710+
6711+ const auto *TST = dyn_cast<TemplateSpecializationType>(T.getTypePtr ());
6712+ const auto *CTST = dyn_cast<TemplateSpecializationType>(CT.getTypePtr ());
6713+ if (!TST || !CTST) {
6714+ O << T.getDesugaredType (Context).getAsString (Policy);
6715+ return ;
6716+ }
6717+
6718+ const TemplateSpecializationType *TSTAsNonAlias =
6719+ TST->getAsNonAliasTemplateSpecializationType ();
6720+ if (TSTAsNonAlias)
6721+ TST = TSTAsNonAlias;
6722+
6723+ ArrayRef<TemplateArgument> SpecArgs = TST->template_arguments ();
6724+ ArrayRef<TemplateArgument> DeclArgs = CTST->template_arguments ();
6725+
6726+ const TemplateDecl *TD = CTST->getTemplateName ().getAsTemplateDecl ();
6727+ if (!TD->getIdentifier ())
6728+ TD = TST->getTemplateName ().getAsTemplateDecl ();
6729+ assert (TD->getIdentifier () &&
6730+ " Either the type or the canonical type should have an identifier." );
6731+ TD->printQualifiedName (O);
6732+
6733+ O << " <" ;
6734+ for (size_t I = 0 , E = std::max (DeclArgs.size (), SpecArgs.size ()),
6735+ SE = SpecArgs.size ();
6736+ I < E; ++I) {
6737+ if (I != 0 )
6738+ O << " , " ;
6739+ // If we have a specialized argument, use it. Otherwise fallback to a
6740+ // default argument.
6741+ // We pass specialized arguments in case there are references to them
6742+ // from other types.
6743+ // FIXME: passing SpecArgs here is incorrect. It refers to template
6744+ // arguments of a single function argument, but DeclArgs contain
6745+ // references (in form of depth-index) to template arguments of the
6746+ // function itself which results in incorrect integration header being
6747+ // produced.
6748+ Base::Visit (I < SE ? SpecArgs[I] : DeclArgs[I], SpecArgs);
6749+ }
6750+ O << " >" ;
6751+ }
6752+
6753+ // Internal version of the function above that is used when template argument
6754+ // is a template by itself
6755+ void Visit (const TemplateSpecializationType *T,
6756+ ArrayRef<TemplateArgument> SpecArgs) {
6757+ const TemplateDecl *TD = T->getTemplateName ().getAsTemplateDecl ();
6758+ const auto *TTPD = dyn_cast<TemplateTemplateParmDecl>(TD);
6759+ if (TTPD && !TTPD->getIdentifier ())
6760+ PrintDesugared (SpecArgs[TTPD->getIndex ()]);
6761+ else
6762+ TD->printQualifiedName (O);
6763+ O << " <" ;
6764+ ArrayRef<const TemplateArgument> DeclArgs = T->template_arguments ();
6765+ for (size_t I = 0 , E = DeclArgs.size (); I < E; ++I) {
6766+ if (I != 0 )
6767+ O << " , " ;
6768+ Base::Visit (DeclArgs[I], SpecArgs);
6769+ }
6770+ O << " >" ;
6771+ }
6772+
6773+ void VisitNullTemplateArgument (const TemplateArgument &,
6774+ ArrayRef<TemplateArgument>) {
6775+ llvm_unreachable (" If template argument has not been deduced, then we can't "
6776+ " forward-declare it, something went wrong" );
6777+ }
6778+
6779+ void VisitTypeTemplateArgument (const TemplateArgument &Arg,
6780+ ArrayRef<TemplateArgument> SpecArgs) {
6781+ TemplateArgument DesugaredArg = DesugarTemplateArgument (Arg);
6782+ // If we reference an existing template argument without a known identifier,
6783+ // print it instead.
6784+ const auto *TPT = dyn_cast<TemplateTypeParmType>(DesugaredArg.getAsType ());
6785+ if (TPT && !TPT->getIdentifier ()) {
6786+ PrintDesugared (SpecArgs[TPT->getIndex ()]);
6787+ return ;
6788+ }
6789+
6790+ const auto *TST =
6791+ dyn_cast<TemplateSpecializationType>(DesugaredArg.getAsType ());
6792+ if (TST && Arg.isInstantiationDependent ()) {
6793+ // This is an instantiation dependent template specialization, meaning
6794+ // that some of its arguments reference template arguments of the free
6795+ // function kernel itself.
6796+ Visit (TST, SpecArgs);
6797+ return ;
6798+ }
6799+
6800+ DesugaredArg.print (Policy, O, /* IncludeType = */ false );
6801+ }
6802+
6803+ void VisitDeclarationTemplateArgument (const TemplateArgument &,
6804+ ArrayRef<TemplateArgument>) {
6805+ llvm_unreachable (" Free function kernels cannot have non-type template "
6806+ " arguments which are pointers or references" );
6807+ }
6808+
6809+ void VisitNullPtrTemplateArgument (const TemplateArgument &,
6810+ ArrayRef<TemplateArgument>) {
6811+ llvm_unreachable (" Free function kernels cannot have non-type template "
6812+ " arguments which are pointers or references" );
6813+ }
6814+
6815+ void VisitIntegralTemplateArgument (const TemplateArgument &Arg,
6816+ ArrayRef<TemplateArgument>) {
6817+ PrintDesugared (Arg);
6818+ }
6819+
6820+ void VisitStructuralValueTemplateArgument (const TemplateArgument &Arg,
6821+ ArrayRef<TemplateArgument>) {
6822+ PrintDesugared (Arg);
6823+ }
6824+
6825+ void VisitTemplateTemplateArgument (const TemplateArgument &Arg,
6826+ ArrayRef<TemplateArgument>) {
6827+ PrintDesugared (Arg);
6828+ }
6829+
6830+ void VisitTemplateExpansionTemplateArgument (const TemplateArgument &Arg,
6831+ ArrayRef<TemplateArgument>) {
6832+ PrintDesugared (Arg);
6833+ }
6834+
6835+ void VisitExpressionTemplateArgument (const TemplateArgument &Arg,
6836+ ArrayRef<TemplateArgument>) {
6837+ Expr *E = Arg.getAsExpr ();
6838+ assert (E && " Failed to get an Expr for an Expression template arg?" );
6839+
6840+ if (Arg.isInstantiationDependent () ||
6841+ E->getType ()->isScopedEnumeralType ()) {
6842+ // Scoped enumerations can't be implicitly cast from integers, so
6843+ // we don't need to evaluate them.
6844+ // If expression is instantiation-dependent, then we can't evaluate it
6845+ // either, let's fallback to default printing mechanism.
6846+ PrintDesugared (Arg);
6847+ return ;
6848+ }
6849+
6850+ Expr::EvalResult Res;
6851+ [[maybe_unused]] bool Success =
6852+ Arg.getAsExpr ()->EvaluateAsConstantExpr (Res, Context);
6853+ assert (Success && " invalid non-type template argument?" );
6854+ assert (!Res.Val .isAbsent () && " couldn't read the evaulation result?" );
6855+ Res.Val .printPretty (O, Policy, Arg.getAsExpr ()->getType (), &Context);
6856+ }
6857+
6858+ void VisitPackTemplateArgument (const TemplateArgument &Arg,
6859+ ArrayRef<TemplateArgument>) {
6860+ PrintDesugared (Arg);
6861+ }
6862+ };
6863+
66166864class FreeFunctionPrinter {
66176865 raw_ostream &O;
66186866 PrintingPolicy &Policy;
@@ -6776,86 +7024,16 @@ class FreeFunctionPrinter {
67767024 llvm::raw_svector_ostream ParmListOstream{ParamList};
67777025 Policy.SuppressTagKeyword = true ;
67787026
6779- for (ParmVarDecl *Param : Parameters) {
7027+ FreeFunctionTemplateKernelArgsPrinter Printer (ParmListOstream, Policy,
7028+ Context);
7029+
7030+ for (const ParmVarDecl *Param : Parameters) {
67807031 if (FirstParam)
67817032 FirstParam = false ;
67827033 else
67837034 ParmListOstream << " , " ;
67847035
6785- // There are cases when we can't directly use neither the original
6786- // argument type, nor its canonical version. An example would be:
6787- // template<typename T>
6788- // void kernel(sycl::accessor<T, 1>);
6789- // template void kernel(sycl::accessor<int, 1>);
6790- // Accessor has multiple non-type template arguments with default values
6791- // and non-qualified type will not include necessary namespaces for all
6792- // of them. Qualified type will have that information, but all references
6793- // to T will be replaced to something like type-argument-0
6794- // What we do instead is we iterate template arguments of both versions
6795- // of a type in sync and take elements from one or another to get the best
6796- // of both: proper references to template arguments of a kernel itself and
6797- // fully-qualified names for enumerations.
6798- //
6799- // Moral of the story: drop integration header ASAP (but that is blocked
6800- // by support for 3rd-party host compilers, which is important).
6801- QualType T = Param->getType ();
6802- QualType CT = T.getCanonicalType ();
6803-
6804- const auto *TST = dyn_cast<TemplateSpecializationType>(T.getTypePtr ());
6805- const auto *CTST = dyn_cast<TemplateSpecializationType>(CT.getTypePtr ());
6806- if (!TST || !CTST) {
6807- ParmListOstream << T.getAsString (Policy);
6808- continue ;
6809- }
6810-
6811- const TemplateSpecializationType *TSTAsNonAlias =
6812- TST->getAsNonAliasTemplateSpecializationType ();
6813- if (TSTAsNonAlias)
6814- TST = TSTAsNonAlias;
6815-
6816- TemplateName CTN = CTST->getTemplateName ();
6817- CTN.getAsTemplateDecl ()->printQualifiedName (ParmListOstream);
6818- ParmListOstream << " <" ;
6819-
6820- ArrayRef<TemplateArgument> SpecArgs = TST->template_arguments ();
6821- ArrayRef<TemplateArgument> DeclArgs = CTST->template_arguments ();
6822-
6823- auto TemplateArgPrinter = [&](const TemplateArgument &Arg) {
6824- if (Arg.getKind () != TemplateArgument::ArgKind::Expression ||
6825- Arg.isInstantiationDependent ()) {
6826- Arg.print (Policy, ParmListOstream, /* IncludeType = */ false );
6827- return ;
6828- }
6829-
6830- Expr *E = Arg.getAsExpr ();
6831- assert (E && " Failed to get an Expr for an Expression template arg?" );
6832- if (E->getType ().getTypePtr ()->isScopedEnumeralType ()) {
6833- // Scoped enumerations can't be implicitly cast from integers, so
6834- // we don't need to evaluate them.
6835- Arg.print (Policy, ParmListOstream, /* IncludeType = */ false );
6836- return ;
6837- }
6838-
6839- Expr::EvalResult Res;
6840- [[maybe_unused]] bool Success =
6841- Arg.getAsExpr ()->EvaluateAsConstantExpr (Res, Context);
6842- assert (Success && " invalid non-type template argument?" );
6843- assert (!Res.Val .isAbsent () && " couldn't read the evaulation result?" );
6844- Res.Val .printPretty (ParmListOstream, Policy, Arg.getAsExpr ()->getType (),
6845- &Context);
6846- };
6847-
6848- for (size_t I = 0 , E = std::max (DeclArgs.size (), SpecArgs.size ()),
6849- SE = SpecArgs.size ();
6850- I < E; ++I) {
6851- if (I != 0 )
6852- ParmListOstream << " , " ;
6853- // If we have a specialized argument, use it. Otherwise fallback to a
6854- // default argument.
6855- TemplateArgPrinter (I < SE ? SpecArgs[I] : DeclArgs[I]);
6856- }
6857-
6858- ParmListOstream << " >" ;
7036+ Printer.Visit (Param);
68597037 }
68607038 return ParamList.str ().str ();
68617039 }
@@ -6873,26 +7051,39 @@ class FreeFunctionPrinter {
68737051 std::string getTemplateParameters (const clang::TemplateParameterList *TPL) {
68747052 std::string TemplateParams{" template <" };
68757053 bool FirstParam{true };
6876- for (NamedDecl *Param : *TPL) {
7054+ for (const NamedDecl *Param : *TPL) {
68777055 if (!FirstParam)
68787056 TemplateParams += " , " ;
68797057 FirstParam = false ;
6880- if (const auto *TemplateParam = dyn_cast<TemplateTypeParmDecl>(Param)) {
6881- TemplateParams +=
6882- TemplateParam->wasDeclaredWithTypename () ? " typename " : " class " ;
6883- if (TemplateParam->isParameterPack ())
6884- TemplateParams += " ... " ;
6885- TemplateParams += TemplateParam->getNameAsString ();
6886- } else if (const auto *NonTypeParam =
6887- dyn_cast<NonTypeTemplateParmDecl>(Param)) {
6888- TemplateParams += NonTypeParam->getType ().getAsString ();
6889- TemplateParams += " " ;
6890- TemplateParams += NonTypeParam->getNameAsString ();
6891- }
7058+ TemplateParams += getTemplateParameter (Param);
68927059 }
68937060 TemplateParams += " > " ;
68947061 return TemplateParams;
68957062 }
7063+
7064+ // / Helper method to get text representation of a template parameter.
7065+ // / \param Param The template parameter.
7066+ std::string getTemplateParameter (const NamedDecl *Param) {
7067+ auto GetTypenameOrClass = [](const auto *Param) {
7068+ return Param->wasDeclaredWithTypename () ? " typename " : " class " ;
7069+ };
7070+ if (const auto *TemplateParam = dyn_cast<TemplateTypeParmDecl>(Param)) {
7071+ std::string TemplateParamStr = GetTypenameOrClass (TemplateParam);
7072+ if (TemplateParam->isParameterPack ())
7073+ TemplateParamStr += " ... " ;
7074+ TemplateParamStr += TemplateParam->getNameAsString ();
7075+ return TemplateParamStr;
7076+ } else if (const auto *NonTypeParam =
7077+ dyn_cast<NonTypeTemplateParmDecl>(Param)) {
7078+ return NonTypeParam->getType ().getAsString () + " " +
7079+ NonTypeParam->getNameAsString ();
7080+ } else if (const auto *TTParam =
7081+ dyn_cast<TemplateTemplateParmDecl>(Param)) {
7082+ return getTemplateParameters (TTParam->getTemplateParameters ()) + " " +
7083+ GetTypenameOrClass (TTParam) + TTParam->getNameAsString ();
7084+ }
7085+ return " " ;
7086+ }
68967087};
68977088
68987089void SYCLIntegrationHeader::emit (raw_ostream &O) {
0 commit comments