1- export elementwise_binary!, elementwise_trinary!,
2- permutation!, contraction!, reduction!
3-
41const ModeType = AbstractVector{<: Union{Char, Integer} }
52
63# remove the CUTENSOR_ prefix from some common enums,
@@ -13,7 +10,7 @@ const ModeType = AbstractVector{<:Union{Char, Integer}}
1310is_unary (op:: cutensorOperator_t ) = (op ∈ (OP_IDENTITY, OP_SQRT, OP_RELU, OP_CONJ, OP_RCP))
1411is_binary (op:: cutensorOperator_t ) = (op ∈ (OP_ADD, OP_MUL, OP_MAX, OP_MIN))
1512
16- function elementwise_trinary ! (
13+ function elementwise_trinary_execute ! (
1714 @nospecialize (alpha:: Number ),
1815 @nospecialize (A:: DenseCuArray ), Ainds:: ModeType , opA:: cutensorOperator_t ,
1916 @nospecialize (beta:: Number ),
@@ -43,12 +40,7 @@ function elementwise_trinary!(
4340 plan
4441 end
4542
46- scalar_type = actual_plan. scalar_type
47- cutensorElementwiseTrinaryExecute (handle (), actual_plan,
48- Ref {scalar_type} (alpha), A,
49- Ref {scalar_type} (beta), B,
50- Ref {scalar_type} (gamma), C, D,
51- stream ())
43+ elementwise_trinary_execute! (actual_plan, alpha, A, beta, B, gamma, C, D)
5244
5345 if plan === nothing
5446 CUDA. unsafe_free! (actual_plan)
@@ -57,6 +49,23 @@ function elementwise_trinary!(
5749 return D
5850end
5951
52+ function elementwise_trinary_execute! (plan:: CuTensorPlan ,
53+ @nospecialize (alpha:: Number ),
54+ @nospecialize (A:: DenseCuArray ),
55+ @nospecialize (beta:: Number ),
56+ @nospecialize (B:: DenseCuArray ),
57+ @nospecialize (gamma:: Number ),
58+ @nospecialize (C:: DenseCuArray ),
59+ @nospecialize (D:: DenseCuArray ))
60+ scalar_type = plan. scalar_type
61+ cutensorElementwiseTrinaryExecute (handle (), plan,
62+ Ref {scalar_type} (alpha), A,
63+ Ref {scalar_type} (beta), B,
64+ Ref {scalar_type} (gamma), C, D,
65+ stream ())
66+ return D
67+ end
68+
6069function plan_elementwise_trinary (
6170 @nospecialize (A:: DenseCuArray ), Ainds:: ModeType , opA:: cutensorOperator_t ,
6271 @nospecialize (B:: DenseCuArray ), Binds:: ModeType , opB:: cutensorOperator_t ,
@@ -104,7 +113,7 @@ function plan_elementwise_trinary(
104113 CuTensorPlan (desc[], plan_pref[]; workspacePref= workspace)
105114end
106115
107- function elementwise_binary ! (
116+ function elementwise_binary_execute ! (
108117 @nospecialize (alpha:: Number ),
109118 @nospecialize (A:: DenseCuArray ), Ainds:: ModeType , opA:: cutensorOperator_t ,
110119 @nospecialize (gamma:: Number ),
@@ -130,11 +139,7 @@ function elementwise_binary!(
130139 plan
131140 end
132141
133- scalar_type = actual_plan. scalar_type
134- cutensorElementwiseBinaryExecute (handle (), actual_plan,
135- Ref {scalar_type} (alpha), A,
136- Ref {scalar_type} (gamma), C, D,
137- stream ())
142+ elementwise_binary_execute! (actual_plan, alpha, A, gamma, C, D)
138143
139144 if plan === nothing
140145 CUDA. unsafe_free! (actual_plan)
@@ -143,6 +148,20 @@ function elementwise_binary!(
143148 return D
144149end
145150
151+ function elementwise_binary_execute! (plan:: CuTensorPlan ,
152+ @nospecialize (alpha:: Number ),
153+ @nospecialize (A:: DenseCuArray ),
154+ @nospecialize (gamma:: Number ),
155+ @nospecialize (C:: DenseCuArray ),
156+ @nospecialize (D:: DenseCuArray ))
157+ scalar_type = plan. scalar_type
158+ cutensorElementwiseBinaryExecute (handle (), plan,
159+ Ref {scalar_type} (alpha), A,
160+ Ref {scalar_type} (gamma), C, D,
161+ stream ())
162+ return D
163+ end
164+
146165function plan_elementwise_binary (
147166 @nospecialize (A:: DenseCuArray ), Ainds:: ModeType , opA:: cutensorOperator_t ,
148167 @nospecialize (C:: DenseCuArray ), Cinds:: ModeType , opC:: cutensorOperator_t ,
@@ -183,7 +202,7 @@ function plan_elementwise_binary(
183202 CuTensorPlan (desc[], plan_pref[]; workspacePref= workspace)
184203end
185204
186- function permutation ! (
205+ function permute ! (
187206 @nospecialize (alpha:: Number ),
188207 @nospecialize (A:: DenseCuArray ), Ainds:: ModeType , opA:: cutensorOperator_t ,
189208 @nospecialize (B:: DenseCuArray ), Binds:: ModeType ;
@@ -206,10 +225,7 @@ function permutation!(
206225 plan
207226 end
208227
209- scalar_type = actual_plan. scalar_type
210- cutensorPermute (handle (), actual_plan,
211- Ref {scalar_type} (alpha), A, B,
212- stream ())
228+ permute! (actual_plan, alpha, A, B)
213229
214230 if plan === nothing
215231 CUDA. unsafe_free! (actual_plan)
@@ -218,6 +234,17 @@ function permutation!(
218234 return B
219235end
220236
237+ function permute! (plan:: CuTensorPlan ,
238+ @nospecialize (alpha:: Number ),
239+ @nospecialize (A:: DenseCuArray ),
240+ @nospecialize (B:: DenseCuArray ))
241+ scalar_type = plan. scalar_type
242+ cutensorPermute (handle (), plan,
243+ Ref {scalar_type} (alpha), A, B,
244+ stream ())
245+ return B
246+ end
247+
221248function plan_permutation (
222249 @nospecialize (A:: DenseCuArray ), Ainds:: ModeType , opA:: cutensorOperator_t ,
223250 @nospecialize (B:: DenseCuArray ), Binds:: ModeType ;
@@ -249,7 +276,7 @@ function plan_permutation(
249276 CuTensorPlan (desc[], plan_pref[]; workspacePref= workspace)
250277end
251278
252- function contraction ! (
279+ function contract ! (
253280 @nospecialize (alpha:: Number ),
254281 @nospecialize (A:: DenseCuArray ), Ainds:: ModeType , opA:: cutensorOperator_t ,
255282 @nospecialize (B:: DenseCuArray ), Binds:: ModeType , opB:: cutensorOperator_t ,
@@ -275,11 +302,7 @@ function contraction!(
275302 plan
276303 end
277304
278- scalar_type = actual_plan. scalar_type
279- cutensorContract (handle (), actual_plan,
280- Ref {scalar_type} (alpha), A, B,
281- Ref {scalar_type} (beta), C, C,
282- actual_plan. workspace, sizeof (actual_plan. workspace), stream ())
305+ contract! (actual_plan, alpha, A, B, beta, C)
283306
284307 if plan === nothing
285308 CUDA. unsafe_free! (actual_plan)
@@ -288,6 +311,20 @@ function contraction!(
288311 return C
289312end
290313
314+ function contract! (plan:: CuTensorPlan ,
315+ @nospecialize (alpha:: Number ),
316+ @nospecialize (A:: DenseCuArray ),
317+ @nospecialize (B:: DenseCuArray ),
318+ @nospecialize (beta:: Number ),
319+ @nospecialize (C:: DenseCuArray ))
320+ scalar_type = plan. scalar_type
321+ cutensorContract (handle (), plan,
322+ Ref {scalar_type} (alpha), A, B,
323+ Ref {scalar_type} (beta), C, C,
324+ plan. workspace, sizeof (plan. workspace), stream ())
325+ return C
326+ end
327+
291328function plan_contraction (
292329 @nospecialize (A:: DenseCuArray ), Ainds:: ModeType , opA:: cutensorOperator_t ,
293330 @nospecialize (B:: DenseCuArray ), Binds:: ModeType , opB:: cutensorOperator_t ,
@@ -330,7 +367,7 @@ function plan_contraction(
330367 CuTensorPlan (desc[], plan_pref[]; workspacePref= workspace)
331368end
332369
333- function reduction ! (
370+ function reduce ! (
334371 @nospecialize (alpha:: Number ),
335372 @nospecialize (A:: DenseCuArray ), Ainds:: ModeType , opA:: cutensorOperator_t ,
336373 @nospecialize (beta:: Number ),
@@ -353,11 +390,7 @@ function reduction!(
353390 plan
354391 end
355392
356- scalar_type = actual_plan. scalar_type
357- cutensorReduce (handle (), actual_plan,
358- Ref {scalar_type} (alpha), A,
359- Ref {scalar_type} (beta), C, C,
360- actual_plan. workspace, sizeof (actual_plan. workspace), stream ())
393+ reduce! (actual_plan, alpha, A, beta, C)
361394
362395 if plan === nothing
363396 CUDA. unsafe_free! (actual_plan)
@@ -366,6 +399,19 @@ function reduction!(
366399 return C
367400end
368401
402+ function reduce! (plan:: CuTensorPlan ,
403+ @nospecialize (alpha:: Number ),
404+ @nospecialize (A:: DenseCuArray ),
405+ @nospecialize (beta:: Number ),
406+ @nospecialize (C:: DenseCuArray ))
407+ scalar_type = plan. scalar_type
408+ cutensorReduce (handle (), plan,
409+ Ref {scalar_type} (alpha), A,
410+ Ref {scalar_type} (beta), C, C,
411+ plan. workspace, sizeof (plan. workspace), stream ())
412+ return C
413+ end
414+
369415function plan_reduction (
370416 @nospecialize (A:: DenseCuArray ), Ainds:: ModeType , opA:: cutensorOperator_t ,
371417 @nospecialize (C:: DenseCuArray ), Cinds:: ModeType , opC:: cutensorOperator_t ,
0 commit comments