-
Notifications
You must be signed in to change notification settings - Fork 261
[CUDNN] Support BFloat16 #2987
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
[CUDNN] Support BFloat16 #2987
Conversation
|
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/lib/cudnn/src/util.jl b/lib/cudnn/src/util.jl
index 8923ff9b5..c7ec0c2bd 100644
--- a/lib/cudnn/src/util.jl
+++ b/lib/cudnn/src/util.jl
@@ -4,13 +4,13 @@ using BFloat16s: BFloat16
cptr(x,a::DenseCuArray{Float64})=Float64[x]
cptr(x,a::DenseCuArray{Float32})=Float32[x]
cptr(x,a::DenseCuArray{Float16})=Float32[x]
-cptr(x,a::DenseCuArray{BFloat16})=Float32[x]
+cptr(x, a::DenseCuArray{BFloat16}) = Float32[x]
# Conversion between Julia and cuDNN datatypes
cudnnDataType(::Type{Float16})=CUDNN_DATA_HALF
cudnnDataType(::Type{Float32})=CUDNN_DATA_FLOAT
cudnnDataType(::Type{Float64})=CUDNN_DATA_DOUBLE
-cudnnDataType(::Type{BFloat16})=CUDNN_DATA_BFLOAT16
+cudnnDataType(::Type{BFloat16}) = CUDNN_DATA_BFLOAT16
cudnnDataType(::Type{Int8}) = CUDNN_DATA_INT8
cudnnDataType(::Type{UInt8}) = CUDNN_DATA_UINT8
cudnnDataType(::Type{Int32}) = CUDNN_DATA_INT32
@@ -21,7 +21,7 @@ cudnnDataType(::Type{Int32}) = CUDNN_DATA_INT32
juliaDataType(a)=(a==CUDNN_DATA_HALF ? Float16 :
a==CUDNN_DATA_FLOAT ? Float32 :
a==CUDNN_DATA_DOUBLE ? Float64 :
- a==CUDNN_DATA_BFLOAT16 ? BFloat16 :
+ a == CUDNN_DATA_BFLOAT16 ? BFloat16 :
a==CUDNN_DATA_INT8 ? Int8 :
a==CUDNN_DATA_UINT8 ? UInt8 :
a==CUDNN_DATA_INT32 ? Int32 : error())
diff --git a/lib/cudnn/test/activation.jl b/lib/cudnn/test/activation.jl
index 7b7f2f01a..4164e0231 100644
--- a/lib/cudnn/test/activation.jl
+++ b/lib/cudnn/test/activation.jl
@@ -62,8 +62,8 @@ activationtest(alpha=2)
activationtest(beta=2)
# BFloat16 tests
-(ax,ay) = randn.(BFloat16, (10,10))
-(cx,cy) = CuArray.((ax,ay))
-activationtest(mode=CUDNN_ACTIVATION_SIGMOID)
-activationtest(mode=CUDNN_ACTIVATION_RELU)
-activationtest(mode=CUDNN_ACTIVATION_TANH)
+(ax, ay) = randn.(BFloat16, (10, 10))
+(cx, cy) = CuArray.((ax, ay))
+activationtest(mode = CUDNN_ACTIVATION_SIGMOID)
+activationtest(mode = CUDNN_ACTIVATION_RELU)
+activationtest(mode = CUDNN_ACTIVATION_TANH)
diff --git a/lib/cudnn/test/softmax.jl b/lib/cudnn/test/softmax.jl
index 68967bc1d..ab446813c 100644
--- a/lib/cudnn/test/softmax.jl
+++ b/lib/cudnn/test/softmax.jl
@@ -46,7 +46,7 @@ softmaxtest(algo=CUDNN_SOFTMAX_ACCURATE)
softmaxtest(algo=CUDNN_SOFTMAX_LOG)
# BFloat16 tests
-ax,ay = randn(BFloat16,10,10),randn(BFloat16,10,10)
-cx,cy = CuArray.((ax,ay))
+ax, ay = randn(BFloat16, 10, 10), randn(BFloat16, 10, 10)
+cx, cy = CuArray.((ax, ay))
softmaxtest()
-softmaxtest(algo=CUDNN_SOFTMAX_LOG)
+softmaxtest(algo = CUDNN_SOFTMAX_LOG) |
|
Hm, duplicate of #1092? That one doesn't define the |
|
1.12 failure unrelated, retried CI |
|
Second CI fail also seems unrelated, rerunning. If that succeeds you should rebase on top of |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #2987 +/- ##
===========================================
+ Coverage 76.53% 89.22% +12.68%
===========================================
Files 148 148
Lines 12860 12950 +90
===========================================
+ Hits 9842 11554 +1712
+ Misses 3018 1396 -1622 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
7f8d47b to
6d2bea8
Compare
|
Thanks, Katharine. I’ve updated the BFloat16s compat entry to align with CUDA.jl’s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CUDA.jl Benchmarks
| Benchmark suite | Current: 6d2bea8 | Previous: 1af91be | Ratio |
|---|---|---|---|
latency/precompile |
55259754867 ns |
55207707384.5 ns |
1.00 |
latency/ttfp |
7903571593 ns |
7803466984 ns |
1.01 |
latency/import |
4159302659.5 ns |
4119333235.5 ns |
1.01 |
integration/volumerhs |
9624485 ns |
9616867 ns |
1.00 |
integration/byval/slices=1 |
147379.5 ns |
147131 ns |
1.00 |
integration/byval/slices=3 |
425954 ns |
426158 ns |
1.00 |
integration/byval/reference |
145246 ns |
145358 ns |
1.00 |
integration/byval/slices=2 |
286767 ns |
286555 ns |
1.00 |
integration/cudadevrt |
103742 ns |
103753 ns |
1.00 |
kernel/indexing |
14466 ns |
14494 ns |
1.00 |
kernel/indexing_checked |
14918 ns |
15153 ns |
0.98 |
kernel/occupancy |
673.8471337579617 ns |
670.8662420382166 ns |
1.00 |
kernel/launch |
2243.6666666666665 ns |
2220.4444444444443 ns |
1.01 |
kernel/rand |
15206 ns |
15661 ns |
0.97 |
array/reverse/1d |
20126 ns |
20102.5 ns |
1.00 |
array/reverse/2dL_inplace |
67043 ns |
67011 ns |
1.00 |
array/reverse/1dL |
70416 ns |
70372 ns |
1.00 |
array/reverse/2d |
22004 ns |
22007 ns |
1.00 |
array/reverse/1d_inplace |
9714 ns |
9891 ns |
0.98 |
array/reverse/2d_inplace |
13588 ns |
13546 ns |
1.00 |
array/reverse/2dL |
74045 ns |
73927.5 ns |
1.00 |
array/reverse/1dL_inplace |
66992 ns |
67056 ns |
1.00 |
array/copy |
21226 ns |
20954 ns |
1.01 |
array/iteration/findall/int |
157722.5 ns |
158738.5 ns |
0.99 |
array/iteration/findall/bool |
139755 ns |
140481.5 ns |
0.99 |
array/iteration/findfirst/int |
160903 ns |
161535 ns |
1.00 |
array/iteration/findfirst/bool |
161653.5 ns |
162298 ns |
1.00 |
array/iteration/scalar |
72877 ns |
73003.5 ns |
1.00 |
array/iteration/logical |
215656 ns |
218149.5 ns |
0.99 |
array/iteration/findmin/1d |
52799 ns |
53245 ns |
0.99 |
array/iteration/findmin/2d |
96706 ns |
96825.5 ns |
1.00 |
array/reductions/reduce/Int64/1d |
43181 ns |
43671 ns |
0.99 |
array/reductions/reduce/Int64/dims=1 |
44959.5 ns |
44843.5 ns |
1.00 |
array/reductions/reduce/Int64/dims=2 |
61617 ns |
61899 ns |
1.00 |
array/reductions/reduce/Int64/dims=1L |
89151 ns |
89258 ns |
1.00 |
array/reductions/reduce/Int64/dims=2L |
88131 ns |
88446 ns |
1.00 |
array/reductions/reduce/Float32/1d |
37040 ns |
37742 ns |
0.98 |
array/reductions/reduce/Float32/dims=1 |
52063.5 ns |
42436 ns |
1.23 |
array/reductions/reduce/Float32/dims=2 |
60261 ns |
60098 ns |
1.00 |
array/reductions/reduce/Float32/dims=1L |
52559 ns |
52602 ns |
1.00 |
array/reductions/reduce/Float32/dims=2L |
72421 ns |
72179 ns |
1.00 |
array/reductions/mapreduce/Int64/1d |
43557 ns |
43634 ns |
1.00 |
array/reductions/mapreduce/Int64/dims=1 |
54292 ns |
46988 ns |
1.16 |
array/reductions/mapreduce/Int64/dims=2 |
61725 ns |
61675 ns |
1.00 |
array/reductions/mapreduce/Int64/dims=1L |
89262 ns |
89086 ns |
1.00 |
array/reductions/mapreduce/Int64/dims=2L |
88213 ns |
88150 ns |
1.00 |
array/reductions/mapreduce/Float32/1d |
36825 ns |
36978 ns |
1.00 |
array/reductions/mapreduce/Float32/dims=1 |
42127 ns |
48419.5 ns |
0.87 |
array/reductions/mapreduce/Float32/dims=2 |
60403 ns |
60111 ns |
1.00 |
array/reductions/mapreduce/Float32/dims=1L |
52850 ns |
52878 ns |
1.00 |
array/reductions/mapreduce/Float32/dims=2L |
72360 ns |
72374.5 ns |
1.00 |
array/broadcast |
20240 ns |
20123 ns |
1.01 |
array/copyto!/gpu_to_gpu |
13371 ns |
13003 ns |
1.03 |
array/copyto!/cpu_to_gpu |
215905 ns |
217546 ns |
0.99 |
array/copyto!/gpu_to_cpu |
282914 ns |
285690 ns |
0.99 |
array/accumulate/Int64/1d |
124585.5 ns |
124863 ns |
1.00 |
array/accumulate/Int64/dims=1 |
83607 ns |
83917 ns |
1.00 |
array/accumulate/Int64/dims=2 |
157963.5 ns |
158224 ns |
1.00 |
array/accumulate/Int64/dims=1L |
1710806 ns |
1710808 ns |
1.00 |
array/accumulate/Int64/dims=2L |
966785 ns |
966620 ns |
1.00 |
array/accumulate/Float32/1d |
109529 ns |
109551.5 ns |
1.00 |
array/accumulate/Float32/dims=1 |
80750 ns |
80701.5 ns |
1.00 |
array/accumulate/Float32/dims=2 |
147664 ns |
148055.5 ns |
1.00 |
array/accumulate/Float32/dims=1L |
1619983 ns |
1619581.5 ns |
1.00 |
array/accumulate/Float32/dims=2L |
698418 ns |
698770 ns |
1.00 |
array/construct |
1271.15 ns |
1306.1 ns |
0.97 |
array/random/randn/Float32 |
44845.5 ns |
45766 ns |
0.98 |
array/random/randn!/Float32 |
25197 ns |
25261 ns |
1.00 |
array/random/rand!/Int64 |
27627 ns |
27478 ns |
1.01 |
array/random/rand!/Float32 |
8893.666666666666 ns |
8968 ns |
0.99 |
array/random/rand/Int64 |
30486 ns |
30173.5 ns |
1.01 |
array/random/rand/Float32 |
13553 ns |
13273 ns |
1.02 |
array/permutedims/4d |
55253 ns |
56320.5 ns |
0.98 |
array/permutedims/2d |
54398 ns |
54500 ns |
1.00 |
array/permutedims/3d |
55036 ns |
55121.5 ns |
1.00 |
array/sorting/1d |
2758745 ns |
2758806 ns |
1.00 |
array/sorting/by |
3345779 ns |
3345943 ns |
1.00 |
array/sorting/2d |
1082150.5 ns |
1082452 ns |
1.00 |
cuda/synchronization/stream/auto |
1048 ns |
1022.3076923076923 ns |
1.03 |
cuda/synchronization/stream/nonblocking |
7331.6 ns |
7461.6 ns |
0.98 |
cuda/synchronization/stream/blocking |
813.8105263157895 ns |
805.6082474226804 ns |
1.01 |
cuda/synchronization/context/auto |
1219 ns |
1182.4 ns |
1.03 |
cuda/synchronization/context/nonblocking |
8014.6 ns |
7352 ns |
1.09 |
cuda/synchronization/context/blocking |
903.0208333333334 ns |
901.8522727272727 ns |
1.00 |
This comment was automatically generated by workflow using github-action-benchmark.
|
The cuDNN run on CUDA 13 appears to fail due to running on SM75. |
|
rebase on master? |
e7e97ca to
6359bf0
Compare
|
Done! |
This PR defines methods for making cuDNN work with
BFloat16s.BFloat16.In the following example, I show how the new methods fixes the
BFloat16backward pass ofFlux.logitcrossentropy:Before
Note: Core.BFloat16 === BFloat16s.BFloat16, but I didn't explicitly import in this REPL session.
After defining cudnnDataType(::Type{BFloat16})
After defining scalingParameter(::Type{BFloat16}, val)
I also define a
cptrmethod for consistency, but it appears the function isn't used anywhere.Tests are added for softmax, activations, and pooling. I initially also tested convolutions, normalization, RNNs, and MHA but they don't appear to support BFloat16.
Adding BFloat16s.jl as a dependency does not affect compilation since it's already a dependency of CUDA.jl.
Along with my proposed fix in FluxML/Optimisers.jl#215, this has allowed me to train LLMs in BFloat16 with Flux.jl in Julia v1.12. I am still tinkering with Optimisers.jl, but these together would be a significant unlock for my lab.