2121#include " TypeCheckType.h"
2222#include " swift/AST/GenericEnvironment.h"
2323#include " swift/AST/ParameterList.h"
24+ #include " swift/AST/TypeVisitor.h"
2425#include " swift/Basic/Statistic.h"
2526#include " llvm/ADT/SetVector.h"
2627#include " llvm/ADT/SmallString.h"
@@ -1464,6 +1465,364 @@ static ArrayRef<OverloadChoice> partitionSIMDOperators(
14641465 return scratch;
14651466}
14661467
1468+ // / Retrieve the type that will be used when matching the given overload.
1469+ static Type getEffectiveOverloadType (const OverloadChoice &overload) {
1470+ switch (overload.getKind ()) {
1471+ case OverloadChoiceKind::Decl:
1472+ // Declaration choices are handled below.
1473+ break ;
1474+
1475+ case OverloadChoiceKind::BaseType:
1476+ case OverloadChoiceKind::DeclViaBridge:
1477+ case OverloadChoiceKind::DeclViaDynamic:
1478+ case OverloadChoiceKind::DeclViaUnwrappedOptional:
1479+ case OverloadChoiceKind::DynamicMemberLookup:
1480+ case OverloadChoiceKind::KeyPathApplication:
1481+ case OverloadChoiceKind::TupleIndex:
1482+ return Type ();
1483+ }
1484+
1485+ auto decl = overload.getDecl ();
1486+
1487+ // Retrieve the interface type.
1488+ auto type = decl->getInterfaceType ();
1489+ if (!type) {
1490+ decl->getASTContext ().getLazyResolver ()->resolveDeclSignature (decl);
1491+ type = decl->getInterfaceType ();
1492+ if (!type) {
1493+ return Type ();
1494+ }
1495+ }
1496+
1497+ // If we have a generic function type, drop the generic signature; we don't
1498+ // need it for this comparison.
1499+ if (auto genericFn = type->getAs <GenericFunctionType>()) {
1500+ type = FunctionType::get (genericFn->getParams (),
1501+ genericFn->getResult (),
1502+ genericFn->getExtInfo ());
1503+ }
1504+
1505+ // If this declaration is within a type context, bail out.
1506+ if (decl->getDeclContext ()->isTypeContext ()) {
1507+ return Type ();
1508+ }
1509+
1510+ return type;
1511+ }
1512+
1513+ namespace {
1514+ // / Type visitor that extracts the common type between two types, when
1515+ // / possible.
1516+ class CommonTypeVisitor : public TypeVisitor <CommonTypeVisitor, Type, Type> {
1517+ // / Perform a "leaf" match for types, which does not consider the children.
1518+ Type handleLeafMatch (Type type1, Type type2) {
1519+ if (type1->isEqual (type2))
1520+ return type1;
1521+
1522+ return handleMismatch (type1, type2);
1523+ }
1524+
1525+ // / Handle a mismatch between two types.
1526+ Type handleMismatch (Type type1, Type type2) {
1527+ return Type ();
1528+ }
1529+
1530+ public:
1531+ Type visitTupleType (TupleType *tuple1, Type type2) {
1532+ if (tuple1->isEqual (type2))
1533+ return Type (tuple1);
1534+
1535+ auto tuple2 = type2->getAs <TupleType>();
1536+ if (!tuple2) {
1537+ return handleMismatch (Type (tuple1), type2);
1538+ }
1539+
1540+ // Check for structural similarity between the two tuple types.
1541+ auto elements1 = tuple1->getElements ();
1542+ auto elements2 = tuple2->getElements ();
1543+ if (elements1.size () != elements2.size ()) {
1544+ return handleMismatch (Type (tuple1), type2);
1545+ }
1546+
1547+ for (unsigned i : indices (elements1)) {
1548+ const auto &elt1 = elements1[i];
1549+ const auto &elt2 = elements2[i];
1550+ if (elt1.getName () != elt2.getName () ||
1551+ elt1.getParameterFlags () != elt2.getParameterFlags ()) {
1552+ return handleMismatch (Type (tuple1), type2);
1553+ }
1554+ }
1555+
1556+ // Recurse on the element types.
1557+ SmallVector<TupleTypeElt, 4 > newElements;
1558+ newElements.reserve (elements1.size ());
1559+ for (unsigned i : indices (elements1)) {
1560+ const auto &elt1 = elements1[i];
1561+ const auto &elt2 = elements2[i];
1562+ Type elementType = visit (elt1.getRawType (), elt2.getRawType ());
1563+ if (!elementType) {
1564+ return handleMismatch (Type (tuple1), type2);
1565+ }
1566+
1567+ newElements.push_back (elt1.getWithType (elementType));
1568+ }
1569+ return TupleType::get (newElements, tuple1->getASTContext ());
1570+ }
1571+
1572+ Type visitReferenceStorageType (ReferenceStorageType *refStorage1,
1573+ Type type2) {
1574+ if (refStorage1->isEqual (type2))
1575+ return Type (refStorage1);
1576+
1577+ auto refStorage2 = type2->getAs <ReferenceStorageType>();
1578+ if (!refStorage2 ||
1579+ refStorage1->getOwnership () != refStorage2->getOwnership ()) {
1580+ return handleMismatch (Type (refStorage1), type2);
1581+ }
1582+
1583+ Type newReferentType = visit (refStorage1->getReferentType (),
1584+ refStorage2->getReferentType ());
1585+ if (!newReferentType) {
1586+ return handleMismatch (Type (refStorage1), type2);
1587+ }
1588+
1589+ return ReferenceStorageType::get (newReferentType,
1590+ refStorage1->getOwnership (),
1591+ refStorage1->getASTContext ());
1592+ }
1593+
1594+ Type visitAnyMetatypeType (AnyMetatypeType *metatype1, Type type2) {
1595+ if (metatype1->isEqual (type2))
1596+ return Type (metatype1);
1597+
1598+
1599+ auto metatype2 = type2->getAs <AnyMetatypeType>();
1600+ if (!metatype2) {
1601+ return handleMismatch (Type (metatype1), type2);
1602+ }
1603+
1604+ if (metatype1->getKind () != metatype2->getKind () ||
1605+ metatype1->hasRepresentation () != metatype2->hasRepresentation () ||
1606+ (metatype1->hasRepresentation () &&
1607+ metatype2->getRepresentation () != metatype2->getRepresentation ())) {
1608+ return handleMismatch (Type (metatype1), type2);
1609+ }
1610+
1611+ Type newInstanceType = visit (metatype1->getInstanceType (),
1612+ metatype2->getInstanceType ());
1613+ if (!newInstanceType) {
1614+ return handleMismatch (Type (metatype1), type2);
1615+ }
1616+
1617+ Optional<MetatypeRepresentation> representation;
1618+ if (metatype1->hasRepresentation ())
1619+ representation = metatype1->getRepresentation ();
1620+
1621+ if (metatype1->getKind () == TypeKind::Metatype)
1622+ return MetatypeType::get (newInstanceType, representation);
1623+
1624+ assert (metatype1->getKind () == TypeKind::ExistentialMetatype);
1625+ return ExistentialMetatypeType::get (newInstanceType, representation);
1626+ }
1627+
1628+ Type visitFunctionType (FunctionType *function1, Type type2) {
1629+ if (function1->isEqual (type2))
1630+ return Type (function1);
1631+
1632+ auto function2 = type2->getAs <FunctionType>();
1633+ if (!function2 ||
1634+ function1->getExtInfo () != function2->getExtInfo () ||
1635+ function1->getNumParams () != function2->getNumParams ()) {
1636+ return handleMismatch (Type (function1), type2);
1637+ }
1638+
1639+ // Check for a structural match between the parameters.
1640+ auto params1 = function1->getParams ();
1641+ auto params2 = function2->getParams ();
1642+ for (unsigned i : indices (params1)) {
1643+ const auto ¶m1 = params1[i];
1644+ const auto ¶m2 = params2[i];
1645+ if (param1.getLabel () != param2.getLabel () ||
1646+ param1.getParameterFlags () != param2.getParameterFlags ()) {
1647+ return handleMismatch (Type (function1), type2);
1648+ }
1649+ }
1650+
1651+ Type newResultType = visit (function1->getResult (), function2->getResult ());
1652+ if (!newResultType) {
1653+ return handleMismatch (Type (function1), type2);
1654+ }
1655+
1656+ SmallVector<AnyFunctionType::Param, 4 > newParams;
1657+ newParams.reserve (params1.size ());
1658+ for (unsigned i : indices (params1)) {
1659+ const auto ¶m1 = params1[i];
1660+ const auto ¶m2 = params2[i];
1661+ Type newParamType = visit (param1.getPlainType (), param2.getPlainType ());
1662+ if (!newParamType) {
1663+ return handleMismatch (Type (function1), type2);
1664+ }
1665+
1666+ newParams.push_back (AnyFunctionType::Param (newParamType,
1667+ param1.getLabel (),
1668+ param1.getParameterFlags ()));
1669+ }
1670+
1671+ return FunctionType::get (newParams, newResultType, function1->getExtInfo ());
1672+ }
1673+
1674+ Type visitGenericFunctionType (GenericFunctionType *function1, Type type2) {
1675+ llvm_unreachable (" Caller should have eliminated these" );
1676+ }
1677+
1678+ Type visitLValueType (LValueType *lvalue1, Type type2) {
1679+ if (lvalue1->isEqual (type2))
1680+ return Type (lvalue1);
1681+
1682+ auto lvalue2 = type2->getAs <LValueType>();
1683+ if (!lvalue2) {
1684+ return handleMismatch (Type (lvalue1), type2);
1685+ }
1686+
1687+ Type newObjectType =
1688+ visit (lvalue1->getObjectType (), lvalue2->getObjectType ());
1689+ if (!newObjectType) {
1690+ return handleMismatch (Type (lvalue1), type2);
1691+ }
1692+
1693+ return LValueType::get (newObjectType);
1694+ }
1695+
1696+ Type visitInOutType (InOutType *inout1, Type type2) {
1697+ if (inout1->isEqual (type2))
1698+ return Type (inout1);
1699+
1700+ auto inout2 = type2->getAs <InOutType>();
1701+ if (!inout2) {
1702+ return handleMismatch (Type (inout1), type2);
1703+ }
1704+
1705+ Type newObjectType =
1706+ visit (inout1->getObjectType (), inout2->getObjectType ());
1707+ if (!newObjectType) {
1708+ return handleMismatch (Type (inout1), type2);
1709+ }
1710+
1711+ return LValueType::get (newObjectType);
1712+ }
1713+
1714+ Type visitSugarType (SugarType *sugar1, Type type2) {
1715+ if (sugar1->isEqual (type2))
1716+ return Type (sugar1);
1717+
1718+ // FIXME: Reconstitute sugar.
1719+ return visit (Type (sugar1->getSinglyDesugaredType ()), type2);
1720+ }
1721+
1722+ #define FAILURE_CASE (Class ) \
1723+ Type visit##Class##Type(Class##Type *type1, Type type2) { \
1724+ return Type (); \
1725+ }
1726+
1727+ #define LEAF_CASE (Class ) \
1728+ Type visit##Class##Type(Class##Type *type1, Type type2) { \
1729+ return handleLeafMatch (Type (type1), type2); \
1730+ }
1731+
1732+ FAILURE_CASE (Error)
1733+ FAILURE_CASE (Unresolved)
1734+ LEAF_CASE (Builtin)
1735+ LEAF_CASE (Nominal) // FIXME: We can do a more specific match here.
1736+ LEAF_CASE (BoundGeneric) // FIXME: We can do a more specific match here.
1737+ FAILURE_CASE (UnboundGeneric)
1738+ LEAF_CASE (Module)
1739+ LEAF_CASE (DynamicSelf) // FIXME: Can we do better here?
1740+ LEAF_CASE (Substitutable)
1741+ LEAF_CASE (DependentMember)
1742+ LEAF_CASE (SILFunction)
1743+ LEAF_CASE (SILBlockStorage)
1744+ LEAF_CASE (SILBox)
1745+ LEAF_CASE (SILToken)
1746+ LEAF_CASE (ProtocolComposition)
1747+ LEAF_CASE (TypeVariable) // FIXME: Could do better here when we create vars
1748+
1749+ #undef LEAF_CASE
1750+ #undef FAILURE_CASE
1751+ };
1752+
1753+ }
1754+
1755+ Type ConstraintSystem::findCommonOverloadType (
1756+ ArrayRef<OverloadChoice> choices,
1757+ ArrayRef<OverloadChoice> outerAlternatives,
1758+ ConstraintLocator *locator) {
1759+ // Local function to consider this s new overload choice, updating the
1760+ // "common type". Returns true if this overload cannot be integrated into
1761+ // the common type, at which point there is no "common type".
1762+ Type commonType;
1763+ auto considerOverload = [&](const OverloadChoice &overload) -> bool {
1764+ // If we can't even get a type for the overload, there's nothing more to
1765+ // do.
1766+ Type overloadType = getEffectiveOverloadType (overload);
1767+ if (!overloadType) {
1768+ return true ;
1769+ }
1770+
1771+ // If this is the first overload, record it's type as the common type.
1772+ if (!commonType) {
1773+ commonType = overloadType;
1774+ return false ;
1775+ }
1776+
1777+ // Find the common type between the current common type and the new
1778+ // overload's type.
1779+ commonType = CommonTypeVisitor ().visit (commonType, overloadType);
1780+ if (!commonType) {
1781+ return true ;
1782+ }
1783+
1784+ return false ;
1785+ };
1786+
1787+ // Consider all of the choices and outer alternatives.
1788+ for (const auto &choice : choices) {
1789+ if (considerOverload (choice))
1790+ return Type ();
1791+ }
1792+ for (const auto &choice : outerAlternatives) {
1793+ if (considerOverload (choice))
1794+ return Type ();
1795+ }
1796+
1797+ assert (commonType && " We can't get here without having a common type" );
1798+
1799+ // If our common type contains any generic parameters, open them up into
1800+ // type variables.
1801+ if (commonType->hasTypeParameter ()) {
1802+ llvm::SmallDenseMap<const GenericTypeParamType *, TypeVariableType *>
1803+ openedGenericParams;
1804+ commonType = commonType.transformRec ([&](TypeBase *type) -> Optional<Type> {
1805+ if (auto genericParam = dyn_cast<GenericTypeParamType>(type)) {
1806+ auto canGenericParam = GenericTypeParamType::get (
1807+ genericParam->getDepth (),
1808+ genericParam->getIndex (),
1809+ type->getASTContext ());
1810+ auto knownTypeVar = openedGenericParams.find (canGenericParam);
1811+ if (knownTypeVar != openedGenericParams.end ())
1812+ return Type (knownTypeVar->second );
1813+
1814+ auto typeVar = createTypeVariable (locator);
1815+ openedGenericParams[canGenericParam] = typeVar;
1816+ return Type (typeVar);
1817+ }
1818+
1819+ return None;
1820+ });
1821+ }
1822+
1823+ return commonType;
1824+ }
1825+
14671826void ConstraintSystem::addOverloadSet (Type boundType,
14681827 ArrayRef<OverloadChoice> choices,
14691828 DeclContext *useDC,
@@ -1478,6 +1837,13 @@ void ConstraintSystem::addOverloadSet(Type boundType,
14781837 return ;
14791838 }
14801839
1840+ // If we can compute a common type for the overload set, bind that type.
1841+ if (Type commonType = findCommonOverloadType (choices, outerAlternatives,
1842+ locator)) {
1843+ addConstraint (ConstraintKind::Bind, boundType, commonType, locator);
1844+ boundType = commonType;
1845+ }
1846+
14811847 tryOptimizeGenericDisjunction (*this , choices, favoredChoice);
14821848
14831849 SmallVector<OverloadChoice, 4 > scratchChoices;
0 commit comments