2727#include " swift/AST/TypeVisitor.h"
2828#include " swift/AST/ExistentialLayout.h"
2929#include " swift/Basic/Defer.h"
30+ #include " swift/AST/ASTPrinter.h"
3031
3132using namespace swift ;
3233
@@ -210,7 +211,6 @@ static bool checkAdHocRequirementAccessControl(
210211 return true ;
211212
212213 // === check access control
213- // TODO(distributed): this is for ad-hoc requirements and is likely too naive
214214 if (func->getEffectiveAccess () == decl->getEffectiveAccess ()) {
215215 return false ;
216216 }
@@ -222,6 +222,73 @@ static bool checkAdHocRequirementAccessControl(
222222 return true ;
223223}
224224
225+ static bool diagnoseMissingAdHocProtocolRequirement (ASTContext &C, Identifier identifier, NominalTypeDecl *decl) {
226+ assert (decl);
227+ auto FixitLocation = decl->getBraces ().Start ;
228+
229+ // Prepare the indent (same as `printRequirementStub`)
230+ StringRef ExtraIndent;
231+ StringRef CurrentIndent =
232+ Lexer::getIndentationForLine (C.SourceMgr , decl->getStartLoc (), &ExtraIndent);
233+
234+ llvm::SmallString<128 > Text;
235+ llvm::raw_svector_ostream OS (Text);
236+ ExtraIndentStreamPrinter Printer (OS, CurrentIndent);
237+
238+ Printer.printNewline ();
239+ Printer.printIndent ();
240+ Printer << (decl->getFormalAccess () == AccessLevel::Public ? " public " : " " );
241+
242+ if (identifier == C.Id_remoteCall ) {
243+ Printer << " func remoteCall<Act, Err, Res>("
244+ " on actor: Act, "
245+ " target: RemoteCallTarget, "
246+ " invocation: inout InvocationEncoder, "
247+ " throwing: Err.Type, "
248+ " returning: Res.Type) "
249+ " async throws -> Res "
250+ " where Act: DistributedActor, "
251+ " Act.ID == ActorID, "
252+ " Err: Error, "
253+ " Res: SerializationRequirement" ;
254+ } else if (identifier == C.Id_remoteCallVoid ) {
255+ Printer << " func remoteCallVoid<Act, Err>("
256+ " on actor: Act, "
257+ " target: RemoteCallTarget, "
258+ " invocation: inout InvocationEncoder, "
259+ " throwing: Err.Type"
260+ " ) async throws "
261+ " where Act: DistributedActor, "
262+ " Act.ID == ActorID, "
263+ " Err: Error" ;
264+ } else if (identifier == C.Id_recordArgument ) {
265+ Printer << " mutating func recordArgument<Value: SerializationRequirement>(_ argument: RemoteCallArgument<Value>) throws" ;
266+ } else if (identifier == C.Id_recordReturnType ) {
267+ Printer << " mutating func recordReturnType<Res: SerializationRequirement>(_ resultType: Res.Type) throws" ;
268+ } else if (identifier == C.Id_decodeNextArgument ) {
269+ Printer << " mutating func decodeNextArgument<Argument: SerializationRequirement>() throws -> Argument" ;
270+ } else if (identifier == C.Id_onReturn ) {
271+ Printer << " func onReturn<Success: SerializationRequirement>(value: Success) async throws" ;
272+ } else {
273+ llvm_unreachable (" Unknown identifier for diagnosing ad-hoc missing requirement." );
274+ }
275+
276+ // / Print the "{ <#code#> }" placeholder body
277+ Printer << " {\n " ;
278+ Printer << ExtraIndent << getCodePlaceholder ();
279+ Printer.printNewline ();
280+ Printer.printIndent ();
281+ Printer << " }\n " ;
282+
283+ decl->diagnose (
284+ diag::distributed_actor_system_conformance_missing_adhoc_requirement,
285+ decl, identifier);
286+ decl->diagnose (diag::missing_witnesses_general)
287+ .fixItInsertAfter (FixitLocation, Text.str ());
288+
289+ return true ;
290+ }
291+
225292bool swift::checkDistributedActorSystemAdHocProtocolRequirements (
226293 ASTContext &C,
227294 ProtocolDecl *Proto,
@@ -238,53 +305,21 @@ bool swift::checkDistributedActorSystemAdHocProtocolRequirements(
238305 auto remoteCallDecl =
239306 C.getRemoteCallOnDistributedActorSystem (decl, /* isVoidReturn=*/ false );
240307 if (!remoteCallDecl && diagnose) {
241- auto identifier = C.Id_remoteCall ;
242- decl->diagnose (
243- diag::distributed_actor_system_conformance_missing_adhoc_requirement,
244- decl, identifier);
245- decl->diagnose (
246- diag::note_distributed_actor_system_conformance_missing_adhoc_requirement,
247- Proto->getName (), identifier,
248- " func remoteCall<Act, Err, Res>(\n "
249- " on actor: Act,\n "
250- " target: RemoteCallTarget,\n "
251- " invocation: inout InvocationEncoder,\n "
252- " throwing: Err.Type,\n "
253- " returning: Res.Type\n "
254- " ) async throws -> Res\n "
255- " where Act: DistributedActor,\n "
256- " Act.ID == ActorID,\n "
257- " Err: Error,\n "
258- " Res: SerializationRequirement\n " );
259- anyMissingAdHocRequirements = true ;
308+ anyMissingAdHocRequirements = diagnoseMissingAdHocProtocolRequirement (C, C.Id_remoteCall , decl);
260309 }
261- if (checkAdHocRequirementAccessControl (decl, Proto, remoteCallDecl))
310+ if (checkAdHocRequirementAccessControl (decl, Proto, remoteCallDecl)) {
262311 anyMissingAdHocRequirements = true ;
312+ }
263313
264314 // - remoteCallVoid
265315 auto remoteCallVoidDecl =
266316 C.getRemoteCallOnDistributedActorSystem (decl, /* isVoidReturn=*/ true );
267317 if (!remoteCallVoidDecl && diagnose) {
268- auto identifier = C.Id_remoteCallVoid ;
269- decl->diagnose (
270- diag::distributed_actor_system_conformance_missing_adhoc_requirement,
271- decl, identifier);
272- decl->diagnose (
273- diag::note_distributed_actor_system_conformance_missing_adhoc_requirement,
274- Proto->getName (), identifier,
275- " func remoteCallVoid<Act, Err>(\n "
276- " on actor: Act,\n "
277- " target: RemoteCallTarget,\n "
278- " invocation: inout InvocationEncoder,\n "
279- " throwing: Err.Type\n "
280- " ) async throws\n "
281- " where Act: DistributedActor,\n "
282- " Act.ID == ActorID,\n "
283- " Err: Error\n " );
284- anyMissingAdHocRequirements = true ;
318+ anyMissingAdHocRequirements = diagnoseMissingAdHocProtocolRequirement (C, C.Id_remoteCallVoid , decl);
285319 }
286- if (checkAdHocRequirementAccessControl (decl, Proto, remoteCallVoidDecl))
320+ if (checkAdHocRequirementAccessControl (decl, Proto, remoteCallVoidDecl)) {
287321 anyMissingAdHocRequirements = true ;
322+ }
288323
289324 return anyMissingAdHocRequirements;
290325 }
@@ -295,32 +330,20 @@ bool swift::checkDistributedActorSystemAdHocProtocolRequirements(
295330 // - recordArgument
296331 auto recordArgumentDecl = C.getRecordArgumentOnDistributedInvocationEncoder (decl);
297332 if (!recordArgumentDecl) {
298- auto identifier = C.Id_recordArgument ;
299- decl->diagnose (
300- diag::distributed_actor_system_conformance_missing_adhoc_requirement,
301- decl, identifier);
302- decl->diagnose (diag::note_distributed_actor_system_conformance_missing_adhoc_requirement,
303- Proto->getName (), identifier,
304- " mutating func recordArgument<Value: SerializationRequirement>(_ argument: RemoteCallArgument<Value>) throws\n " );
305- anyMissingAdHocRequirements = true ;
333+ anyMissingAdHocRequirements = diagnoseMissingAdHocProtocolRequirement (C, C.Id_recordArgument , decl);
306334 }
307- if (checkAdHocRequirementAccessControl (decl, Proto, recordArgumentDecl))
335+ if (checkAdHocRequirementAccessControl (decl, Proto, recordArgumentDecl)) {
308336 anyMissingAdHocRequirements = true ;
337+ }
309338
310339 // - recordReturnType
311340 auto recordReturnTypeDecl = C.getRecordReturnTypeOnDistributedInvocationEncoder (decl);
312341 if (!recordReturnTypeDecl) {
313- auto identifier = C.Id_recordReturnType ;
314- decl->diagnose (
315- diag::distributed_actor_system_conformance_missing_adhoc_requirement,
316- decl, identifier);
317- decl->diagnose (diag::note_distributed_actor_system_conformance_missing_adhoc_requirement,
318- Proto->getName (), identifier,
319- " mutating func recordReturnType<Res: SerializationRequirement>(_ resultType: Res.Type) throws\n " );
320- anyMissingAdHocRequirements = true ;
342+ anyMissingAdHocRequirements = diagnoseMissingAdHocProtocolRequirement (C, C.Id_recordReturnType , decl);
321343 }
322- if (checkAdHocRequirementAccessControl (decl, Proto, recordReturnTypeDecl))
344+ if (checkAdHocRequirementAccessControl (decl, Proto, recordReturnTypeDecl)) {
323345 anyMissingAdHocRequirements = true ;
346+ }
324347
325348 return anyMissingAdHocRequirements;
326349 }
@@ -331,17 +354,11 @@ bool swift::checkDistributedActorSystemAdHocProtocolRequirements(
331354 // - decodeNextArgument
332355 auto decodeNextArgumentDecl = C.getDecodeNextArgumentOnDistributedInvocationDecoder (decl);
333356 if (!decodeNextArgumentDecl) {
334- auto identifier = C.Id_decodeNextArgument ;
335- decl->diagnose (
336- diag::distributed_actor_system_conformance_missing_adhoc_requirement,
337- decl, identifier);
338- decl->diagnose (diag::note_distributed_actor_system_conformance_missing_adhoc_requirement,
339- Proto->getName (), identifier,
340- " mutating func decodeNextArgument<Argument: SerializationRequirement>() throws -> Argument\n " );
341- anyMissingAdHocRequirements = true ;
357+ anyMissingAdHocRequirements = diagnoseMissingAdHocProtocolRequirement (C, C.Id_decodeNextArgument , decl);
342358 }
343- if (checkAdHocRequirementAccessControl (decl, Proto, decodeNextArgumentDecl))
359+ if (checkAdHocRequirementAccessControl (decl, Proto, decodeNextArgumentDecl)) {
344360 anyMissingAdHocRequirements = true ;
361+ }
345362
346363 return anyMissingAdHocRequirements;
347364 }
@@ -352,19 +369,11 @@ bool swift::checkDistributedActorSystemAdHocProtocolRequirements(
352369 // - onReturn
353370 auto onReturnDecl = C.getOnReturnOnDistributedTargetInvocationResultHandler (decl);
354371 if (!onReturnDecl) {
355- auto identifier = C.Id_onReturn ;
356- decl->diagnose (
357- diag::distributed_actor_system_conformance_missing_adhoc_requirement,
358- decl, identifier);
359- decl->diagnose (
360- diag::note_distributed_actor_system_conformance_missing_adhoc_requirement,
361- Proto->getName (), identifier,
362- " func onReturn<Success: SerializationRequirement>(value: "
363- " Success) async throws\n " );
364- anyMissingAdHocRequirements = true ;
372+ anyMissingAdHocRequirements = diagnoseMissingAdHocProtocolRequirement (C, C.Id_onReturn , decl);
365373 }
366- if (checkAdHocRequirementAccessControl (decl, Proto, onReturnDecl))
374+ if (checkAdHocRequirementAccessControl (decl, Proto, onReturnDecl)) {
367375 anyMissingAdHocRequirements = true ;
376+ }
368377
369378 return anyMissingAdHocRequirements;
370379 }
0 commit comments