From 62d86701c14c5dbcab8b2c0521f9c0f63aeb50e4 Mon Sep 17 00:00:00 2001 From: Sima Nerush <2002ssn@gmail.com> Date: Fri, 26 Jan 2024 18:48:35 -0800 Subject: [PATCH 1/2] [AST] Use interface type to obtain the depth of parameter pack's element type. --- lib/AST/GenericEnvironment.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lib/AST/GenericEnvironment.cpp b/lib/AST/GenericEnvironment.cpp index cbe1dcfda14f6..38dc35e817342 100644 --- a/lib/AST/GenericEnvironment.cpp +++ b/lib/AST/GenericEnvironment.cpp @@ -734,10 +734,11 @@ GenericEnvironment::mapElementTypeIntoPackContext(Type type) const { type = type->mapTypeOutOfContext(); + auto interfaceType = element->getInterfaceType(); + llvm::SmallDenseMap packParamForElement; - auto elementDepth = - sig.getInnermostGenericParams().front()->getDepth() + 1; + auto elementDepth = interfaceType->getRootGenericParam()->getDepth(); for (auto *genericParam : sig.getGenericParams()) { if (!genericParam->isParameterPack()) From 0b167b55b1e8ccb48c46588a9a974e63bf9be267 Mon Sep 17 00:00:00 2001 From: Sima Nerush <2002ssn@gmail.com> Date: Fri, 26 Jan 2024 18:53:58 -0800 Subject: [PATCH 2/2] [ConstraintSystem] Cache pack element generic environments associated with`for`-`in` loops over parameter packs to apply in `getPackElementEnvironment`. --- include/swift/Sema/ConstraintSystem.h | 9 +++++++++ include/swift/Sema/SyntacticElementTarget.h | 17 +++++++++++++---- lib/AST/GenericEnvironment.cpp | 2 ++ lib/Sema/CSGen.cpp | 6 ++++++ lib/Sema/CSSolver.cpp | 14 ++++++++++++++ lib/Sema/ConstraintSystem.cpp | 7 +++++-- lib/Sema/SyntacticElementTarget.cpp | 5 +++-- lib/Sema/TypeCheckConstraints.cpp | 6 ++++-- lib/Sema/TypeCheckStmt.cpp | 17 +++++++++++++++-- lib/Sema/TypeChecker.h | 3 ++- test/stmt/foreach.swift | 10 ++++++++++ 11 files changed, 83 insertions(+), 13 deletions(-) diff --git a/include/swift/Sema/ConstraintSystem.h b/include/swift/Sema/ConstraintSystem.h index 834bb0fa574cf..81d1b0d2928ae 100644 --- a/include/swift/Sema/ConstraintSystem.h +++ b/include/swift/Sema/ConstraintSystem.h @@ -1543,6 +1543,10 @@ class Solution { llvm::MapVector PackEnvironments; + /// The outer pack element generic environment to use when dealing with nested + /// pack iteration (see \c getPackElementEnvironment). + llvm::SmallVector PackElementGenericEnvironments; + /// The locators of \c Defaultable constraints whose defaults were used. llvm::DenseSet DefaultedConstraints; @@ -2344,6 +2348,8 @@ class ConstraintSystem { llvm::SmallMapVector PackEnvironments; + llvm::SmallVector PackElementGenericEnvironments; + /// The set of functions that have been transformed by a result builder. llvm::MapVector resultBuilderTransformed; @@ -2833,6 +2839,9 @@ class ConstraintSystem { /// The length of \c PackEnvironments. unsigned numPackEnvironments; + /// The length of \c PackElementGenericEnvironments. + unsigned numPackElementGenericEnvironments; + /// The length of \c DefaultedConstraints. unsigned numDefaultedConstraints; diff --git a/include/swift/Sema/SyntacticElementTarget.h b/include/swift/Sema/SyntacticElementTarget.h index 4a87c152396cc..fee597f9c2d8e 100644 --- a/include/swift/Sema/SyntacticElementTarget.h +++ b/include/swift/Sema/SyntacticElementTarget.h @@ -160,6 +160,7 @@ class SyntacticElementTarget { DeclContext *dc; Pattern *pattern; bool ignoreWhereClause; + GenericEnvironment *packElementEnv; ForEachStmtInfo info; } forEachStmt; @@ -239,11 +240,13 @@ class SyntacticElementTarget { } SyntacticElementTarget(ForEachStmt *stmt, DeclContext *dc, - bool ignoreWhereClause) + bool ignoreWhereClause, + GenericEnvironment *packElementEnv) : kind(Kind::forEachStmt) { forEachStmt.stmt = stmt; forEachStmt.dc = dc; forEachStmt.ignoreWhereClause = ignoreWhereClause; + forEachStmt.packElementEnv = packElementEnv; } /// Form a target for the initialization of a pattern from an expression. @@ -259,9 +262,10 @@ class SyntacticElementTarget { unsigned patternBindingIndex, bool bindPatternVarsOneWay); /// Form a target for a for-in loop. - static SyntacticElementTarget forForEachStmt(ForEachStmt *stmt, - DeclContext *dc, - bool ignoreWhereClause = false); + static SyntacticElementTarget + forForEachStmt(ForEachStmt *stmt, DeclContext *dc, + bool ignoreWhereClause = false, + GenericEnvironment *packElementEnv = nullptr); /// Form a target for a property with an attached property wrapper that is /// initialized out-of-line. @@ -536,6 +540,11 @@ class SyntacticElementTarget { return forEachStmt.ignoreWhereClause; } + GenericEnvironment *getPackElementEnv() const { + assert(isForEachStmt()); + return forEachStmt.packElementEnv; + } + const ForEachStmtInfo &getForEachStmtInfo() const { assert(isForEachStmt()); return forEachStmt.info; diff --git a/lib/AST/GenericEnvironment.cpp b/lib/AST/GenericEnvironment.cpp index 38dc35e817342..2d34111519b8f 100644 --- a/lib/AST/GenericEnvironment.cpp +++ b/lib/AST/GenericEnvironment.cpp @@ -793,6 +793,8 @@ Type BuildForwardingSubstitutions::operator()(SubstitutableType *type) const { auto param = type->castTo(); if (!param->isParameterPack()) return resultType; + if (resultType->is()) + return resultType; return PackType::getSingletonPackExpansion(resultType); } return Type(); diff --git a/lib/Sema/CSGen.cpp b/lib/Sema/CSGen.cpp index 158a03529ad05..22c45206c2b61 100644 --- a/lib/Sema/CSGen.cpp +++ b/lib/Sema/CSGen.cpp @@ -4912,6 +4912,12 @@ bool ConstraintSystem::generateConstraints( } case SyntacticElementTarget::Kind::forEachStmt: { + + // Cache the outer generic environment, if it exists. + if (target.getPackElementEnv()) { + PackElementGenericEnvironments.push_back(target.getPackElementEnv()); + } + // For a for-each statement, generate constraints for the pattern, where // clause, and sequence traversal. auto resultTarget = generateForEachStmtConstraints(*this, target); diff --git a/lib/Sema/CSSolver.cpp b/lib/Sema/CSSolver.cpp index 9cad13b4c068d..ccc16cbd4cce4 100644 --- a/lib/Sema/CSSolver.cpp +++ b/lib/Sema/CSSolver.cpp @@ -247,6 +247,9 @@ Solution ConstraintSystem::finalize() { for (const auto &packEnv : PackEnvironments) solution.PackEnvironments.insert(packEnv); + for (const auto &packEltGenericEnv : PackElementGenericEnvironments) + solution.PackElementGenericEnvironments.push_back(packEltGenericEnv); + return solution; } @@ -316,6 +319,12 @@ void ConstraintSystem::applySolution(const Solution &solution) { PackEnvironments.insert(packEnvironment); } + // Register the solutions's pack element generic environments. + for (auto &packElementGenericEnvironment : + solution.PackElementGenericEnvironments) { + PackElementGenericEnvironments.push_back(packElementGenericEnvironment); + } + // Register the defaulted type variables. DefaultedConstraints.insert(solution.DefaultedConstraints.begin(), solution.DefaultedConstraints.end()); @@ -647,6 +656,7 @@ ConstraintSystem::SolverScope::SolverScope(ConstraintSystem &cs) numOpenedPackExpansionTypes = cs.OpenedPackExpansionTypes.size(); numPackExpansionEnvironments = cs.PackExpansionEnvironments.size(); numPackEnvironments = cs.PackEnvironments.size(); + numPackElementGenericEnvironments = cs.PackElementGenericEnvironments.size(); numDefaultedConstraints = cs.DefaultedConstraints.size(); numAddedNodeTypes = cs.addedNodeTypes.size(); numAddedKeyPathComponentTypes = cs.addedKeyPathComponentTypes.size(); @@ -736,6 +746,10 @@ ConstraintSystem::SolverScope::~SolverScope() { // Remove any pack environments. truncate(cs.PackEnvironments, numPackEnvironments); + // Remove any pack element generic environments. + truncate(cs.PackElementGenericEnvironments, + numPackElementGenericEnvironments); + // Remove any defaulted type variables. truncate(cs.DefaultedConstraints, numDefaultedConstraints); diff --git a/lib/Sema/ConstraintSystem.cpp b/lib/Sema/ConstraintSystem.cpp index a2494d76cb58d..dabefbbb33b28 100644 --- a/lib/Sema/ConstraintSystem.cpp +++ b/lib/Sema/ConstraintSystem.cpp @@ -792,9 +792,11 @@ ConstraintSystem::getPackElementEnvironment(ConstraintLocator *locator, shapeClass->mapTypeOutOfContext()->getCanonicalType()); auto &ctx = getASTContext(); + auto *contextEnv = PackElementGenericEnvironments.empty() + ? DC->getGenericEnvironmentOfContext() + : PackElementGenericEnvironments.back(); auto elementSig = ctx.getOpenedElementSignature( - DC->getGenericSignatureOfContext().getCanonicalSignature(), shapeParam); - auto *contextEnv = DC->getGenericEnvironmentOfContext(); + contextEnv->getGenericSignature().getCanonicalSignature(), shapeParam); auto contextSubs = contextEnv->getForwardingSubstitutionMap(); return GenericEnvironment::forOpenedElement(elementSig, uuidAndShape.first, shapeParam, contextSubs); @@ -4403,6 +4405,7 @@ size_t Solution::getTotalMemory() const { OpenedPackExpansionTypes.getMemorySize() + PackExpansionEnvironments.getMemorySize() + size_in_bytes(PackEnvironments) + + PackElementGenericEnvironments.size() + (DefaultedConstraints.size() * sizeof(void *)) + ImplicitCallAsFunctionRoots.getMemorySize() + nodeTypes.getMemorySize() + diff --git a/lib/Sema/SyntacticElementTarget.cpp b/lib/Sema/SyntacticElementTarget.cpp index c6ef0bff463f5..27d753d6f888e 100644 --- a/lib/Sema/SyntacticElementTarget.cpp +++ b/lib/Sema/SyntacticElementTarget.cpp @@ -178,8 +178,9 @@ SyntacticElementTarget SyntacticElementTarget::forInitialization( SyntacticElementTarget SyntacticElementTarget::forForEachStmt(ForEachStmt *stmt, DeclContext *dc, - bool ignoreWhereClause) { - SyntacticElementTarget target(stmt, dc, ignoreWhereClause); + bool ignoreWhereClause, + GenericEnvironment *packElementEnv) { + SyntacticElementTarget target(stmt, dc, ignoreWhereClause, packElementEnv); return target; } diff --git a/lib/Sema/TypeCheckConstraints.cpp b/lib/Sema/TypeCheckConstraints.cpp index 1531c5846ebc6..59f9ab75624e5 100644 --- a/lib/Sema/TypeCheckConstraints.cpp +++ b/lib/Sema/TypeCheckConstraints.cpp @@ -896,7 +896,8 @@ bool TypeChecker::typeCheckPatternBinding(PatternBindingDecl *PBD, return hadError; } -bool TypeChecker::typeCheckForEachBinding(DeclContext *dc, ForEachStmt *stmt) { +bool TypeChecker::typeCheckForEachBinding(DeclContext *dc, ForEachStmt *stmt, + GenericEnvironment *packElementEnv) { auto &Context = dc->getASTContext(); FrontendStatsTracer statsTracer(Context.Stats, "typecheck-for-each", stmt); PrettyStackTraceStmt stackTrace(Context, "type-checking-for-each", stmt); @@ -912,7 +913,8 @@ bool TypeChecker::typeCheckForEachBinding(DeclContext *dc, ForEachStmt *stmt) { return true; }; - auto target = SyntacticElementTarget::forForEachStmt(stmt, dc); + auto target = SyntacticElementTarget::forForEachStmt( + stmt, dc, /*ignoreWhereClause=*/false, packElementEnv); if (!typeCheckTarget(target)) return failed(); diff --git a/lib/Sema/TypeCheckStmt.cpp b/lib/Sema/TypeCheckStmt.cpp index 6383a68e7a711..9b4138b00fadf 100644 --- a/lib/Sema/TypeCheckStmt.cpp +++ b/lib/Sema/TypeCheckStmt.cpp @@ -1000,6 +1000,8 @@ class StmtChecker : public StmtVisitor { StmtChecker(DeclContext *DC) : Ctx(DC->getASTContext()), DC(DC) { } + llvm::SmallVector genericSigStack; + //===--------------------------------------------------------------------===// // Helper Functions. //===--------------------------------------------------------------------===// @@ -1434,7 +1436,10 @@ class StmtChecker : public StmtVisitor { } Stmt *visitForEachStmt(ForEachStmt *S) { - if (TypeChecker::typeCheckForEachBinding(DC, S)) + GenericEnvironment *genericSignature = + genericSigStack.empty() ? nullptr : genericSigStack.back(); + + if (TypeChecker::typeCheckForEachBinding(DC, S, genericSignature)) return nullptr; // Type-check the body of the loop. @@ -1442,9 +1447,17 @@ class StmtChecker : public StmtVisitor { checkLabeledStmtShadowing(getASTContext(), sourceFile, S); BraceStmt *Body = S->getBody(); + + if (auto packExpansion = + dyn_cast(S->getParsedSequence())) + genericSigStack.push_back(packExpansion->getGenericEnvironment()); + typeCheckStmt(Body); S->setBody(Body); - + + if (isa(S->getParsedSequence())) + genericSigStack.pop_back(); + return S; } diff --git a/lib/Sema/TypeChecker.h b/lib/Sema/TypeChecker.h index 49cad37212f89..27d91bf233a57 100644 --- a/lib/Sema/TypeChecker.h +++ b/lib/Sema/TypeChecker.h @@ -759,7 +759,8 @@ bool typeCheckPatternBinding(PatternBindingDecl *PBD, unsigned patternNumber, /// Type-check a for-each loop's pattern binding and sequence together. /// /// \returns true if a failure occurred. -bool typeCheckForEachBinding(DeclContext *dc, ForEachStmt *stmt); +bool typeCheckForEachBinding(DeclContext *dc, ForEachStmt *stmt, + GenericEnvironment *packElementEnv); /// Compute the set of captures for the given function or closure. void computeCaptures(AnyFunctionRef AFR); diff --git a/test/stmt/foreach.swift b/test/stmt/foreach.swift index 77caccc3d57f1..043fafb0c0a02 100644 --- a/test/stmt/foreach.swift +++ b/test/stmt/foreach.swift @@ -330,4 +330,14 @@ do { // expected-error@-1 {{'where' clause in pack iteration is not supported}} } } + + func nested(value: repeat each T, value1: repeat each U) { + for e1 in repeat each value { + for _ in [] {} + for e2 in repeat each value1 { + let y = e1 // Ok + } + let x = e1 // Ok + } + } }