Skip to content

Commit 68451a2

Browse files
author
Deepak Battini
committed
SGEMM matrix multiplication example. XArray which is generic array and more use friendly than standard .NET array
1 parent 437b693 commit 68451a2

File tree

13 files changed

+717
-46
lines changed

13 files changed

+717
-46
lines changed

doc-gen/doc-gen.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
</PropertyGroup>
77

88
<ItemGroup>
9-
<PackageReference Include="docfx.console" Version="2.47.0">
9+
<PackageReference Include="docfx.console" Version="2.54.0">
1010
<PrivateAssets>all</PrivateAssets>
1111
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
1212
</PackageReference>
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
using Amplifier.OpenCL;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Text;
5+
6+
namespace AmplifierExamples.Kernels
7+
{
8+
public class SGEMMKernals : OpenCLFunctions
9+
{
10+
[OpenCLKernel]
11+
void MatMul(int M, int N, int K, [Global]float[] A, [Global]float[] B, [Global]float[] C)
12+
{
13+
int globalRow = get_global_id(0);
14+
int globalCol = get_global_id(1);
15+
float acc = 0.0f;
16+
for(int k = 0; k < K; k++)
17+
{
18+
acc += A[k * M + globalRow] * B[globalCol * K + k];
19+
}
20+
21+
C[globalCol * M + globalRow] = acc;
22+
}
23+
}
24+
}

examples/AmplifierExamples/Program.cs

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@ static void Main(string[] args)
1313
//Console.WriteLine("---------------------Basic example---------------------------");
1414
//example = new SimpleKernelEx();
1515
//example.Execute();
16-
//Console.WriteLine("\n---------------------Basic example---------------------------");
16+
Console.WriteLine("\n---------------------Basic example---------------------------");
1717

18-
PrintThreeEmptyLines();
18+
//PrintThreeEmptyLines();
1919

20-
Console.WriteLine("---------------------Array Loop example---------------------");
21-
example = new ArrayForLoopEx();
22-
example.Execute();
23-
Console.WriteLine("\n---------------------Array Loop example---------------------");
20+
//Console.WriteLine("---------------------Array Loop example---------------------");
21+
//example = new ArrayForLoopEx();
22+
//example.Execute();
23+
//Console.WriteLine("\n---------------------Array Loop example---------------------");
2424

2525
//PrintThreeEmptyLines();
2626

@@ -31,10 +31,10 @@ static void Main(string[] args)
3131

3232
//PrintThreeEmptyLines();
3333

34-
Console.WriteLine("--------------------Save and load example-------------------");
35-
example = new SaveAndLoadEx();
36-
example.Execute();
37-
Console.WriteLine("\n--------------------Save and load example-------------------");
34+
//Console.WriteLine("--------------------Save and load example-------------------");
35+
//example = new SaveAndLoadEx();
36+
//example.Execute();
37+
//Console.WriteLine("\n--------------------Save and load example-------------------");
3838

3939
//PrintThreeEmptyLines();
4040

@@ -49,6 +49,9 @@ static void Main(string[] args)
4949
//example = new WithStructEx();
5050
//example.Execute();
5151

52+
Console.WriteLine("---------------------Matrix multiplication example---------------------------");
53+
example = new MatrixMulExample();
54+
example.Execute();
5255

5356
Console.ReadLine();
5457
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
using Amplifier;
2+
using AmplifierExamples.Kernels;
3+
using System;
4+
using System.Collections.Generic;
5+
using System.Text;
6+
7+
namespace AmplifierExamples
8+
{
9+
public class MatrixMulExample : IExample
10+
{
11+
public void Execute()
12+
{
13+
//Create instance of OpenCL compiler
14+
var compiler = new OpenCLCompiler();
15+
16+
//Select a default device
17+
compiler.UseDevice(0);
18+
19+
//Compile the sample kernel
20+
compiler.CompileKernel(typeof(SGEMMKernals), typeof(SimpleKernels));
21+
22+
//Create variable a, b and r
23+
int M = 30;
24+
int N = 30;
25+
int K = 20;
26+
27+
var x = new XArray(new long[] { M, K }, DType.Float32);
28+
var y = new XArray(new long[] { K, M }, DType.Float32);
29+
var z = new OutArray(new long[] { M, N }, DType.Float32) { IsElementWise = false };
30+
31+
//Get the execution engine
32+
var exec = compiler.GetExec();
33+
34+
exec.Fill(x, 2);
35+
exec.Fill(y, 3);
36+
var r = y.ToArray();
37+
38+
exec.MatMul(M, N, K, x, y, z);
39+
r = z.ToArray();
40+
//Print the result
41+
Console.WriteLine("\nResult----");
42+
for (int i = 0; i < z.Count; i++)
43+
{
44+
Console.Write(z[i] + " ");
45+
}
46+
}
47+
}
48+
}

examples/AmplifierExamples/SimpleKernelEx.cs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,33 +35,33 @@ public void Execute()
3535
}
3636

3737
//Create variable a, b and r
38-
var x = new float[] { 1, 2, 3, 4, 5, 6, 7, 8, 9 };
39-
var y = new float[9];
40-
var z = new float[9];
38+
var x = new XArray(new float[] { 1, 2, 3, 4, 5, 6, 7, 8, 9 }).Reshape(3, 3);
39+
var y = new XArray(new long[] { 3, 3 });
40+
var z = new XArray(new long[] { 3, 3 }, DType.Float32);
4141

4242
//Get the execution engine
4343
var exec = compiler.GetExec();
4444

4545
//Execute fill kernel method
4646
exec.Fill(y, 0.5f);
47-
47+
var r = y.ToArray();
4848
//Execute AddData kernel method
49-
exec.AddData(x, y, z);
50-
49+
exec.AddTensor(x, y, z);
50+
r = z.ToArray();
5151
//Execute AddHalf kernel method
52-
var xhalf = Array.ConvertAll(x, v => (half)v);
53-
var yhalf = Array.ConvertAll(y, v => (half)v);
54-
exec.AddHalf(xhalf, yhalf);
55-
z = Array.ConvertAll(yhalf, v => (float)v);
52+
//var xhalf = Array.ConvertAll(x, v => (half)v);
53+
//var yhalf = Array.ConvertAll(y, v => (half)v);
54+
//exec.AddHalf(xhalf, yhalf);
55+
//z = Array.ConvertAll(yhalf, v => (float)v);
5656

5757
//Execuete SAXPY kernel method
5858
exec.SAXPY(x, y, 2f);
5959

6060
//Print the result
6161
Console.WriteLine("\nResult----");
62-
for (int i = 0; i < y.Length; i++)
62+
for (int i = 0; i < y.Count; i++)
6363
{
64-
Console.Write(y.GetValue(i) + " ");
64+
Console.Write(y[i] + " ");
6565
}
6666
}
6767
}

src/Amplifier.Net/Amplifier.csproj

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
<PropertyGroup>
44
<AssemblyName>Amplifier</AssemblyName>
55
<RootNamespace>Amplifier</RootNamespace>
6-
<Version>0.5.0</Version>
6+
<Version>1.0.8</Version>
77
<Authors>Deepak Battini</Authors>
88
<Company>Tech Quantum</Company>
99
<Description>Amplifier allows .NET developers to easily run complex applications with intensive mathematical computation on Intel CPU/GPU, NVIDIA, AMD without writing any additional C kernel code. Write your function in .NET and Amplifier will take care of running it on your favorite hardware.</Description>
@@ -17,6 +17,9 @@
1717
<TargetFramework>netstandard2.0</TargetFramework>
1818
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
1919
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance>
20+
<PackageReleaseNotes>1.0.8-------------------
21+
* XArray with InArray and OutArray which is advance and powerful than standard .NET Array with shaping and transposing.
22+
* SGEMM mat multiplication example</PackageReleaseNotes>
2023
</PropertyGroup>
2124

2225
<PropertyGroup Condition="'$(Configuration)|$(TargetFramework)|$(Platform)'=='Release|net46|AnyCPU'">
@@ -36,11 +39,11 @@
3639
<LangVersion>latest</LangVersion>
3740
</PropertyGroup>
3841
<ItemGroup>
39-
<PackageReference Include="Humanizer.Core" Version="2.7.9" />
42+
<PackageReference Include="Humanizer.Core" Version="2.8.11" />
4043
<PackageReference Include="Newtonsoft.Json" Version="12.0.3" />
4144
<PackageReference Include="Newtonsoft.Json.Bson" Version="1.0.2" />
42-
<PackageReference Include="System.Collections.Immutable" Version="1.7.0" />
43-
<PackageReference Include="System.Reflection.Metadata" Version="1.8.0" />
45+
<PackageReference Include="System.Collections.Immutable" Version="1.7.1" />
46+
<PackageReference Include="System.Reflection.Metadata" Version="1.8.1" />
4447
<PackageReference Include="System.ValueTuple" Version="4.5.0" />
4548
</ItemGroup>
4649

src/Amplifier.Net/DType.cs

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Amplifier
6+
{
7+
public enum DType
8+
{
9+
Float32 = 0,
10+
Float64 = 1,
11+
//Int8 = 2,
12+
//Int16 = 3,
13+
Int32 = 4,
14+
//Int64 = 5,
15+
UInt8 = 6,
16+
//UInt16 = 7,
17+
//UInt32 = 8,
18+
//UInt64 = 8
19+
}
20+
21+
/// <summary>
22+
/// Class DTypeExtensions.
23+
/// </summary>
24+
public static class DTypeExtensions
25+
{
26+
/// <summary>
27+
/// Sizes the specified value.
28+
/// </summary>
29+
/// <param name="value">The value.</param>
30+
/// <returns>System.Int32.</returns>
31+
/// <exception cref="NotSupportedException">Element type " + value + " not supported.</exception>
32+
public static int Size(this DType value)
33+
{
34+
switch (value)
35+
{
36+
case DType.Float32: return 4;
37+
case DType.Float64: return 8;
38+
case DType.Int32: return 4;
39+
case DType.UInt8: return 1;
40+
//case DType.Int8: return 1;
41+
//case DType.UInt32: return 4;
42+
//case DType.Int16: return 2;
43+
//case DType.UInt16: return 2;
44+
//case DType.Int64: return 8;
45+
//case DType.UInt64: return 8;
46+
default:
47+
throw new NotSupportedException("Element type " + value + " not supported.");
48+
}
49+
}
50+
51+
/// <summary>
52+
/// Converts to clrtype.
53+
/// </summary>
54+
/// <param name="value">The value.</param>
55+
/// <returns>Type.</returns>
56+
/// <exception cref="NotSupportedException">Element type " + value + " not supported.</exception>
57+
public static Type ToCLRType(this DType value)
58+
{
59+
switch (value)
60+
{
61+
case DType.Float32: return typeof(float);
62+
case DType.Float64: return typeof(double);
63+
case DType.Int32: return typeof(int);
64+
case DType.UInt8: return typeof(byte);
65+
//case DType.Int8: return typeof(sbyte);
66+
//case DType.UInt32: return typeof(uint);
67+
//case DType.Int16: return typeof(short);
68+
//case DType.UInt16: return typeof(ushort);
69+
//case DType.Int64: return typeof(long);
70+
//case DType.UInt64: return typeof(ulong);
71+
default:
72+
throw new NotSupportedException("Element type " + value + " not supported.");
73+
}
74+
}
75+
}
76+
77+
/// <summary>
78+
/// Class DTypeBuilder.
79+
/// </summary>
80+
public static class DTypeBuilder
81+
{
82+
/// <summary>
83+
/// Froms the type of the color.
84+
/// </summary>
85+
/// <param name="type">The type.</param>
86+
/// <returns>DType.</returns>
87+
/// <exception cref="NotSupportedException">No corresponding DType value for CLR type " + type</exception>
88+
public static DType FromCLRType(Type type)
89+
{
90+
if (type.Name.Contains("Single")) return DType.Float32;
91+
else if (type.Name.Contains("Double")) return DType.Float64;
92+
else if (type.Name.Contains("Int32")) return DType.Int32;
93+
else if (type.Name.Contains("Byte")) return DType.UInt8;
94+
else
95+
throw new NotSupportedException("No corresponding DType value for CLR type " + type);
96+
}
97+
}
98+
}

src/Amplifier.Net/OpenCL/Cloo/ComputeCommandQueue.Added.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,21 @@ public void ReadFromMemory(GenericArrayMemory source, ref Array r, bool blocking
394394
events.Add(new ComputeEvent(newEventHandle[0], this));
395395
}
396396

397+
public void ReadFromMemory(GenericArrayMemory source, ref XArray r, bool blocking, long offset, ICollection<ComputeEventBase> events)
398+
{
399+
IntPtr destinationOffsetPtr = r.NativePtr;
400+
long region = r.Count;
401+
long size = source.Size / region;
402+
CLEventHandle[] eventHandles = ComputeTools.ExtractHandles(events, out var eventWaitListSize);
403+
bool eventsWritable = events != null && !events.IsReadOnly;
404+
CLEventHandle[] newEventHandle = eventsWritable ? new CLEventHandle[1] : null;
405+
ComputeErrorCode error = CL12.EnqueueReadBuffer(Handle, new CLMemoryHandle(source.Handle.Value), blocking, new IntPtr(offset * size), new IntPtr(region * size), destinationOffsetPtr, eventWaitListSize, eventHandles, newEventHandle);
406+
ComputeException.ThrowOnError(error);
407+
408+
if (eventsWritable)
409+
events.Add(new ComputeEvent(newEventHandle[0], this));
410+
}
411+
397412
/// <summary>
398413
/// Enqueues a command to read data from a buffer.
399414
/// </summary>

src/Amplifier.Net/OpenCL/Cloo/GenericArrayMemory.cs

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,40 @@ public GenericArrayMemory(ComputeContext context, ComputeMemoryFlags flags) : ba
1515
{
1616
}
1717

18-
public GenericArrayMemory(ComputeContext context, ComputeMemoryFlags flags, object obj) : base(context, flags)
18+
public GenericArrayMemory(ComputeContext context, ComputeMemoryFlags flags, Array array) : base(context, flags)
1919
{
20-
Array array = (Array)obj;
21-
2220
if (array.Length == 0)
2321
return;
24-
22+
2523
int size = Marshal.SizeOf(array.GetValue(0).GetType()) * array.Length;
2624
var hostPtr = IntPtr.Zero;
2725
if ((flags & (ComputeMemoryFlags.CopyHostPointer | ComputeMemoryFlags.UseHostPointer)) != ComputeMemoryFlags.None)
2826
{
29-
var datagch = GCHandle.Alloc(obj, GCHandleType.Pinned);
27+
var datagch = GCHandle.Alloc(array, GCHandleType.Pinned);
3028
hostPtr = datagch.AddrOfPinnedObject();
3129
}
30+
3231
ComputeErrorCode error = ComputeErrorCode.Success;
3332
var handle = CL12.CreateBuffer(context.Handle, flags, new IntPtr(size), hostPtr, out error);
3433

3534
this.Size = size;
3635
this.Handle = handle;
3736
}
37+
38+
public GenericArrayMemory(ComputeContext context, ComputeMemoryFlags flags, XArray obj) : base(context, flags)
39+
{
40+
var hostPtr = IntPtr.Zero;
41+
long size = obj.DataType.Size() * obj.Count;
42+
if ((flags & (ComputeMemoryFlags.CopyHostPointer | ComputeMemoryFlags.UseHostPointer)) != ComputeMemoryFlags.None)
43+
{
44+
hostPtr = obj.NativePtr;
45+
}
46+
47+
ComputeErrorCode error = ComputeErrorCode.Success;
48+
var handle = CL12.CreateBuffer(context.Handle, flags, new IntPtr(size), hostPtr, out error);
49+
50+
this.Size = size;
51+
this.Handle = handle;
52+
}
3853
}
3954
}

0 commit comments

Comments
 (0)