Skip to content

Commit 880403a

Browse files
AntonOrestenAnton Oresten
authored andcommitted
add methods for BFloat16 type with tests in cuDNN
1 parent ca67075 commit 880403a

File tree

6 files changed

+27
-1
lines changed

6 files changed

+27
-1
lines changed

lib/cudnn/Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
name = "cuDNN"
22
uuid = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
3-
authors = ["Tim Besard <tim.besard@gmail.com>"]
43
version = "1.4.6"
4+
authors = ["Tim Besard <tim.besard@gmail.com>"]
55

66
[deps]
7+
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
78
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
89
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
910
CUDA_Runtime_Discovery = "1af6417a-86b4-443c-805f-a4643ffb695f"
1011
CUDNN_jll = "62b44479-cb7b-5706-934f-f13b2eb2e645"
1112

1213
[compat]
14+
BFloat16s = "0.6.0"
1315
CEnum = "0.2, 0.3, 0.4, 0.5"
1416
CUDA = "~5.9"
1517
CUDA_Runtime_Discovery = "0.2, 0.3, 1"

lib/cudnn/src/util.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
1+
using BFloat16s: BFloat16
2+
13
# For low level cudnn functions that require a pointer to a number
24
cptr(x,a::DenseCuArray{Float64})=Float64[x]
35
cptr(x,a::DenseCuArray{Float32})=Float32[x]
46
cptr(x,a::DenseCuArray{Float16})=Float32[x]
7+
cptr(x,a::DenseCuArray{BFloat16})=Float32[x]
58

69
# Conversion between Julia and cuDNN datatypes
710
cudnnDataType(::Type{Float16})=CUDNN_DATA_HALF
811
cudnnDataType(::Type{Float32})=CUDNN_DATA_FLOAT
912
cudnnDataType(::Type{Float64})=CUDNN_DATA_DOUBLE
13+
cudnnDataType(::Type{BFloat16})=CUDNN_DATA_BFLOAT16
1014
cudnnDataType(::Type{Int8}) = CUDNN_DATA_INT8
1115
cudnnDataType(::Type{UInt8}) = CUDNN_DATA_UINT8
1216
cudnnDataType(::Type{Int32}) = CUDNN_DATA_INT32
@@ -17,6 +21,7 @@ cudnnDataType(::Type{Int32}) = CUDNN_DATA_INT32
1721
juliaDataType(a)=(a==CUDNN_DATA_HALF ? Float16 :
1822
a==CUDNN_DATA_FLOAT ? Float32 :
1923
a==CUDNN_DATA_DOUBLE ? Float64 :
24+
a==CUDNN_DATA_BFLOAT16 ? BFloat16 :
2025
a==CUDNN_DATA_INT8 ? Int8 :
2126
a==CUDNN_DATA_UINT8 ? UInt8 :
2227
a==CUDNN_DATA_INT32 ? Int32 : error())
@@ -35,6 +40,7 @@ scalingParameter(T, val) = error("Unknown tensor type $T")
3540
scalingParameter(::Type{Float16}, val) = Ref{Float32}(val)
3641
scalingParameter(::Type{Float32}, val) = Ref{Float32}(val)
3742
scalingParameter(::Type{Float64}, val) = Ref{Float64}(val)
43+
scalingParameter(::Type{BFloat16}, val) = Ref{Float32}(val)
3844

3945

4046
# Create temporary reserveSpace. Use 128 to avoid alignment issues.

lib/cudnn/test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[deps]
2+
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
23
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
34
CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2"
45
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"

lib/cudnn/test/activation.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using BFloat16s: BFloat16
12
using cuDNN:
23
cudnnActivationForward,
34
cudnnActivationForward!,
@@ -59,3 +60,10 @@ activationtest(coef=2,mode=CUDNN_ACTIVATION_CLIPPED_RELU)
5960
activationtest(coef=2,mode=CUDNN_ACTIVATION_ELU)
6061
activationtest(alpha=2)
6162
activationtest(beta=2)
63+
64+
# BFloat16 tests
65+
(ax,ay) = randn.(BFloat16, (10,10))
66+
(cx,cy) = CuArray.((ax,ay))
67+
activationtest(mode=CUDNN_ACTIVATION_SIGMOID)
68+
activationtest(mode=CUDNN_ACTIVATION_RELU)
69+
activationtest(mode=CUDNN_ACTIVATION_TANH)

lib/cudnn/test/pooling.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using CUDA, Random
2+
using BFloat16s: BFloat16
23
import NNlib
34
using cuDNN:
45
cudnnPoolingForward,
@@ -87,5 +88,6 @@ pooltest(padding = 1)
8788
pooltest(stride = 1)
8889
pooltest(format = CUDNN_TENSOR_NHWC)
8990
pooltest(dataType = Float16)
91+
pooltest(dataType = BFloat16)
9092
pooltest(alpha = 2)
9193
pooltest(beta = 2)

lib/cudnn/test/softmax.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using BFloat16s: BFloat16
12
using cuDNN:
23
cudnnSoftmaxForward,
34
cudnnSoftmaxForward!,
@@ -43,3 +44,9 @@ softmaxtest(mode=CUDNN_SOFTMAX_MODE_CHANNEL)
4344
softmaxtest(algo=CUDNN_SOFTMAX_FAST)
4445
softmaxtest(algo=CUDNN_SOFTMAX_ACCURATE)
4546
softmaxtest(algo=CUDNN_SOFTMAX_LOG)
47+
48+
# BFloat16 tests
49+
ax,ay = randn(BFloat16,10,10),randn(BFloat16,10,10)
50+
cx,cy = CuArray.((ax,ay))
51+
softmaxtest()
52+
softmaxtest(algo=CUDNN_SOFTMAX_LOG)

0 commit comments

Comments
 (0)