@@ -478,10 +478,9 @@ struct ArgUseChecker : PtrUseVisitor<ArgUseChecker> {
478478 // Storing the pointer escapes it.
479479 if (U->get () == SI.getValueOperand ())
480480 return PI.setEscapedAndAborted (&SI);
481- // Writes to the pointer are UB w/ __grid_constant__, but do not force a
482- // copy.
483- if (!IsGridConstant)
484- return PI.setAborted (&SI);
481+ // Writes to the pointer are UB w/ __grid_constant__, but param space is
482+ // read-only on CUDA so we need to force a copy.
483+ return PI.setAborted (&SI);
485484 }
486485
487486 void visitAddrSpaceCastInst (AddrSpaceCastInst &ASC) {
@@ -529,10 +528,22 @@ void copyByValParam(Function &F, Argument &Arg) {
529528 // the use of the byval parameter with this alloca instruction.
530529 AllocA->setAlignment (
531530 Arg.getParamAlign ().value_or (DL.getPrefTypeAlign (StructType)));
532- Arg.replaceAllUsesWith (AllocA);
533531
532+ // CRITICAL: Must create ArgInParam BEFORE replacing uses of Arg.
533+ // createNVVMInternalAddrspaceWrap needs to use Arg as an operand.
534534 CallInst *ArgInParam = createNVVMInternalAddrspaceWrap (IRB, Arg);
535535
536+ // Replace all uses of Arg with AllocA, except the use in ArgInParam.
537+ // We can't use replaceAllUsesWith because it would replace ArgInParam's
538+ // operand too, creating a circular dependency.
539+ SmallVector<Use *, 8 > UsesToReplace;
540+ for (Use &U : Arg.uses ()) {
541+ if (U.getUser () != ArgInParam)
542+ UsesToReplace.push_back (&U);
543+ }
544+ for (Use *U : UsesToReplace)
545+ U->set (AllocA);
546+
536547 // Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX
537548 // addrspacecast preserves alignment. Since params are constant, this load
538549 // is definitely not volatile.
@@ -578,10 +589,9 @@ static void handleByValParam(const NVPTXTargetMachine &TM, Argument *Arg) {
578589
579590 // We can't access byval arg directly and need a pointer. on sm_70+ we have
580591 // ability to take a pointer to the argument without making a local copy.
581- // However, we're still not allowed to write to it. If the user specified
582- // `__grid_constant__` for the argument, we'll consider escaped pointer as
583- // read-only.
584- if (IsGridConstant || (HasCvtaParam && ArgUseIsReadOnly)) {
592+ // However, param space is read-only, so we can only use this optimization
593+ // if the argument is not written to.
594+ if ((IsGridConstant || HasCvtaParam) && ArgUseIsReadOnly) {
585595 LLVM_DEBUG (dbgs () << " Using non-copy pointer to " << *Arg << " \n " );
586596 // Replace all argument pointer uses (which might include a device function
587597 // call) with a cast to the generic address space using cvta.param
@@ -599,10 +609,14 @@ static void handleByValParam(const NVPTXTargetMachine &TM, Argument *Arg) {
599609 ParamSpaceArg, IRB.getPtrTy (ADDRESS_SPACE_GENERIC),
600610 Arg->getName () + " .gen" );
601611
602- Arg->replaceAllUsesWith (GenericArg);
603-
604- // Do not replace Arg in the cast to param space
605- ParamSpaceArg->setOperand (0 , Arg);
612+ // Collect all uses of Arg except the one in ParamSpaceArg, then replace them
613+ SmallVector<Use *> UsesToReplace;
614+ for (Use &U : Arg->uses ()) {
615+ if (U.getUser () != ParamSpaceArg)
616+ UsesToReplace.push_back (&U);
617+ }
618+ for (Use *U : UsesToReplace)
619+ U->set (GenericArg);
606620 } else
607621 copyByValParam (*Func, *Arg);
608622}
0 commit comments