Skip to content

Commit a5ab106

Browse files
committed
TODO [SYCL] Fix circular dependency in byval parameter handling in NVPTXLowerArgs
1 parent c3cdb9f commit a5ab106

File tree

2 files changed

+52
-31
lines changed

2 files changed

+52
-31
lines changed

llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp

Lines changed: 52 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -79,15 +79,15 @@
7979
//
8080
// define void @foo({i32*, i32*}* byval %input) {
8181
// %b_param = addrspacecat ptr %input to ptr addrspace(101)
82-
// %b_ptr = getelementptr {ptr, ptr}, ptr addrspace(101) %b_param, i64 0, i32 1
83-
// %b = load ptr, ptr addrspace(101) %b_ptr
84-
// %b_global = addrspacecast ptr %b to ptr addrspace(1)
85-
// ; use %b_generic
82+
// %b_ptr = getelementptr {ptr, ptr}, ptr addrspace(101) %b_param, i64 0,
83+
// i32 1 %b = load ptr, ptr addrspace(101) %b_ptr %b_global = addrspacecast
84+
// ptr %b to ptr addrspace(1) ; use %b_generic
8685
// }
8786
//
88-
// Create a local copy of kernel byval parameters used in a way that *might* mutate
89-
// the parameter, by storing it in an alloca. Mutations to "grid_constant" parameters
90-
// are undefined behaviour, and don't require local copies.
87+
// Create a local copy of kernel byval parameters used in a way that *might*
88+
// mutate the parameter, by storing it in an alloca. Mutations to
89+
// "grid_constant" parameters are undefined behaviour, and don't require
90+
// local copies.
9191
//
9292
// define void @foo(ptr byval(%struct.s) align 4 %input) {
9393
// store i32 42, ptr %input
@@ -124,11 +124,11 @@
124124
//
125125
// define void @foo(ptr byval(%struct.s) %input) {
126126
// %input1 = addrspacecast ptr %input to ptr addrspace(101)
127-
// ; the following intrinsic converts pointer to generic. We don't use an addrspacecast
128-
// ; to prevent generic -> param -> generic from getting cancelled out
129-
// %input1.gen = call ptr @llvm.nvvm.ptr.param.to.gen.p0.p101(ptr addrspace(101) %input1)
130-
// %call = call i32 @escape(ptr %input1.gen)
131-
// ret void
127+
// ; the following intrinsic converts pointer to generic. We don't use an
128+
// addrspacecast ; to prevent generic -> param -> generic from getting
129+
// cancelled out %input1.gen = call ptr
130+
// @llvm.nvvm.ptr.param.to.gen.p0.p101(ptr addrspace(101) %input1) %call =
131+
// call i32 @escape(ptr %input1.gen) ret void
132132
// }
133133
//
134134
// TODO: merge this pass with NVPTXInferAddressSpaces so that other passes don't
@@ -475,13 +475,18 @@ struct ArgUseChecker : PtrUseVisitor<ArgUseChecker> {
475475
}
476476

477477
void visitStoreInst(StoreInst &SI) {
478-
// Storing the pointer escapes it.
479-
if (U->get() == SI.getValueOperand())
480-
return PI.setEscapedAndAborted(&SI);
481-
// Writes to the pointer are UB w/ __grid_constant__, but do not force a
482-
// copy.
483-
if (!IsGridConstant)
484-
return PI.setAborted(&SI);
478+
// Storing the pointer value (as opposed to storing through it) escapes it.
479+
// For grid_constant params, this is allowed - we can pass the generic
480+
// pointer. For non-grid-constant, this requires a copy.
481+
if (U->get() == SI.getValueOperand()) {
482+
if (IsGridConstant)
483+
return PI.setEscaped(&SI);
484+
else
485+
return PI.setEscapedAndAborted(&SI);
486+
}
487+
// Writes through the pointer to param space are UB w/ __grid_constant__,
488+
// and param space is read-only on CUDA, so we need to force a copy.
489+
return PI.setAborted(&SI);
485490
}
486491

487492
void visitAddrSpaceCastInst(AddrSpaceCastInst &ASC) {
@@ -529,10 +534,22 @@ void copyByValParam(Function &F, Argument &Arg) {
529534
// the use of the byval parameter with this alloca instruction.
530535
AllocA->setAlignment(
531536
Arg.getParamAlign().value_or(DL.getPrefTypeAlign(StructType)));
532-
Arg.replaceAllUsesWith(AllocA);
533537

538+
// CRITICAL: Must create ArgInParam BEFORE replacing uses of Arg.
539+
// createNVVMInternalAddrspaceWrap needs to use Arg as an operand.
534540
CallInst *ArgInParam = createNVVMInternalAddrspaceWrap(IRB, Arg);
535541

542+
// Replace all uses of Arg with AllocA, except the use in ArgInParam.
543+
// We can't use replaceAllUsesWith because it would replace ArgInParam's
544+
// operand too, creating a circular dependency.
545+
SmallVector<Use *, 8> UsesToReplace;
546+
for (Use &U : Arg.uses()) {
547+
if (U.getUser() != ArgInParam)
548+
UsesToReplace.push_back(&U);
549+
}
550+
for (Use *U : UsesToReplace)
551+
U->set(AllocA);
552+
536553
// Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX
537554
// addrspacecast preserves alignment. Since params are constant, this load
538555
// is definitely not volatile.
@@ -578,10 +595,12 @@ static void handleByValParam(const NVPTXTargetMachine &TM, Argument *Arg) {
578595

579596
// We can't access byval arg directly and need a pointer. on sm_70+ we have
580597
// ability to take a pointer to the argument without making a local copy.
581-
// However, we're still not allowed to write to it. If the user specified
582-
// `__grid_constant__` for the argument, we'll consider escaped pointer as
583-
// read-only.
584-
if (IsGridConstant || (HasCvtaParam && ArgUseIsReadOnly)) {
598+
// However, param space is read-only, so we can only use this optimization
599+
// if the argument is not written to.
600+
// Grid constant params can escape (pointer passed to functions) but cannot
601+
// have stores. Non-grid-constant params must be fully read-only.
602+
if ((IsGridConstant && !PI.isAborted()) ||
603+
(HasCvtaParam && ArgUseIsReadOnly)) {
585604
LLVM_DEBUG(dbgs() << "Using non-copy pointer to " << *Arg << "\n");
586605
// Replace all argument pointer uses (which might include a device function
587606
// call) with a cast to the generic address space using cvta.param
@@ -599,10 +618,15 @@ static void handleByValParam(const NVPTXTargetMachine &TM, Argument *Arg) {
599618
ParamSpaceArg, IRB.getPtrTy(ADDRESS_SPACE_GENERIC),
600619
Arg->getName() + ".gen");
601620

602-
Arg->replaceAllUsesWith(GenericArg);
603-
604-
// Do not replace Arg in the cast to param space
605-
ParamSpaceArg->setOperand(0, Arg);
621+
// Collect all uses of Arg except the one in ParamSpaceArg, then replace
622+
// them
623+
SmallVector<Use *> UsesToReplace;
624+
for (Use &U : Arg->uses()) {
625+
if (U.getUser() != ParamSpaceArg)
626+
UsesToReplace.push_back(&U);
627+
}
628+
for (Use *U : UsesToReplace)
629+
U->set(GenericArg);
606630
} else
607631
copyByValParam(*Func, *Arg);
608632
}

sycl/test-e2e/FreeFunctionKernels/structs_with_special_types_as_kernel_paramters.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
// RUN: %{build} -o %t.out
22
// RUN: %{run} %t.out
33

4-
// XFAIL: target-nvidia
5-
// XFAIL-TRACKER: https://github.com/intel/llvm/issues/20908
6-
74
// This test verifies whether struct that contains either sycl::local_accesor or
85
// sycl::accessor can be used with free function kernels extension.
96

0 commit comments

Comments
 (0)