@@ -220,16 +220,22 @@ public override void Execute(string functionName, params object[] args)
220220 var ndobject = ( Array ) args . FirstOrDefault ( x => ( x . GetType ( ) . IsArray ) ) ;
221221 long length = ndobject != null ? ndobject . Length : 1 ;
222222
223- var buffers = BuildKernelArguments ( args , kernel , length ) ;
223+ var method = KernelFunctions . FirstOrDefault ( x => ( x . Name == functionName ) ) ;
224+
225+ var buffers = BuildKernelArguments ( method , args , kernel , length ) ;
224226 commands . Execute ( kernel , null , new long [ ] { length } , null , null ) ;
225227
226- for ( int i = 0 ; i < args . Length ; i ++ )
228+ for ( int i = 0 ; i < args . Length ; i ++ )
227229 {
228230 if ( ! args [ i ] . GetType ( ) . IsArray )
229231 continue ;
230232
231- Array r = ( Array ) args [ i ] ;
232- commands . ReadFromMemory ( buffers [ i ] , ref r , true , 0 , null ) ;
233+ var ioMode = method . Parameters . ElementAt ( i ) . Value . IOMode ;
234+ if ( ioMode == IOMode . InOut || ioMode == IOMode . Out )
235+ {
236+ Array r = ( Array ) args [ i ] ;
237+ commands . ReadFromMemory ( buffers [ i ] , ref r , true , 0 , null ) ;
238+ }
233239 buffers [ i ] . Dispose ( ) ;
234240 }
235241 }
@@ -277,11 +283,11 @@ private void ValidateArgs(string functionName, object[] args)
277283
278284 if ( args [ i ] . GetType ( ) . IsPrimitive )
279285 {
280- args [ i ] = Convert . ChangeType ( args [ i ] , Type . GetType ( parameter . Value ) ) ;
286+ args [ i ] = Convert . ChangeType ( args [ i ] , Type . GetType ( parameter . Value . TypeName ) ) ;
281287 }
282288 else if ( args [ i ] . GetType ( ) . IsArray )
283289 {
284- if ( parameter . Value != args [ i ] . GetType ( ) . FullName )
290+ if ( parameter . Value . TypeName != args [ i ] . GetType ( ) . FullName )
285291 throw new ExecutionException ( string . Format ( "Data type mismatch for parameter {0}. Expected is {1} but got {2}" ,
286292 parameter . Key ,
287293 ( parameter . Value ,
@@ -417,7 +423,20 @@ CSharpDecompiler cSharpDecompiler
417423 var k = new KernelFunction ( ) { Name = item . Name } ;
418424 foreach ( var p in item . Parameters )
419425 {
420- k . Parameters . Add ( p . Name , p . Type . FullName ) ;
426+ var isInput = p . GetAttributes ( ) . Any ( x => x . AttributeType . Name == "InputAttribute" ) ;
427+ var isOutput = p . GetAttributes ( ) . Any ( x => x . AttributeType . Name == "OutputAttribute" ) ;
428+ var mode = IOMode . InOut ;
429+ if ( isInput )
430+ mode = IOMode . In ;
431+ if ( isOutput )
432+ mode = IOMode . Out ;
433+ if ( isInput && isOutput )
434+ mode = IOMode . InOut ;
435+ k . Parameters . Add ( p . Name , new FunctionParameter
436+ {
437+ TypeName = p . Type . FullName ,
438+ IOMode = mode
439+ } ) ;
421440 }
422441
423442 KernelFunctions . Add ( k ) ;
@@ -469,12 +488,13 @@ private void CreateKernels(string code)
469488 /// Builds the kernel arguments.
470489 /// </summary>
471490 /// <typeparam name="TSource">The type of the source.</typeparam>
491+ /// <param name="method">The method.</param>
472492 /// <param name="inputs">The inputs.</param>
473493 /// <param name="kernel">The kernel.</param>
474494 /// <param name="length">The length.</param>
475495 /// <param name="returnInputVariable">The return result.</param>
476496 /// <returns></returns>
477- private Dictionary < int , GenericArrayMemory > BuildKernelArguments ( object [ ] inputs , ComputeKernel kernel , long length , int ? returnInputVariable = null )
497+ private Dictionary < int , GenericArrayMemory > BuildKernelArguments ( KernelFunction method , object [ ] inputs , ComputeKernel kernel , long length , int ? returnInputVariable = null )
478498 {
479499 int i = 0 ;
480500 Dictionary < int , GenericArrayMemory > result = new Dictionary < int , GenericArrayMemory > ( ) ;
@@ -484,9 +504,13 @@ private Dictionary<int, GenericArrayMemory> BuildKernelArguments(object[] inputs
484504 int size = 0 ;
485505 if ( item . GetType ( ) . IsArray )
486506 {
487-
488- var datagch = GCHandle . Alloc ( item , GCHandleType . Pinned ) ;
489- GenericArrayMemory mem = new GenericArrayMemory ( _context , ComputeMemoryFlags . ReadWrite | ComputeMemoryFlags . CopyHostPointer , item ) ;
507+ var mode = method . Parameters . ElementAt ( i ) . Value . IOMode ;
508+ var flag = ComputeMemoryFlags . ReadWrite ;
509+ if ( mode == IOMode . Out )
510+ flag |= ComputeMemoryFlags . AllocateHostPointer ;
511+ else
512+ flag |= ComputeMemoryFlags . CopyHostPointer ;
513+ GenericArrayMemory mem = new GenericArrayMemory ( _context , flag , item ) ;
490514 kernel . SetMemoryArgument ( i , mem ) ;
491515 result . Add ( i , mem ) ;
492516 }
0 commit comments