@@ -340,52 +340,19 @@ swift::getDistributedSerializationRequirements(
340340 if (existentialRequirementTy->isAny ())
341341 return true ; // we're done here, any means there are no requirements
342342
343- if (!existentialRequirementTy->isExistentialType ()) {
344- // SerializationRequirement must be an existential type
345- return false ;
346- }
347-
348- ExistentialType *serialReqType = existentialRequirementTy
349- ->castTo <ExistentialType>();
343+ auto *serialReqType = existentialRequirementTy->getAs <ExistentialType>();
350344 if (!serialReqType || serialReqType->hasError ()) {
351345 return false ;
352346 }
353347
354- auto desugaredTy = serialReqType->getConstraintType ()->getDesugaredType ();
355- auto flattenedRequirements =
356- flattenDistributedSerializationTypeToRequiredProtocols (
357- desugaredTy);
358- for (auto p : flattenedRequirements) {
348+ auto layout = serialReqType->getExistentialLayout ();
349+ for (auto p : layout.getProtocols ()) {
359350 requirementProtos.insert (p);
360351 }
361352
362353 return true ;
363354}
364355
365- llvm::SmallPtrSet<ProtocolDecl *, 2 >
366- swift::flattenDistributedSerializationTypeToRequiredProtocols (
367- TypeBase *serializationRequirement) {
368- llvm::SmallPtrSet<ProtocolDecl *, 2 > serializationReqs;
369- if (auto composition =
370- serializationRequirement->getAs <ProtocolCompositionType>()) {
371- for (auto member : composition->getMembers ()) {
372- if (auto comp = member->getAs <ProtocolCompositionType>()) {
373- for (auto protocol :
374- flattenDistributedSerializationTypeToRequiredProtocols (comp)) {
375- serializationReqs.insert (protocol);
376- }
377- } else if (auto *protocol = member->getAs <ProtocolType>()) {
378- serializationReqs.insert (protocol->getDecl ());
379- }
380- }
381- } else {
382- auto protocol = serializationRequirement->castTo <ProtocolType>()->getDecl ();
383- serializationReqs.insert (protocol);
384- }
385-
386- return serializationReqs;
387- }
388-
389356bool swift::checkDistributedSerializationRequirementIsExactlyCodable (
390357 ASTContext &C,
391358 const llvm::SmallPtrSetImpl<ProtocolDecl *> &allRequirements) {
@@ -565,25 +532,19 @@ bool AbstractFunctionDecl::isDistributedActorSystemRemoteCall(bool isVoidReturn)
565532
566533 // --- Check requirement: conforms_to: Act DistributedActor
567534 auto actorReq = requirements[0 ];
568- auto distActorTy = C.getProtocol (KnownProtocolKind::DistributedActor)
569- ->getInterfaceType ()
570- ->getMetatypeInstanceType ();
571535 if (actorReq.getKind () != RequirementKind::Conformance) {
572536 return false ;
573537 }
574- if (!actorReq.getSecondType ()->isEqual (distActorTy )) {
538+ if (!actorReq.getProtocolDecl ()->isSpecificProtocol (KnownProtocolKind::DistributedActor )) {
575539 return false ;
576540 }
577541
578542 // --- Check requirement: conforms_to: Err Error
579543 auto errorReq = requirements[1 ];
580- auto errorTy = C.getProtocol (KnownProtocolKind::Error)
581- ->getInterfaceType ()
582- ->getMetatypeInstanceType ();
583544 if (errorReq.getKind () != RequirementKind::Conformance) {
584545 return false ;
585546 }
586- if (!errorReq.getSecondType ()->isEqual (errorTy )) {
547+ if (!errorReq.getProtocolDecl ()->isSpecificProtocol (KnownProtocolKind::Error )) {
587548 return false ;
588549 }
589550
@@ -598,10 +559,9 @@ bool AbstractFunctionDecl::isDistributedActorSystemRemoteCall(bool isVoidReturn)
598559 assert (ResParam && " Non void function, yet no Res generic parameter found" );
599560 if (auto func = dyn_cast<FuncDecl>(this )) {
600561 auto resultType = func->mapTypeIntoContext (func->getResultInterfaceType ())
601- ->getMetatypeInstanceType ()
602- ->getDesugaredType ();
562+ ->getMetatypeInstanceType ();
603563 auto resultParamType = func->mapTypeIntoContext (
604- ResParam->getInterfaceType ()-> getMetatypeInstanceType ());
564+ ResParam->getDeclaredInterfaceType ());
605565 // The result of the function must be the `Res` generic argument.
606566 if (!resultType->isEqual (resultParamType)) {
607567 return false ;
@@ -797,12 +757,10 @@ AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordArgument() const
797757
798758 // the <Value> of the RemoteCallArgument<Value>
799759 auto remoteCallArgValueGenericTy =
800- mapTypeIntoContext (argGenericParams[0 ]->getInterfaceType ())
801- ->getDesugaredType ()
802- ->getMetatypeInstanceType ();
760+ mapTypeIntoContext (argGenericParams[0 ]->getDeclaredInterfaceType ());
803761 // expected (the <Value> from the recordArgument<Value>)
804762 auto expectedGenericParamTy = mapTypeIntoContext (
805- ArgumentParam->getInterfaceType ()-> getMetatypeInstanceType ());
763+ ArgumentParam->getDeclaredInterfaceType ());
806764
807765 if (!remoteCallArgValueGenericTy->isEqual (expectedGenericParamTy)) {
808766 return false ;
@@ -932,11 +890,10 @@ AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordReturnType() con
932890 // ...
933891
934892 auto resultType = func->mapTypeIntoContext (argumentParam->getInterfaceType ())
935- ->getMetatypeInstanceType ()
936- ->getDesugaredType ();
893+ ->getMetatypeInstanceType ();
937894
938895 auto resultParamType = func->mapTypeIntoContext (
939- ArgumentParam->getInterfaceType ()-> getMetatypeInstanceType ());
896+ ArgumentParam->getDeclaredInterfaceType ());
940897
941898 // The result of the function must be the `Res` generic argument.
942899 if (!resultType->isEqual (resultParamType)) {
@@ -1046,13 +1003,10 @@ AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordErrorType() cons
10461003
10471004 // --- Check requirement: conforms_to: Err Error
10481005 auto errorReq = requirements[0 ];
1049- auto errorTy = C.getProtocol (KnownProtocolKind::Error)
1050- ->getInterfaceType ()
1051- ->getMetatypeInstanceType ();
10521006 if (errorReq.getKind () != RequirementKind::Conformance) {
10531007 return false ;
10541008 }
1055- if (!errorReq.getSecondType ()->isEqual (errorTy )) {
1009+ if (!errorReq.getProtocolDecl ()->isSpecificProtocol (KnownProtocolKind::Error )) {
10561010 return false ;
10571011 }
10581012
@@ -1139,10 +1093,9 @@ AbstractFunctionDecl::isDistributedTargetInvocationDecoderDecodeNextArgument() c
11391093 // --- Check: Argument: SerializationRequirement
11401094 GenericTypeParamDecl *ArgumentParam = genericParams->getParams ()[0 ];
11411095 auto resultType = func->mapTypeIntoContext (func->getResultInterfaceType ())
1142- ->getMetatypeInstanceType ()
1143- ->getDesugaredType ();
1096+ ->getMetatypeInstanceType ();
11441097 auto resultParamType = func->mapTypeIntoContext (
1145- ArgumentParam->getInterfaceType ()-> getMetatypeInstanceType ());
1098+ ArgumentParam->getDeclaredInterfaceType ());
11461099 // The result of the function must be the `Res` generic argument.
11471100 if (!resultType->isEqual (resultParamType)) {
11481101 return false ;
@@ -1237,11 +1190,10 @@ AbstractFunctionDecl::isDistributedTargetInvocationResultHandlerOnReturn() const
12371190 // === Check generic parameters in detail
12381191 // --- Check: Argument: SerializationRequirement
12391192 GenericTypeParamDecl *ArgumentParam = genericParams->getParams ()[0 ];
1240- auto argumentType = func->mapTypeIntoContext (valueParam->getInterfaceType ())
1241- ->getMetatypeInstanceType ()
1242- ->getDesugaredType ();
1193+ auto argumentType = func->mapTypeIntoContext (
1194+ valueParam->getInterfaceType ()->getMetatypeInstanceType ());
12431195 auto resultParamType = func->mapTypeIntoContext (
1244- ArgumentParam->getInterfaceType ()-> getMetatypeInstanceType ());
1196+ ArgumentParam->getDeclaredInterfaceType ());
12451197 // The result of the function must be the `Res` generic argument.
12461198 if (!argumentType->isEqual (resultParamType)) {
12471199 return false ;
@@ -1269,35 +1221,19 @@ swift::extractDistributedSerializationRequirements(
12691221 auto DA = C.getDistributedActorDecl ();
12701222 auto daSerializationReqAssocType =
12711223 DA->getAssociatedType (C.Id_SerializationRequirement );
1272- auto daSystemSerializationReqTy = daSerializationReqAssocType->getInterfaceType ();
12731224
12741225 for (auto req : allRequirements) {
1275- if (req.getSecondType ()->isAny ()) {
1276- continue ;
1277- }
1278- if (!req.getFirstType ()->hasDependentMember ())
1226+ // FIXME: Seems unprincipled
1227+ if (req.getKind () != RequirementKind::SameType &&
1228+ req.getKind () != RequirementKind::Conformance)
12791229 continue ;
12801230
12811231 if (auto dependentMemberType =
1282- req.getFirstType ()->castTo <DependentMemberType>()) {
1283- auto dependentTy =
1284- dependentMemberType->getAssocType ()->getInterfaceType ();
1285-
1286- if (dependentTy->isEqual (daSystemSerializationReqTy)) {
1287- auto requirementProto = req.getSecondType ();
1288- if (auto proto = dyn_cast_or_null<ProtocolDecl>(
1289- requirementProto->getAnyNominal ())) {
1290- serializationReqs.insert (proto);
1291- } else {
1292- auto serialReqType = requirementProto->castTo <ExistentialType>()
1293- ->getConstraintType ()
1294- ->getDesugaredType ();
1295- auto flattenedRequirements =
1296- flattenDistributedSerializationTypeToRequiredProtocols (
1297- serialReqType);
1298- for (auto p : flattenedRequirements) {
1299- serializationReqs.insert (p);
1300- }
1232+ req.getFirstType ()->getAs <DependentMemberType>()) {
1233+ if (dependentMemberType->getAssocType () == daSerializationReqAssocType) {
1234+ auto layout = req.getSecondType ()->getExistentialLayout ();
1235+ for (auto p : layout.getProtocols ()) {
1236+ serializationReqs.insert (p);
13011237 }
13021238 }
13031239 }
0 commit comments