Skip to content

Commit 0605577

Browse files
committed
[NVPTX] fix ptr escapes and byval params in kernel args
1 parent c3cdb9f commit 0605577

File tree

1 file changed

+31
-12
lines changed

1 file changed

+31
-12
lines changed

llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -475,13 +475,18 @@ struct ArgUseChecker : PtrUseVisitor<ArgUseChecker> {
475475
}
476476

477477
void visitStoreInst(StoreInst &SI) {
478-
// Storing the pointer escapes it.
479-
if (U->get() == SI.getValueOperand())
480-
return PI.setEscapedAndAborted(&SI);
481-
// Writes to the pointer are UB w/ __grid_constant__, but do not force a
482-
// copy.
483-
if (!IsGridConstant)
484-
return PI.setAborted(&SI);
478+
// Storing the pointer value (as opposed to storing through it) escapes it.
479+
// For grid_constant params, this is allowed - we can pass the generic
480+
// pointer. For non-grid-constant, this requires a copy.
481+
if (U->get() == SI.getValueOperand()) {
482+
if (IsGridConstant)
483+
return PI.setEscaped(&SI);
484+
else
485+
return PI.setEscapedAndAborted(&SI);
486+
}
487+
// Writes through the pointer to param space are UB w/ __grid_constant__,
488+
// and param space is read-only on CUDA, so we need to force a copy.
489+
return PI.setAborted(&SI);
485490
}
486491

487492
void visitAddrSpaceCastInst(AddrSpaceCastInst &ASC) {
@@ -529,10 +534,22 @@ void copyByValParam(Function &F, Argument &Arg) {
529534
// the use of the byval parameter with this alloca instruction.
530535
AllocA->setAlignment(
531536
Arg.getParamAlign().value_or(DL.getPrefTypeAlign(StructType)));
532-
Arg.replaceAllUsesWith(AllocA);
533537

538+
// Must create ArgInParam before replacing uses of Arg.
539+
// createNVVMInternalAddrspaceWrap needs to use Arg as an operand.
534540
CallInst *ArgInParam = createNVVMInternalAddrspaceWrap(IRB, Arg);
535541

542+
// Replace all uses of Arg with AllocA, except the use in ArgInParam.
543+
// Note: we can't use replaceAllUsesWith because it would replace ArgInParam's
544+
// operand too, creating a circular dependency.
545+
SmallVector<Use *, 8> UsesToReplace;
546+
for (Use &U : Arg.uses()) {
547+
if (U.getUser() != ArgInParam)
548+
UsesToReplace.push_back(&U);
549+
}
550+
for (Use *U : UsesToReplace)
551+
U->set(AllocA);
552+
536553
// Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX
537554
// addrspacecast preserves alignment. Since params are constant, this load
538555
// is definitely not volatile.
@@ -578,10 +595,12 @@ static void handleByValParam(const NVPTXTargetMachine &TM, Argument *Arg) {
578595

579596
// We can't access byval arg directly and need a pointer. on sm_70+ we have
580597
// ability to take a pointer to the argument without making a local copy.
581-
// However, we're still not allowed to write to it. If the user specified
582-
// `__grid_constant__` for the argument, we'll consider escaped pointer as
583-
// read-only.
584-
if (IsGridConstant || (HasCvtaParam && ArgUseIsReadOnly)) {
598+
// However, param space is read-only, so we can only use this optimization
599+
// if the argument is not written to.
600+
// Grid constant params can escape (pointer passed to functions) but cannot
601+
// have stores. Non-grid-constant params must be fully read-only.
602+
if ((IsGridConstant && !PI.isAborted()) ||
603+
(HasCvtaParam && ArgUseIsReadOnly)) {
585604
LLVM_DEBUG(dbgs() << "Using non-copy pointer to " << *Arg << "\n");
586605
// Replace all argument pointer uses (which might include a device function
587606
// call) with a cast to the generic address space using cvta.param

0 commit comments

Comments
 (0)