@@ -97,6 +97,15 @@ struct TermArgSources {
9797
9898} // namespace
9999
100+ // / Used for generating informative diagnostics.
101+ static Expr *getExprForPartitionOp (const PartitionOp &op) {
102+ SILInstruction *sourceInstr = op.getSourceInst (/* assertNonNull=*/ true );
103+ Expr *expr = sourceInstr->getLoc ().getAsASTNode <Expr>();
104+ assert (expr && " PartitionOp's source location should correspond to"
105+ " an AST node" );
106+ return expr;
107+ }
108+
100109// ===----------------------------------------------------------------------===//
101110// MARK: Main Computation
102111// ===----------------------------------------------------------------------===//
@@ -1300,52 +1309,99 @@ class ConsumeRequireAccumulator {
13001309 std::map<PartitionOp, std::set<PartitionOpAtDistance>>
13011310 requirementsForConsumptions;
13021311
1312+ SILFunction *fn;
1313+
13031314public:
1304- ConsumeRequireAccumulator () {}
1315+ ConsumeRequireAccumulator (SILFunction *fn) : fn(fn ) {}
13051316
13061317 void accumulateConsumedReason (PartitionOp requireOp, const ConsumedReason &consumedReason) {
13071318 for (auto [distance, consumeOps] : consumedReason.consumeOps )
13081319 for (auto consumeOp : consumeOps)
13091320 requirementsForConsumptions[consumeOp].insert ({requireOp, distance});
13101321 }
13111322
1312- // for each consumption in this ConsumeRequireAccumulator, call the passed
1313- // processConsumeOp closure on it, followed immediately by calling the passed
1314- // processRequireOp closure on the top `numRequiresPerConsume` operations
1315- // that access ("require") the region consumed. Sorting is by lowest distance
1316- // first, then arbitrarily. This is used for final diagnostic output.
1317- void forEachConsumeRequire (
1318- llvm::function_ref<void (const PartitionOp& consumeOp, unsigned numProcessed, unsigned numSkipped)>
1319- processConsumeOp,
1320- llvm::function_ref<void(const PartitionOp& requireOp)>
1321- processRequireOp,
1322- unsigned numRequiresPerConsume = UINT_MAX) const {
1323+ void
1324+ emitErrorsForConsumeRequire (unsigned numRequiresPerConsume = UINT_MAX) const {
13231325 for (auto [consumeOp, requireOps] : requirementsForConsumptions) {
13241326 unsigned numProcessed = std::min ({(unsigned ) requireOps.size (),
13251327 (unsigned ) numRequiresPerConsume});
1326- processConsumeOp (consumeOp, numProcessed, requireOps.size () - numProcessed);
1328+
1329+ // First process our consume ops.
1330+ unsigned numDisplayed = numProcessed;
1331+ unsigned numHidden = requireOps.size () - numProcessed;
1332+ if (!tryDiagnoseAsCallSite (consumeOp, numDisplayed, numHidden)) {
1333+ assert (false && " no consumptions besides callsites implemented yet" );
1334+
1335+ // default to more generic diagnostic
1336+ auto expr = getExprForPartitionOp (consumeOp);
1337+ auto diag = fn->getASTContext ().Diags .diagnose (
1338+ expr->getLoc (), diag::consumption_yields_race, numDisplayed,
1339+ numDisplayed != 1 , numHidden > 0 , numHidden);
1340+ if (auto sourceExpr = consumeOp.getSourceExpr ())
1341+ diag.highlight (sourceExpr->getSourceRange ());
1342+ return ;
1343+ }
1344+
13271345 unsigned numRequiresToProcess = numRequiresPerConsume;
13281346 for (auto [requireOp, _] : requireOps) {
13291347 // ensures at most numRequiresPerConsume requires are processed per consume
1330- if (numRequiresToProcess-- == 0 ) break ;
1331- processRequireOp (requireOp);
1348+ if (numRequiresToProcess-- == 0 )
1349+ break ;
1350+ auto expr = getExprForPartitionOp (requireOp);
1351+ fn->getASTContext ()
1352+ .Diags .diagnose (expr->getLoc (), diag::possible_racy_access_site)
1353+ .highlight (expr->getSourceRange ());
13321354 }
13331355 }
13341356 }
13351357
13361358 SWIFT_DEBUG_DUMP { print (llvm::dbgs ()); }
13371359
13381360 void print (llvm::raw_ostream &os) const {
1339- forEachConsumeRequire (
1340- [&](const PartitionOp &consumeOp, unsigned numProcessed,
1341- unsigned numSkipped) {
1342- os << " ┌──╼ CONSUME: " ;
1343- consumeOp.print (os);
1344- },
1345- [&](const PartitionOp &requireOp) {
1346- os << " ├╼ REQUIRE: " ;
1347- requireOp.print (os);
1348- });
1361+ for (auto [consumeOp, requireOps] : requirementsForConsumptions) {
1362+ os << " ┌──╼ CONSUME: " ;
1363+ consumeOp.print (os);
1364+
1365+ for (auto &[requireOp, _] : requireOps) {
1366+ os << " ├╼ REQUIRE: " ;
1367+ requireOp.print (os);
1368+ }
1369+ }
1370+ }
1371+
1372+ private:
1373+ // / Try to interpret this consumeOp as a source-level callsite (ApplyExpr),
1374+ // / and report a diagnostic including actor isolation crossing information
1375+ // / returns true iff one was succesfully formed and emitted.
1376+ bool tryDiagnoseAsCallSite (const PartitionOp &consumeOp,
1377+ unsigned numDisplayed, unsigned numHidden) const {
1378+ SILInstruction *sourceInst =
1379+ consumeOp.getSourceInst (/* assertNonNull=*/ true );
1380+ ApplyExpr *apply = sourceInst->getLoc ().getAsASTNode <ApplyExpr>();
1381+ if (!apply)
1382+ // consumption does not correspond to an apply expression
1383+ return false ;
1384+ auto isolationCrossing = apply->getIsolationCrossing ();
1385+ if (!isolationCrossing) {
1386+ assert (false && " ApplyExprs should be consuming only if"
1387+ " they are isolation crossing" );
1388+ return false ;
1389+ }
1390+ auto argExpr = consumeOp.getSourceExpr ();
1391+ if (!argExpr)
1392+ assert (false &&
1393+ " sourceExpr should be populated for ApplyExpr consumptions" );
1394+
1395+ sourceInst->getFunction ()
1396+ ->getASTContext ()
1397+ .Diags
1398+ .diagnose (argExpr->getLoc (), diag::call_site_consumption_yields_race,
1399+ argExpr->findOriginalType (),
1400+ isolationCrossing.value ().getCallerIsolation (),
1401+ isolationCrossing.value ().getCalleeIsolation (), numDisplayed,
1402+ numDisplayed != 1 , numHidden > 0 , numHidden)
1403+ .highlight (argExpr->getSourceRange ());
1404+ return true ;
13491405 }
13501406};
13511407
@@ -1560,8 +1616,9 @@ class RaceTracer {
15601616 }
15611617
15621618public:
1563- RaceTracer (const BasicBlockData<BlockPartitionState>& blockStates)
1564- : blockStates(blockStates) {}
1619+ RaceTracer (SILFunction *fn,
1620+ const BasicBlockData<BlockPartitionState> &blockStates)
1621+ : blockStates(blockStates), accumulator(fn) {}
15651622
15661623 void traceUseOfConsumedValue (PartitionOp use, TrackableValueID consumedVal) {
15671624 accumulator.accumulateConsumedReason (
@@ -1601,9 +1658,7 @@ class PartitionAnalysis {
16011658 [this ](SILBasicBlock *block) {
16021659 return BlockPartitionState (block, translator);
16031660 }),
1604- raceTracer (blockStates),
1605- function(fn),
1606- solved(false ) {
1661+ raceTracer (fn, blockStates), function(fn), solved(false ) {
16071662 // initialize the entry block as needing an update, and having a partition
16081663 // that places all its non-sendable args in a single region
16091664 blockStates[fn->getEntryBlock ()].needsUpdate = true ;
@@ -1692,15 +1747,6 @@ class PartitionAnalysis {
16921747 return false ;
16931748 }
16941749
1695- // used for generating informative diagnostics
1696- Expr *getExprForPartitionOp (const PartitionOp& op) {
1697- SILInstruction *sourceInstr = op.getSourceInst (/* assertNonNull=*/ true );
1698- Expr *expr = sourceInstr->getLoc ().getAsASTNode <Expr>();
1699- assert (expr && " PartitionOp's source location should correspond to"
1700- " an AST node" );
1701- return expr;
1702- }
1703-
17041750 // once the fixpoint has been solved for, run one more pass over each basic
17051751 // block, reporting any failures due to requiring consumed regions in the
17061752 // fixpoint state
@@ -1710,7 +1756,7 @@ class PartitionAnalysis {
17101756 LLVM_DEBUG (
17111757 llvm::dbgs () << " Emitting diagnostics for function "
17121758 << function->getName () << " \n " );
1713- RaceTracer tracer = blockStates;
1759+ RaceTracer tracer (function, blockStates) ;
17141760
17151761 for (auto [_, blockState] : blockStates) {
17161762 // populate the raceTracer with all requires of consumed valued found
@@ -1738,40 +1784,12 @@ class PartitionAnalysis {
17381784 LLVM_DEBUG (llvm::dbgs () << " Accumulator Complete:\n " ;
17391785 raceTracer.getAccumulator ().print (llvm::dbgs ()););
17401786
1741- // ask the raceTracer to report diagnostics at the consumption sites
1742- // for all the racy requirement sites entered into it above
1743- raceTracer.getAccumulator ().forEachConsumeRequire (
1744- /* diagnoseConsume=*/
1745- [&](const PartitionOp& consumeOp,
1746- unsigned numDisplayed, unsigned numHidden) {
1747-
1748- if (tryDiagnoseAsCallSite (consumeOp, numDisplayed, numHidden))
1749- return ;
1750-
1751- assert (false && " no consumptions besides callsites implemented yet" );
1752-
1753- // default to more generic diagnostic
1754- auto expr = getExprForPartitionOp (consumeOp);
1755- auto diag = function->getASTContext ().Diags .diagnose (
1756- expr->getLoc (), diag::consumption_yields_race,
1757- numDisplayed, numDisplayed != 1 , numHidden > 0 , numHidden);
1758- if (auto sourceExpr = consumeOp.getSourceExpr ())
1759- diag.highlight (sourceExpr->getSourceRange ());
1760- },
1761-
1762- /* diagnoseRequire=*/
1763- [&](const PartitionOp& requireOp) {
1764- auto expr = getExprForPartitionOp (requireOp);
1765- function->getASTContext ().Diags .diagnose (
1766- expr->getLoc (), diag::possible_racy_access_site)
1767- .highlight (expr->getSourceRange ());
1768- },
1787+ // Ask the raceTracer to report diagnostics at the consumption sites for all
1788+ // the racy requirement sites entered into it above.
1789+ raceTracer.getAccumulator ().emitErrorsForConsumeRequire (
17691790 NUM_REQUIREMENTS_TO_DIAGNOSE);
17701791 }
17711792
1772- // try to interpret this consumeOp as a source-level callsite (ApplyExpr),
1773- // and report a diagnostic including actor isolation crossing information
1774- // returns true iff one was succesfully formed and emitted
17751793 bool tryDiagnoseAsCallSite (
17761794 const PartitionOp& consumeOp, unsigned numDisplayed, unsigned numHidden) {
17771795 SILInstruction *sourceInst = consumeOp.getSourceInst (/* assertNonNull=*/ true );
0 commit comments