@@ -127,14 +127,35 @@ being slightly faster.
127127"""
128128cudacall
129129
130- # FIXME : can we make this infer properly?
131- cudacall (f, types:: Tuple , args... ; kwargs... ) =
132- cudacall (f, Base. to_tuple_type (types), args... ; kwargs... )
130+ cudacall (f:: F , types:: Tuple , args:: Vararg{Any,N} ; kwargs... ) where {N,F} =
131+ cudacall (f, _to_tuple_type (types), args... ; kwargs... )
132+
133+ function cudacall (f:: F , types:: Type{T} , args:: Vararg{Any,N} ; kwargs... ) where {T,N,F}
134+ convert_arguments (
135+ ((pointers:: Vararg{Any,M} ,) where {M}) -> launch (f, pointers... ; kwargs... ),
136+ types,
137+ args...
138+ )
139+ end
133140
134- function cudacall (f, types:: Type , args... ; kwargs... )
135- convert_arguments (types, args... ) do pointers...
136- launch (f, pointers... ; kwargs... )
141+ # From `julia/base/reflection.jl`, adjusted to add specialization on `t`.
142+ function _to_tuple_type (t)
143+ if isa (t, Tuple) || isa (t, AbstractArray) || isa (t, SimpleVector)
144+ t = Tuple{t... }
145+ end
146+ if isa (t, Type) && t <: Tuple
147+ for p in (Base. unwrap_unionall (t):: DataType ). parameters
148+ if isa (p, Core. TypeofVararg)
149+ p = Base. unwrapva (p)
150+ end
151+ if ! (isa (p, Type) || isa (p, TypeVar))
152+ error (" argument tuple type must contain only types" )
153+ end
154+ end
155+ else
156+ error (" expected tuple type" )
137157 end
158+ t
138159end
139160
140161
0 commit comments