Skip to content

Commit f1c3201

Browse files
committed
TODO [SYCL] Fix circular dependency in byval parameter handling in NVPTXLowerArgs
1 parent c3cdb9f commit f1c3201

File tree

2 files changed

+27
-16
lines changed

2 files changed

+27
-16
lines changed

llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -478,10 +478,9 @@ struct ArgUseChecker : PtrUseVisitor<ArgUseChecker> {
478478
// Storing the pointer escapes it.
479479
if (U->get() == SI.getValueOperand())
480480
return PI.setEscapedAndAborted(&SI);
481-
// Writes to the pointer are UB w/ __grid_constant__, but do not force a
482-
// copy.
483-
if (!IsGridConstant)
484-
return PI.setAborted(&SI);
481+
// Writes to the pointer are UB w/ __grid_constant__, but param space is
482+
// read-only on CUDA so we need to force a copy.
483+
return PI.setAborted(&SI);
485484
}
486485

487486
void visitAddrSpaceCastInst(AddrSpaceCastInst &ASC) {
@@ -529,10 +528,22 @@ void copyByValParam(Function &F, Argument &Arg) {
529528
// the use of the byval parameter with this alloca instruction.
530529
AllocA->setAlignment(
531530
Arg.getParamAlign().value_or(DL.getPrefTypeAlign(StructType)));
532-
Arg.replaceAllUsesWith(AllocA);
533531

532+
// CRITICAL: Must create ArgInParam BEFORE replacing uses of Arg.
533+
// createNVVMInternalAddrspaceWrap needs to use Arg as an operand.
534534
CallInst *ArgInParam = createNVVMInternalAddrspaceWrap(IRB, Arg);
535535

536+
// Replace all uses of Arg with AllocA, except the use in ArgInParam.
537+
// We can't use replaceAllUsesWith because it would replace ArgInParam's
538+
// operand too, creating a circular dependency.
539+
SmallVector<Use *, 8> UsesToReplace;
540+
for (Use &U : Arg.uses()) {
541+
if (U.getUser() != ArgInParam)
542+
UsesToReplace.push_back(&U);
543+
}
544+
for (Use *U : UsesToReplace)
545+
U->set(AllocA);
546+
536547
// Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX
537548
// addrspacecast preserves alignment. Since params are constant, this load
538549
// is definitely not volatile.
@@ -578,10 +589,9 @@ static void handleByValParam(const NVPTXTargetMachine &TM, Argument *Arg) {
578589

579590
// We can't access byval arg directly and need a pointer. on sm_70+ we have
580591
// ability to take a pointer to the argument without making a local copy.
581-
// However, we're still not allowed to write to it. If the user specified
582-
// `__grid_constant__` for the argument, we'll consider escaped pointer as
583-
// read-only.
584-
if (IsGridConstant || (HasCvtaParam && ArgUseIsReadOnly)) {
592+
// However, param space is read-only, so we can only use this optimization
593+
// if the argument is not written to.
594+
if ((IsGridConstant || HasCvtaParam) && ArgUseIsReadOnly) {
585595
LLVM_DEBUG(dbgs() << "Using non-copy pointer to " << *Arg << "\n");
586596
// Replace all argument pointer uses (which might include a device function
587597
// call) with a cast to the generic address space using cvta.param
@@ -599,10 +609,14 @@ static void handleByValParam(const NVPTXTargetMachine &TM, Argument *Arg) {
599609
ParamSpaceArg, IRB.getPtrTy(ADDRESS_SPACE_GENERIC),
600610
Arg->getName() + ".gen");
601611

602-
Arg->replaceAllUsesWith(GenericArg);
603-
604-
// Do not replace Arg in the cast to param space
605-
ParamSpaceArg->setOperand(0, Arg);
612+
// Collect all uses of Arg except the one in ParamSpaceArg, then replace them
613+
SmallVector<Use *> UsesToReplace;
614+
for (Use &U : Arg->uses()) {
615+
if (U.getUser() != ParamSpaceArg)
616+
UsesToReplace.push_back(&U);
617+
}
618+
for (Use *U : UsesToReplace)
619+
U->set(GenericArg);
606620
} else
607621
copyByValParam(*Func, *Arg);
608622
}

sycl/test-e2e/FreeFunctionKernels/structs_with_special_types_as_kernel_paramters.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
// RUN: %{build} -o %t.out
22
// RUN: %{run} %t.out
33

4-
// XFAIL: target-nvidia
5-
// XFAIL-TRACKER: https://github.com/intel/llvm/issues/20908
6-
74
// This test verifies whether struct that contains either sycl::local_accesor or
85
// sycl::accessor can be used with free function kernels extension.
96

0 commit comments

Comments
 (0)