1+ using BFloat16s: BFloat16
2+
13# For low level cudnn functions that require a pointer to a number
24cptr (x,a:: DenseCuArray{Float64} )= Float64[x]
35cptr (x,a:: DenseCuArray{Float32} )= Float32[x]
46cptr (x,a:: DenseCuArray{Float16} )= Float32[x]
7+ cptr (x,a:: DenseCuArray{BFloat16} )= Float32[x]
58
69# Conversion between Julia and cuDNN datatypes
710cudnnDataType (:: Type{Float16} )= CUDNN_DATA_HALF
811cudnnDataType (:: Type{Float32} )= CUDNN_DATA_FLOAT
912cudnnDataType (:: Type{Float64} )= CUDNN_DATA_DOUBLE
13+ cudnnDataType (:: Type{BFloat16} )= CUDNN_DATA_BFLOAT16
1014cudnnDataType (:: Type{Int8} ) = CUDNN_DATA_INT8
1115cudnnDataType (:: Type{UInt8} ) = CUDNN_DATA_UINT8
1216cudnnDataType (:: Type{Int32} ) = CUDNN_DATA_INT32
@@ -17,6 +21,7 @@ cudnnDataType(::Type{Int32}) = CUDNN_DATA_INT32
1721juliaDataType (a)= (a== CUDNN_DATA_HALF ? Float16 :
1822 a== CUDNN_DATA_FLOAT ? Float32 :
1923 a== CUDNN_DATA_DOUBLE ? Float64 :
24+ a== CUDNN_DATA_BFLOAT16 ? BFloat16 :
2025 a== CUDNN_DATA_INT8 ? Int8 :
2126 a== CUDNN_DATA_UINT8 ? UInt8 :
2227 a== CUDNN_DATA_INT32 ? Int32 : error ())
@@ -35,6 +40,7 @@ scalingParameter(T, val) = error("Unknown tensor type $T")
3540scalingParameter (:: Type{Float16} , val) = Ref {Float32} (val)
3641scalingParameter (:: Type{Float32} , val) = Ref {Float32} (val)
3742scalingParameter (:: Type{Float64} , val) = Ref {Float64} (val)
43+ scalingParameter (:: Type{BFloat16} , val) = Ref {Float32} (val)
3844
3945
4046# Create temporary reserveSpace. Use 128 to avoid alignment issues.
0 commit comments