7979//
8080// define void @foo({i32*, i32*}* byval %input) {
8181// %b_param = addrspacecat ptr %input to ptr addrspace(101)
82- // %b_ptr = getelementptr {ptr, ptr}, ptr addrspace(101) %b_param, i64 0, i32 1
83- // %b = load ptr, ptr addrspace(101) %b_ptr
84- // %b_global = addrspacecast ptr %b to ptr addrspace(1)
85- // ; use %b_generic
82+ // %b_ptr = getelementptr {ptr, ptr}, ptr addrspace(101) %b_param, i64 0,
83+ // i32 1 %b = load ptr, ptr addrspace(101) %b_ptr %b_global = addrspacecast
84+ // ptr %b to ptr addrspace(1) ; use %b_generic
8685// }
8786//
88- // Create a local copy of kernel byval parameters used in a way that *might* mutate
89- // the parameter, by storing it in an alloca. Mutations to "grid_constant" parameters
90- // are undefined behaviour, and don't require local copies.
87+ // Create a local copy of kernel byval parameters used in a way that *might*
88+ // mutate the parameter, by storing it in an alloca. Mutations to
89+ // "grid_constant" parameters are undefined behaviour, and don't require
90+ // local copies.
9191//
9292// define void @foo(ptr byval(%struct.s) align 4 %input) {
9393// store i32 42, ptr %input
124124//
125125// define void @foo(ptr byval(%struct.s) %input) {
126126// %input1 = addrspacecast ptr %input to ptr addrspace(101)
127- // ; the following intrinsic converts pointer to generic. We don't use an addrspacecast
128- // ; to prevent generic -> param -> generic from getting cancelled out
129- // %input1.gen = call ptr @llvm.nvvm.ptr.param.to.gen.p0.p101(ptr addrspace(101) %input1)
130- // %call = call i32 @escape (ptr %input1.gen)
131- // ret void
127+ // ; the following intrinsic converts pointer to generic. We don't use an
128+ // addrspacecast ; to prevent generic -> param -> generic from getting
129+ // cancelled out %input1.gen = call ptr
130+ // @llvm.nvvm.ptr.param.to.gen.p0.p101 (ptr addrspace(101) %input1) %call =
131+ // call i32 @escape(ptr %input1.gen) ret void
132132// }
133133//
134134// TODO: merge this pass with NVPTXInferAddressSpaces so that other passes don't
@@ -475,13 +475,18 @@ struct ArgUseChecker : PtrUseVisitor<ArgUseChecker> {
475475 }
476476
477477 void visitStoreInst (StoreInst &SI) {
478- // Storing the pointer escapes it.
479- if (U->get () == SI.getValueOperand ())
480- return PI.setEscapedAndAborted (&SI);
481- // Writes to the pointer are UB w/ __grid_constant__, but do not force a
482- // copy.
483- if (!IsGridConstant)
484- return PI.setAborted (&SI);
478+ // Storing the pointer value (as opposed to storing through it) escapes it.
479+ // For grid_constant params, this is allowed - we can pass the generic
480+ // pointer. For non-grid-constant, this requires a copy.
481+ if (U->get () == SI.getValueOperand ()) {
482+ if (IsGridConstant)
483+ return PI.setEscaped (&SI);
484+ else
485+ return PI.setEscapedAndAborted (&SI);
486+ }
487+ // Writes through the pointer to param space are UB w/ __grid_constant__,
488+ // and param space is read-only on CUDA, so we need to force a copy.
489+ return PI.setAborted (&SI);
485490 }
486491
487492 void visitAddrSpaceCastInst (AddrSpaceCastInst &ASC) {
@@ -529,10 +534,22 @@ void copyByValParam(Function &F, Argument &Arg) {
529534 // the use of the byval parameter with this alloca instruction.
530535 AllocA->setAlignment (
531536 Arg.getParamAlign ().value_or (DL.getPrefTypeAlign (StructType)));
532- Arg.replaceAllUsesWith (AllocA);
533537
538+ // CRITICAL: Must create ArgInParam BEFORE replacing uses of Arg.
539+ // createNVVMInternalAddrspaceWrap needs to use Arg as an operand.
534540 CallInst *ArgInParam = createNVVMInternalAddrspaceWrap (IRB, Arg);
535541
542+ // Replace all uses of Arg with AllocA, except the use in ArgInParam.
543+ // We can't use replaceAllUsesWith because it would replace ArgInParam's
544+ // operand too, creating a circular dependency.
545+ SmallVector<Use *, 8 > UsesToReplace;
546+ for (Use &U : Arg.uses ()) {
547+ if (U.getUser () != ArgInParam)
548+ UsesToReplace.push_back (&U);
549+ }
550+ for (Use *U : UsesToReplace)
551+ U->set (AllocA);
552+
536553 // Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX
537554 // addrspacecast preserves alignment. Since params are constant, this load
538555 // is definitely not volatile.
@@ -578,10 +595,12 @@ static void handleByValParam(const NVPTXTargetMachine &TM, Argument *Arg) {
578595
579596 // We can't access byval arg directly and need a pointer. on sm_70+ we have
580597 // ability to take a pointer to the argument without making a local copy.
581- // However, we're still not allowed to write to it. If the user specified
582- // `__grid_constant__` for the argument, we'll consider escaped pointer as
583- // read-only.
584- if (IsGridConstant || (HasCvtaParam && ArgUseIsReadOnly)) {
598+ // However, param space is read-only, so we can only use this optimization
599+ // if the argument is not written to.
600+ // Grid constant params can escape (pointer passed to functions) but cannot
601+ // have stores. Non-grid-constant params must be fully read-only.
602+ if ((IsGridConstant && !PI.isAborted ()) ||
603+ (HasCvtaParam && ArgUseIsReadOnly)) {
585604 LLVM_DEBUG (dbgs () << " Using non-copy pointer to " << *Arg << " \n " );
586605 // Replace all argument pointer uses (which might include a device function
587606 // call) with a cast to the generic address space using cvta.param
@@ -599,10 +618,15 @@ static void handleByValParam(const NVPTXTargetMachine &TM, Argument *Arg) {
599618 ParamSpaceArg, IRB.getPtrTy (ADDRESS_SPACE_GENERIC),
600619 Arg->getName () + " .gen" );
601620
602- Arg->replaceAllUsesWith (GenericArg);
603-
604- // Do not replace Arg in the cast to param space
605- ParamSpaceArg->setOperand (0 , Arg);
621+ // Collect all uses of Arg except the one in ParamSpaceArg, then replace
622+ // them
623+ SmallVector<Use *> UsesToReplace;
624+ for (Use &U : Arg->uses ()) {
625+ if (U.getUser () != ParamSpaceArg)
626+ UsesToReplace.push_back (&U);
627+ }
628+ for (Use *U : UsesToReplace)
629+ U->set (GenericArg);
606630 } else
607631 copyByValParam (*Func, *Arg);
608632}
0 commit comments