@@ -475,13 +475,18 @@ struct ArgUseChecker : PtrUseVisitor<ArgUseChecker> {
475475 }
476476
477477 void visitStoreInst (StoreInst &SI) {
478- // Storing the pointer escapes it.
479- if (U->get () == SI.getValueOperand ())
480- return PI.setEscapedAndAborted (&SI);
481- // Writes to the pointer are UB w/ __grid_constant__, but do not force a
482- // copy.
483- if (!IsGridConstant)
484- return PI.setAborted (&SI);
478+ // Storing the pointer value (as opposed to storing through it) escapes it.
479+ // For grid_constant params, this is allowed - we can pass the generic
480+ // pointer. For non-grid-constant, this requires a copy.
481+ if (U->get () == SI.getValueOperand ()) {
482+ if (IsGridConstant)
483+ return PI.setEscaped (&SI);
484+ else
485+ return PI.setEscapedAndAborted (&SI);
486+ }
487+ // Writes through the pointer to param space are UB w/ __grid_constant__,
488+ // and param space is read-only on CUDA, so we need to force a copy.
489+ return PI.setAborted (&SI);
485490 }
486491
487492 void visitAddrSpaceCastInst (AddrSpaceCastInst &ASC) {
@@ -529,10 +534,22 @@ void copyByValParam(Function &F, Argument &Arg) {
529534 // the use of the byval parameter with this alloca instruction.
530535 AllocA->setAlignment (
531536 Arg.getParamAlign ().value_or (DL.getPrefTypeAlign (StructType)));
532- Arg.replaceAllUsesWith (AllocA);
533537
538+ // Must create ArgInParam before replacing uses of Arg.
539+ // createNVVMInternalAddrspaceWrap needs to use Arg as an operand.
534540 CallInst *ArgInParam = createNVVMInternalAddrspaceWrap (IRB, Arg);
535541
542+ // Replace all uses of Arg with AllocA, except the use in ArgInParam.
543+ // Note: we can't use replaceAllUsesWith because it would replace ArgInParam's
544+ // operand too, creating a circular dependency.
545+ SmallVector<Use *, 8 > UsesToReplace;
546+ for (Use &U : Arg.uses ()) {
547+ if (U.getUser () != ArgInParam)
548+ UsesToReplace.push_back (&U);
549+ }
550+ for (Use *U : UsesToReplace)
551+ U->set (AllocA);
552+
536553 // Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX
537554 // addrspacecast preserves alignment. Since params are constant, this load
538555 // is definitely not volatile.
@@ -578,10 +595,12 @@ static void handleByValParam(const NVPTXTargetMachine &TM, Argument *Arg) {
578595
579596 // We can't access byval arg directly and need a pointer. on sm_70+ we have
580597 // ability to take a pointer to the argument without making a local copy.
581- // However, we're still not allowed to write to it. If the user specified
582- // `__grid_constant__` for the argument, we'll consider escaped pointer as
583- // read-only.
584- if (IsGridConstant || (HasCvtaParam && ArgUseIsReadOnly)) {
598+ // However, param space is read-only, so we can only use this optimization
599+ // if the argument is not written to.
600+ // Grid constant params can escape (pointer passed to functions) but cannot
601+ // have stores. Non-grid-constant params must be fully read-only.
602+ if ((IsGridConstant && !PI.isAborted ()) ||
603+ (HasCvtaParam && ArgUseIsReadOnly)) {
585604 LLVM_DEBUG (dbgs () << " Using non-copy pointer to " << *Arg << " \n " );
586605 // Replace all argument pointer uses (which might include a device function
587606 // call) with a cast to the generic address space using cvta.param
0 commit comments