@@ -24,7 +24,7 @@ function elementwise_trinary!(
2424 opABC:: cutensorOperator_t ;
2525 workspace:: cutensorWorksizePreference_t = WORKSPACE_DEFAULT,
2626 algo:: cutensorAlgo_t = ALGO_DEFAULT,
27- compute_type:: Union{Type, cutensorComputeDescriptor_t , Nothing} = nothing ,
27+ compute_type:: Union{DataType, cutensorComputeDescriptorEnum , Nothing} = nothing ,
2828 plan:: Union{CuTensorPlan, Nothing} = nothing )
2929
3030 actual_compute_type = if compute_type === nothing
@@ -66,7 +66,7 @@ function plan_elementwise_trinary(
6666 jit:: cutensorJitMode_t = JIT_MODE_NONE,
6767 workspace:: cutensorWorksizePreference_t = WORKSPACE_DEFAULT,
6868 algo:: cutensorAlgo_t = ALGO_DEFAULT,
69- compute_type:: Union{Type, cutensorComputeDescriptor_t , Nothing} = nothing )
69+ compute_type:: Union{DataType, cutensorComputeDescriptorEnum , Nothing} = nothing )
7070 ! is_unary (opA) && throw (ArgumentError (" opA must be a unary op!" ))
7171 ! is_unary (opB) && throw (ArgumentError (" opB must be a unary op!" ))
7272 ! is_unary (opC) && throw (ArgumentError (" opC must be a unary op!" ))
@@ -96,7 +96,7 @@ function plan_elementwise_trinary(
9696 descC, modeC, opC,
9797 descD, modeD,
9898 opAB, opABC,
99- compute_type )
99+ actual_compute_type )
100100
101101 plan_pref = Ref {cutensorPlanPreference_t} ()
102102 cutensorCreatePlanPreference (handle (), plan_pref, algo, jit)
@@ -112,7 +112,7 @@ function elementwise_binary!(
112112 @nospecialize (D:: DenseCuArray ), Dinds:: ModeType , opAC:: cutensorOperator_t ;
113113 workspace:: cutensorWorksizePreference_t = WORKSPACE_DEFAULT,
114114 algo:: cutensorAlgo_t = ALGO_DEFAULT,
115- compute_type:: Union{Type, cutensorComputeDescriptor_t , Nothing} = nothing ,
115+ compute_type:: Union{DataType, cutensorComputeDescriptorEnum , Nothing} = nothing ,
116116 plan:: Union{CuTensorPlan, Nothing} = nothing )
117117
118118 actual_compute_type = if compute_type === nothing
@@ -150,7 +150,7 @@ function plan_elementwise_binary(
150150 jit:: cutensorJitMode_t = JIT_MODE_NONE,
151151 workspace:: cutensorWorksizePreference_t = WORKSPACE_DEFAULT,
152152 algo:: cutensorAlgo_t = ALGO_DEFAULT,
153- compute_type:: Union{Type, cutensorComputeDescriptor_t , Nothing} = eltype (C))
153+ compute_type:: Union{DataType, cutensorComputeDescriptorEnum , Nothing} = eltype (C))
154154 ! is_unary (opA) && throw (ArgumentError (" opA must be a unary op!" ))
155155 ! is_unary (opC) && throw (ArgumentError (" opC must be a unary op!" ))
156156 ! is_binary (opAC) && throw (ArgumentError (" opAC must be a binary op!" ))
@@ -189,7 +189,7 @@ function permutation!(
189189 @nospecialize (B:: DenseCuArray ), Binds:: ModeType ;
190190 workspace:: cutensorWorksizePreference_t = WORKSPACE_DEFAULT,
191191 algo:: cutensorAlgo_t = ALGO_DEFAULT,
192- compute_type:: Union{Type, cutensorComputeDescriptor_t , Nothing} = nothing ,
192+ compute_type:: Union{DataType, cutensorComputeDescriptorEnum , Nothing} = nothing ,
193193 plan:: Union{CuTensorPlan, Nothing} = nothing )
194194
195195 actual_compute_type = if compute_type === nothing
@@ -224,8 +224,7 @@ function plan_permutation(
224224 jit:: cutensorJitMode_t = JIT_MODE_NONE,
225225 workspace:: cutensorWorksizePreference_t = WORKSPACE_DEFAULT,
226226 algo:: cutensorAlgo_t = ALGO_DEFAULT,
227- compute_type:: Union{Type, cutensorComputeDescriptor_t, Nothing} = nothing )
228- # !is_unary(opPsi) && throw(ArgumentError("opPsi must be a unary op!"))
227+ compute_type:: Union{DataType, cutensorComputeDescriptorEnum, Nothing} = nothing )
229228 descA = CuTensorDescriptor (A)
230229 descB = CuTensorDescriptor (B)
231230
@@ -260,7 +259,7 @@ function contraction!(
260259 jit:: cutensorJitMode_t = JIT_MODE_NONE,
261260 workspace:: cutensorWorksizePreference_t = WORKSPACE_DEFAULT,
262261 algo:: cutensorAlgo_t = ALGO_DEFAULT,
263- compute_type:: Union{Type, cutensorComputeDescriptor_t , Nothing} = nothing ,
262+ compute_type:: Union{DataType, cutensorComputeDescriptorEnum , Nothing} = nothing ,
264263 plan:: Union{CuTensorPlan, Nothing} = nothing )
265264
266265 actual_compute_type = if compute_type === nothing
@@ -269,7 +268,6 @@ function contraction!(
269268 compute_type
270269 end
271270
272- # XXX : save these as parameters of the plan?
273271 actual_plan = if plan === nothing
274272 plan_contraction (A, Ainds, opA, B, Binds, opB, C, Cinds, opC, opOut;
275273 jit, workspace, algo, compute_type= actual_compute_type)
@@ -298,7 +296,7 @@ function plan_contraction(
298296 jit:: cutensorJitMode_t = JIT_MODE_NONE,
299297 workspace:: cutensorWorksizePreference_t = WORKSPACE_DEFAULT,
300298 algo:: cutensorAlgo_t = ALGO_DEFAULT,
301- compute_type:: Union{Type, cutensorComputeDescriptor_t , Nothing} = nothing )
299+ compute_type:: Union{DataType, cutensorComputeDescriptorEnum , Nothing} = nothing )
302300 ! is_unary (opA) && throw (ArgumentError (" opA must be a unary op!" ))
303301 ! is_unary (opB) && throw (ArgumentError (" opB must be a unary op!" ))
304302 ! is_unary (opC) && throw (ArgumentError (" opC must be a unary op!" ))
@@ -340,7 +338,7 @@ function reduction!(
340338 opReduce:: cutensorOperator_t ;
341339 workspace:: cutensorWorksizePreference_t = WORKSPACE_DEFAULT,
342340 algo:: cutensorAlgo_t = ALGO_DEFAULT,
343- compute_type:: Union{Type, cutensorComputeDescriptor_t , Nothing} = nothing ,
341+ compute_type:: Union{DataType, cutensorComputeDescriptorEnum , Nothing} = nothing ,
344342 plan:: Union{CuTensorPlan, Nothing} = nothing )
345343
346344 actual_compute_type = if compute_type === nothing
@@ -375,7 +373,7 @@ function plan_reduction(
375373 jit:: cutensorJitMode_t = JIT_MODE_NONE,
376374 workspace:: cutensorWorksizePreference_t = WORKSPACE_DEFAULT,
377375 algo:: cutensorAlgo_t = ALGO_DEFAULT,
378- compute_type:: Union{Type, cutensorComputeDescriptor_t , Nothing} = nothing )
376+ compute_type:: Union{DataType, cutensorComputeDescriptorEnum , Nothing} = nothing )
379377 ! is_unary (opA) && throw (ArgumentError (" opA must be a unary op!" ))
380378 ! is_unary (opC) && throw (ArgumentError (" opC must be a unary op!" ))
381379 ! is_binary (opReduce) && throw (ArgumentError (" opReduce must be a binary op!" ))
0 commit comments