diff --git a/include/flexflow/accessor.h b/include/flexflow/accessor.h index 65ab33b51..121140c92 100644 --- a/include/flexflow/accessor.h +++ b/include/flexflow/accessor.h @@ -5,16 +5,25 @@ #if defined(FF_USE_CUDA) #include +#include #elif defined(FF_USE_HIP_CUDA) #include +#include #elif defined(FF_USE_HIP_ROCM) #include +#include #endif // using namespace Legion; namespace FlexFlow { +#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA) +typedef __nv_bfloat16 __ff_bfloat16; +#elif defined(FF_USE_HIP_ROCM) +typedef hip_bfloat16 __ff_bfloat16; +#endif + template using AccessorRO = Legion::FieldAccessor>; @@ -61,6 +70,7 @@ class GenericTensorAccessorW { float *get_float_ptr() const; double *get_double_ptr() const; half *get_half_ptr() const; + __ff_bfloat16 *get_bfloat16_ptr() const; char *get_byte_ptr() const; DataType data_type; Legion::Domain domain; @@ -80,6 +90,7 @@ class GenericTensorAccessorR { float const *get_float_ptr() const; double const *get_double_ptr() const; half const *get_half_ptr() const; + __ff_bfloat16 const *get_bfloat16_ptr() const; char const *get_byte_ptr() const; DataType data_type; Legion::Domain domain; diff --git a/include/flexflow/ffconst.h b/include/flexflow/ffconst.h index 57854c722..5abcafabd 100644 --- a/include/flexflow/ffconst.h +++ b/include/flexflow/ffconst.h @@ -36,7 +36,7 @@ enum DataType { DT_INT4 = 46, DT_INT8 = 47, DT_NONE = 49, - DT_BFLOAT16 = 50, // placeholder for bfloat16 + DT_BFLOAT16 = 50, }; enum LossType { @@ -287,4 +287,11 @@ enum { PEFT_MODEL_ID_FIRST_VALID = 6000000, PEFT_MODEL_ID_LAST_VALID = 6999999 }; + +enum InferencePrecision { + INFERENCE_FLOAT = 800, + INFERENCE_HALF = 801, + INFERENCE_BFLOAT16 = 802, +}; + #endif // _FLEXFLOW_CONST_H_ diff --git a/include/flexflow/model.h b/include/flexflow/model.h index 874211172..0101553dc 100644 --- a/include/flexflow/model.h +++ b/include/flexflow/model.h @@ -442,74 +442,74 @@ class FFModel { bool cpu_offload; // C++ APIs for constructing models // Add an exp layer - Tensor exp(const Tensor x, char const *name = NULL); + Tensor exp(Tensor const x, char const *name = NULL); // Add an add layer - Tensor add(const Tensor x, - const Tensor y, + Tensor add(Tensor const x, + Tensor const y, bool inplace_a = false, char const *name = NULL); // Add a subtract layer - Tensor subtract(const Tensor x, - const Tensor y, + Tensor subtract(Tensor const x, + Tensor const y, bool inplace_a = false, char const *name = NULL); // Add a multiply layer - Tensor multiply(const Tensor x, - const Tensor y, + Tensor multiply(Tensor const x, + Tensor const y, bool inplace_a = false, char const *name = NULL); // Add a divide layer - Tensor divide(const Tensor x, - const Tensor y, + Tensor divide(Tensor const x, + Tensor const y, bool inplace_a = false, char const *name = NULL); // Add a max layer - Tensor max(const Tensor x, - const Tensor y, + Tensor max(Tensor const x, + Tensor const y, bool inplace_a = false, char const *name = NULL); // Add a min layer - Tensor min(const Tensor x, - const Tensor y, + Tensor min(Tensor const x, + Tensor const y, bool inplace_a = false, char const *name = NULL); // Add a rsqrt layer - Tensor rsqrt(const Tensor x, bool inplace = true, char const *name = NULL); + Tensor rsqrt(Tensor const x, bool inplace = true, char const *name = NULL); // Add a pow layer - Tensor pow(const Tensor x, + Tensor pow(Tensor const x, float const exponent, bool inplace = true, char const *name = NULL); // Add a scalar multiply layer - Tensor scalar_multiply(const Tensor x, + Tensor scalar_multiply(Tensor const x, float const scalar, bool inplace = true, char const *name = NULL); - Tensor scalar_add(const Tensor x, + Tensor scalar_add(Tensor const x, float const scalar, bool inplace = true, char const *name = NULL); - Tensor scalar_sub(const Tensor x, + Tensor scalar_sub(Tensor const x, float const scalar, bool inplace = true, char const *name = NULL); - Tensor scalar_truediv(const Tensor x, + Tensor scalar_truediv(Tensor const x, float const scalar, bool inplace = true, char const *name = NULL); // Add a sin layer - Tensor sin(const Tensor x, char const *name = NULL); + Tensor sin(Tensor const x, char const *name = NULL); // Add a cos layer - Tensor cos(const Tensor x, char const *name = NULL); + Tensor cos(Tensor const x, char const *name = NULL); // Add an activation layer - Tensor relu(const Tensor x, bool inplace = true, char const *name = NULL); - Tensor identity(const Tensor x, char const *name = NULL); - Tensor gelu(const Tensor x, char const *name = NULL); - Tensor sigmoid(const Tensor x, char const *name = NULL); - Tensor tanh(const Tensor x, char const *name = NULL); - Tensor elu(const Tensor x, bool inplace = true, char const *name = NULL); + Tensor relu(Tensor const x, bool inplace = true, char const *name = NULL); + Tensor identity(Tensor const x, char const *name = NULL); + Tensor gelu(Tensor const x, char const *name = NULL); + Tensor sigmoid(Tensor const x, char const *name = NULL); + Tensor tanh(Tensor const x, char const *name = NULL); + Tensor elu(Tensor const x, bool inplace = true, char const *name = NULL); // Add a 2D convolutional layer - Tensor conv2d(const Tensor input, + Tensor conv2d(Tensor const input, int outChannels, int kernelH, int kernelW, @@ -525,12 +525,12 @@ class FFModel { Initializer *bias_initializer = NULL, char const *name = NULL); // Add a dropout layer - Tensor dropout(const Tensor input, + Tensor dropout(Tensor const input, float rate, unsigned long long seed = 0, char const *name = NULL); // Add an embedding layer - Tensor embedding(const Tensor input, + Tensor embedding(Tensor const input, int num_entries, int outDim, AggrMode aggr, @@ -539,13 +539,13 @@ class FFModel { Initializer *kernel_initializer = NULL, char const *name = NULL); // Add a gather layer - Tensor gather(const Tensor input, - const Tensor index, + Tensor gather(Tensor const input, + Tensor const index, int dim, char const *name = NULL); // Add a group_by layer - void group_by(const Tensor data, - const Tensor assign, + void group_by(Tensor const data, + Tensor const assign, Tensor *outputs, int n, float alpha, @@ -567,7 +567,7 @@ class FFModel { float lambda_bal, char const *name = NULL); // Add a 2D pooling layer - Tensor pool2d(const Tensor input, + Tensor pool2d(Tensor const input, int kernelH, int kernelW, int strideH, @@ -578,7 +578,7 @@ class FFModel { ActiMode activation = AC_MODE_NONE, char const *name = NULL); // Add a layer_norm layer - Tensor layer_norm(const Tensor input, + Tensor layer_norm(Tensor const input, std::vector const &axes, bool elementwise_affine, float eps, @@ -586,9 +586,9 @@ class FFModel { DataType data_type = DT_NONE, char const *name = NULL); // Add a layer_norm layer with residual(s) - void residual_layer_norm(const Tensor input, - const Tensor residual1, - const Tensor residual2, + void residual_layer_norm(Tensor const input, + Tensor const residual1, + Tensor const residual2, Tensor *outputs, bool use_two_residuals, std::vector const &axes, @@ -599,8 +599,8 @@ class FFModel { DataType data_type = DT_NONE, char const *name = NULL); // Add a add_bias_residual_layer_norm layer - void add_bias_residual_layer_norm(const Tensor input, - const Tensor residual, + void add_bias_residual_layer_norm(Tensor const input, + Tensor const residual, Tensor *outputs, std::vector const &axes, bool elementwise_affine, @@ -610,28 +610,28 @@ class FFModel { DataType data_type = DT_NONE, char const *name = NULL); // Add a sigmoid_silu_multi layer - Tensor sigmoid_silu_multi(const Tensor input1, - const Tensor input2, + Tensor sigmoid_silu_multi(Tensor const input1, + Tensor const input2, DataType data_type = DT_NONE, char const *name = NULL); // Add a batch_norm layer Tensor - batch_norm(const Tensor input, bool relu = true, char const *name = NULL); + batch_norm(Tensor const input, bool relu = true, char const *name = NULL); // Add a batch_matmul layer - Tensor batch_matmul(const Tensor A, - const Tensor B, + Tensor batch_matmul(Tensor const A, + Tensor const B, int a_seq_length_dim = -1, int b_seq_length_dim = -1, char const *name = nullptr); // Add a root mean square layer - Tensor rms_norm(const Tensor input, + Tensor rms_norm(Tensor const input, float eps, int dim, DataType data_type = DT_NONE, char const *name = NULL); // Add a residual root mean square layer - void residual_rms_norm(const Tensor input1, - const Tensor input2, + void residual_rms_norm(Tensor const input1, + Tensor const input2, Tensor *outputs, float eps, int dim, @@ -639,13 +639,13 @@ class FFModel { DataType data_type = DT_NONE, char const *name = NULL); // Add a beam search top k layer - Tensor beam_top_k(const Tensor input, + Tensor beam_top_k(Tensor const input, int max_beam_size, bool sorted, char const *name = NULL); // Add a dense layer - Tensor dense(const Tensor input, + Tensor dense(Tensor const input, int outDim, ActiMode activation = AC_MODE_NONE, bool use_bias = true, @@ -657,7 +657,7 @@ class FFModel { float regularizer_lambda = 0.0, char const *name = NULL); // Add a cast layer - Tensor cast(const Tensor input, DataType dtype, char const *name = nullptr); + Tensor cast(Tensor const input, DataType dtype, char const *name = nullptr); // Add a concat layer Tensor concat(int n, Tensor const *tensors, int axis, char const *name = NULL); @@ -672,58 +672,58 @@ class FFModel { int experts_internal_dim_size = 0, // hidden dimension for internal layers char const *name = NULL); // Add a mean layer - Tensor mean(const Tensor input, + Tensor mean(Tensor const input, std::vector const &dims, bool keepdims, char const *name); // Add a moe layer (wrapping topk, group_by and aggregate operators) - Tensor moe(const Tensor input, + Tensor moe(Tensor const input, int num_exp, int num_select, int expert_hidden_size, float alpha, float lambda); // Add a split layer - void split(const Tensor input, + void split(Tensor const input, Tensor *outputs, std::vector const &split, int axis, char const *name = NULL); // Add a flat layer - Tensor flat(const Tensor input, char const *name = NULL); + Tensor flat(Tensor const input, char const *name = NULL); // Add a softmax layer - Tensor softmax(const Tensor input, + Tensor softmax(Tensor const input, int dim = -1, DataType data_type = DT_NONE, char const *name = NULL); // Create input tensors and constants - Tensor transpose(const Tensor input, + Tensor transpose(Tensor const input, std::vector const &perm, char const *name = NULL); - Tensor reduce_sum(const Tensor input, + Tensor reduce_sum(Tensor const input, std::vector const &axes, bool keepdims = false, char const *name = nullptr); - Tensor reshape(const Tensor input, + Tensor reshape(Tensor const input, std::vector const &shape, char const *name = NULL); - Tensor reverse(const Tensor input, int axis, char const *name = NULL); - void top_k(const Tensor input, + Tensor reverse(Tensor const input, int axis, char const *name = NULL); + void top_k(Tensor const input, Tensor *outputs, int k, bool sorted, char const *name = NULL); - Tensor arg_top_k(const Tensor input, + Tensor arg_top_k(Tensor const input, // Tensor *outputs, int k, bool sorted, bool speculative_decoding, char const *name = NULL); - Tensor argmax(const Tensor input, bool beam_search, char const *name = NULL); - Tensor sampling(const Tensor input, float top_p, char const *name = NULL); - Tensor multihead_attention(const Tensor query, - const Tensor key, - const Tensor value, + Tensor argmax(Tensor const input, bool beam_search, char const *name = NULL); + Tensor sampling(Tensor const input, float top_p, char const *name = NULL); + Tensor multihead_attention(Tensor const query, + Tensor const key, + Tensor const value, int embed_dim, int num_heads, int kdim = 0, @@ -736,7 +736,7 @@ class FFModel { Initializer *kernel_initializer = NULL, char const *name = NULL); Tensor inc_multihead_self_attention( - const Tensor input, + Tensor const input, int embed_dim, int num_q_heads, int num_kv_heads, @@ -753,7 +753,7 @@ class FFModel { bool position_bias = false, char const *name = NULL); Tensor spec_inc_multihead_self_attention( - const Tensor input, + Tensor const input, int embed_dim, int num_q_heads, int num_kv_heads, @@ -770,7 +770,7 @@ class FFModel { bool position_bias = false, char const *name = NULL); Tensor inc_multihead_self_attention_verify( - const Tensor input, + Tensor const input, int embed_dim, int num_q_heads, int num_kv_heads, @@ -808,7 +808,7 @@ class FFModel { bool create_grad = true); ParallelTensor create_parallel_tensor_legion_ordering(int num_dim, - const ParallelDim dims[], + ParallelDim const dims[], DataType data_type, Op const *owner_op = NULL, int owner_idx = 0, @@ -821,7 +821,7 @@ class FFModel { int owner_idx = 0, bool create_grad = true); ParallelTensor create_parallel_tensor(int num_dim, - const ParallelDim dims[], + ParallelDim const dims[], DataType data_type, Op const *owner_op = NULL, int owner_idx = 0, @@ -834,7 +834,7 @@ class FFModel { int owner_idx = 0, bool create_grad = true); template - ParallelTensor create_parallel_tensor(const ParallelDim dims[], + ParallelTensor create_parallel_tensor(ParallelDim const dims[], DataType data_type, Op const *owner_op = NULL, int owner_idx = 0, @@ -858,7 +858,7 @@ class FFModel { ParameterSyncType sync_type = ParameterSyncType::NONE); template ParallelParameter create_parallel_weight( - const ParallelDim dims[], + ParallelDim const dims[], DataType data_type, Op const *owner_op = NULL, bool create_grad = true, @@ -866,7 +866,7 @@ class FFModel { ParameterSyncType sync_type = ParameterSyncType::NONE); ParallelParameter create_parallel_weight( int numdim, - const ParallelDim dims[], + ParallelDim const dims[], DataType data_type, Op const *owner_op = NULL, bool create_grad = true, @@ -874,7 +874,7 @@ class FFModel { ParameterSyncType sync_type = ParameterSyncType::NONE); ParallelParameter create_parallel_weight_legion_ordering( int numdim, - const ParallelDim dims[], + ParallelDim const dims[], DataType data_type, Op const *owner_op = NULL, bool create_grad = true, @@ -883,7 +883,7 @@ class FFModel { void map_tensor(ParallelTensor tensor, Op const *parallel_op); void map_weight(ParallelTensor tensor, Op const *parallel_op); - bool get_parallel_tensor_from_tensor(const Tensor tensor, + bool get_parallel_tensor_from_tensor(Tensor const tensor, ParallelTensor ¶llel_tensor) const; template @@ -924,7 +924,7 @@ class FFModel { // Internal PCG::Node creation APIs // ======================================== template - PCG::Node get_or_create_node(const typename T::Input &input, + PCG::Node get_or_create_node(typename T::Input const &input, typename T::Params const ¶ms) { using Params = typename T::Params; @@ -954,50 +954,50 @@ class FFModel { return this->new_node(op); } - PCG::Node get_or_create_noop_node(const ParallelTensor input); + PCG::Node get_or_create_noop_node(ParallelTensor const input); PCG::Node get_or_create_input_node(ParallelTensorShape const &); PCG::Node get_or_create_fused_parallel_node( - const ParallelTensor input, + ParallelTensor const input, std::vector const ¶llel_ops); - PCG::Node get_or_create_parallel_op_node(const ParallelTensor input, + PCG::Node get_or_create_parallel_op_node(ParallelTensor const input, ParallelOpInfo const &); // ======================================== // Internal APIs that should not be invoked from applications // ======================================== void create_disjoint_partition(int num_dims, - const ParallelDim dims[], + ParallelDim const dims[], Legion::IndexSpace const &part_is, Legion::LogicalRegion const ®ion, Legion::LogicalPartition &part); template void create_disjoint_partition_with_dim2( - const ParallelDim dims[], + ParallelDim const dims[], Legion::IndexSpaceT const &part_is, Legion::LogicalRegion const ®ion, Legion::LogicalPartition &part); void create_aliased_partition(int num_dims, - const ParallelDim dims[], + ParallelDim const dims[], int aliased_dim, Legion::IndexSpace const &part_is, Legion::LogicalRegion const ®ion, Legion::LogicalPartition &part); template void create_aliased_partition_with_dim2( - const ParallelDim dims[], + ParallelDim const dims[], int aliased_dim, Legion::IndexSpaceT const &part_is, Legion::LogicalRegion const ®ion, Legion::LogicalPartition &part); template - void create_disjoint_partition(const ParallelTensor tensor, + void create_disjoint_partition(ParallelTensor const tensor, Legion::IndexSpaceT const &part_is, Legion::LogicalPartition &part_fwd, Legion::LogicalPartition &part_bwd); template void create_data_parallel_partition_with_diff_dims( - const ParallelTensor tensor, + ParallelTensor const tensor, Legion::IndexSpaceT const &task_is, Legion::LogicalPartition &part_fwd, Legion::LogicalPartition &part_bwd); @@ -1092,7 +1092,7 @@ class FFModel { Legion::IndexSpace get_or_create_task_is(ParallelConfig const &pc); Legion::IndexSpace get_or_create_task_is(MachineView const &view); Legion::IndexSpace get_or_create_task_is(Legion::Domain const &domain); - Legion::IndexSpace get_or_create_task_is(const ParallelTensor); + Legion::IndexSpace get_or_create_task_is(ParallelTensor const); Legion::IndexSpace get_task_is(Legion::Domain const &domain) const; Legion::IndexSpace get_task_is(ParallelConfig const &pc) const; Legion::IndexSpace get_task_is(MachineView const &view) const; diff --git a/include/flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh b/include/flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh index 3d122d4bc..5b1def7fb 100644 --- a/include/flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh +++ b/include/flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh @@ -34,6 +34,24 @@ struct float8 { float d; }; +struct __nv_bfloat164 { + __nv_bfloat16 x; + __nv_bfloat16 y; + __nv_bfloat16 z; + __nv_bfloat16 w; +}; + +struct __nv_bfloat168 { + __nv_bfloat16 x; + __nv_bfloat16 y; + __nv_bfloat16 z; + __nv_bfloat16 w; + __nv_bfloat16 a; + __nv_bfloat16 b; + __nv_bfloat16 c; + __nv_bfloat16 d; +}; + ////////////////data type/////////////// template struct VEC_K {}; @@ -61,6 +79,18 @@ template <> struct VEC_K { using Type = half4; }; +template <> +struct VEC_K<__nv_bfloat16, 1> { + using Type = __nv_bfloat16; +}; +template <> +struct VEC_K<__nv_bfloat16, 2> { + using Type = __nv_bfloat162; +}; +template <> +struct VEC_K<__nv_bfloat16, 4> { + using Type = __nv_bfloat164; +}; // data type for QK production template @@ -94,6 +124,22 @@ template <> struct Vec_fp32_ { using Type = float8; }; +template <> +struct Vec_fp32_<__nv_bfloat16> { + using Type = float; +}; +template <> +struct Vec_fp32_<__nv_bfloat162> { + using Type = float2; +}; +template <> +struct Vec_fp32_<__nv_bfloat164> { + using Type = float4; +}; +template <> +struct Vec_fp32_<__nv_bfloat168> { + using Type = float8; +}; template struct VEC_V {}; @@ -106,6 +152,11 @@ struct VEC_V { using Type = half8; }; +template <> +struct VEC_V<__nv_bfloat16> { + using Type = __nv_bfloat168; +}; + ////////////////data structures half/////////////// ////////////////////////////////////floating point @@ -331,6 +382,42 @@ inline __device__ float8 cast_to_float(half8 u) { return tmp; } +inline __device__ float cast_to_float(__nv_bfloat16 u) { + return __bfloat162float(u); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 cast_to_float(__nv_bfloat162 u) { + float2 tmp; + tmp.x = __bfloat162float(u.x); + tmp.y = __bfloat162float(u.y); + return tmp; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 cast_to_float(__nv_bfloat164 u) { + float4 tmp; + tmp.x = __bfloat162float(u.x); + tmp.y = __bfloat162float(u.y); + tmp.z = __bfloat162float(u.z); + tmp.w = __bfloat162float(u.w); + return tmp; +} +inline __device__ float8 cast_to_float(__nv_bfloat168 u) { + float8 tmp; + tmp.x = __bfloat162float(u.x); + tmp.y = __bfloat162float(u.y); + tmp.z = __bfloat162float(u.z); + tmp.w = __bfloat162float(u.w); + tmp.a = __bfloat162float(u.a); + tmp.b = __bfloat162float(u.b); + tmp.c = __bfloat162float(u.c); + tmp.d = __bfloat162float(u.d); + return tmp; +} + inline __device__ void convert_from_float(float4 &dst, float4 src) { dst = src; } @@ -369,6 +456,31 @@ inline __device__ void convert_from_float(half &dst, float src) { dst = __float2half(src); } +inline __device__ void convert_from_float(__nv_bfloat164 &dst, float4 src) { + dst.x = __float2bfloat16(src.x); + dst.y = __float2bfloat16(src.y); + dst.z = __float2bfloat16(src.z); + dst.w = __float2bfloat16(src.w); +} + +inline __device__ void convert_from_float(__nv_bfloat168 &dst, float8 src) { + dst.x = __float2bfloat16(src.x); + dst.y = __float2bfloat16(src.y); + dst.z = __float2bfloat16(src.z); + dst.w = __float2bfloat16(src.w); + dst.a = __float2bfloat16(src.a); + dst.b = __float2bfloat16(src.b); + dst.c = __float2bfloat16(src.c); + dst.d = __float2bfloat16(src.d); +} +inline __device__ void convert_from_float(__nv_bfloat162 &dst, float2 src) { + dst.x = __float2bfloat16(src.x); + dst.y = __float2bfloat16(src.y); +} +inline __device__ void convert_from_float(__nv_bfloat16 &dst, float src) { + dst = __float2bfloat16(src); +} + //////////////////////////////////////utils/////////////////////////////////////////////// template diff --git a/include/flexflow/ops/kernels/lora_linear_kernels.h b/include/flexflow/ops/kernels/lora_linear_kernels.h index c35419d33..3ed8e8da0 100644 --- a/include/flexflow/ops/kernels/lora_linear_kernels.h +++ b/include/flexflow/ops/kernels/lora_linear_kernels.h @@ -47,22 +47,22 @@ void save_peft_weights_if_needed(LoraLinearMeta *m, namespace Internal { // template // void init_kernel(LoraLinearMeta *m, int seed, ffStream_t stream); -template +template void inference_kernel(LoraLinearMeta *m, BatchConfig const *bc, - DT const *input_ptr, - DT *output_ptr, + DATA_DT const *input_ptr, + DATA_DT *output_ptr, int in_dim, int out_dim, ffStream_t stream); -template +template void peft_bwd_kernel(Context ctx, Runtime *runtime, LoraLinearMeta *m, BatchConfig const *bc, int shard_id, - DT *input_grad_ptr, - DT const *output_grad_ptr, + DATA_DT *input_grad_ptr, + DATA_DT const *output_grad_ptr, int in_dim, int out_dim, ffStream_t stream); diff --git a/include/flexflow/utils/cuda_helper.h b/include/flexflow/utils/cuda_helper.h index eb9b1c542..26db225f0 100644 --- a/include/flexflow/utils/cuda_helper.h +++ b/include/flexflow/utils/cuda_helper.h @@ -184,6 +184,19 @@ cudnnStatus_t Legion::Domain domain, DataType data_type = DT_FLOAT); +template +struct cublasAlphaBetaType { + using type = float; // default +}; +template <> +struct cublasAlphaBetaType { + using type = half; +}; +template <> +struct cublasAlphaBetaType<__nv_bfloat16> { + using type = __nv_bfloat16; +}; + cudaDataType_t ff_to_cuda_datatype(DataType type); cudnnDataType_t ff_to_cudnn_datatype(DataType type); #ifdef FF_USE_NCCL diff --git a/include/flexflow/utils/file_loader.h b/include/flexflow/utils/file_loader.h index 8ad0f1d14..44d018ff6 100644 --- a/include/flexflow/utils/file_loader.h +++ b/include/flexflow/utils/file_loader.h @@ -32,7 +32,7 @@ class FileDataLoader { size_t _hidden_dim, size_t _qkv_inner_dim, int _tensor_parallelism_degree, - bool _use_full_precision); + DataType _data_type); BatchConfig::TokenId *generate_requests(int num, int length); @@ -44,7 +44,7 @@ class FileDataLoader { size_t num_replicas, DT *weight, Domain weight_domain); - + void load_quantization_weight(FFModel *ff, Layer *l, int weight_idx, @@ -64,7 +64,7 @@ class FileDataLoader { Legion::Runtime *runtime); void load_positions(FFModel *ff, - Tensor pt, + Tensor pt, ParallelTensor position_pt, int max_seq_length, int offset); @@ -74,7 +74,7 @@ class FileDataLoader { size_t hidden_dim, qkv_inner_dim; std::string prompts_filepath; std::string weights_folder; - bool use_full_precision; + DataType data_type; }; struct WeightLoadTaskArgs { diff --git a/inference/flexllm/peft_train.cc b/inference/flexllm/peft_train.cc index 88dd01cec..3f2aff733 100644 --- a/inference/flexllm/peft_train.cc +++ b/inference/flexllm/peft_train.cc @@ -43,7 +43,9 @@ void parse_input_args(char **argv, FilePaths &paths, std::string &llm_model_name, std::string &peft_model_name, + DataType &data_type, bool &use_full_precision, + bool &use_bf16_precision, bool &verbose, bool &do_sample, bool &enable_peft, @@ -102,6 +104,12 @@ void parse_input_args(char **argv, } if (!strcmp(argv[i], "--use-full-precision")) { use_full_precision = true; + data_type = DT_FLOAT; + continue; + } + if (!strcmp(argv[i], "--use-bf16-precision")) { + use_bf16_precision = true; + data_type = DT_BFLOAT16; continue; } if (!strcmp(argv[i], "--warmup")) { @@ -274,7 +282,9 @@ void FlexFlow::top_level_task(Task const *task, } FilePaths file_paths; std::string llm_model_name, peft_model_name; + DataType data_type = DT_HALF; bool use_full_precision = false; + bool use_bf16_precision = false; bool verbose = false; bool do_sample = false; bool enable_peft = false; @@ -297,7 +307,9 @@ void FlexFlow::top_level_task(Task const *task, file_paths, llm_model_name, peft_model_name, + data_type, use_full_precision, + use_bf16_precision, verbose, do_sample, enable_peft, @@ -332,11 +344,13 @@ void FlexFlow::top_level_task(Task const *task, {file_paths.cache_folder_path, "configs", llm_model_name, "config.json"}); std::string tokenizer_filepath = join_path({file_paths.cache_folder_path, "tokenizers", llm_model_name}); + // bfloat16 shares same weight file with float32 std::string weights_filepath = join_path({file_paths.cache_folder_path, "weights", llm_model_name, - use_full_precision ? "full-precision" : "half-precision"}); + use_full_precision ? "full-precision" + : "half-precision"}); std::ifstream config_file_handle(config_filepath); if (!config_file_handle.good()) { std::cout << "Model config file " << config_filepath << " not found." @@ -400,7 +414,7 @@ void FlexFlow::top_level_task(Task const *task, optim_config, true /*init_lora_weights*/, llm_model_name, - use_full_precision ? "fp32" : "fp16", + use_full_precision ? "fp32" : (use_bf16_precision ? "bf16" : "fp16"), rank, (float)rank, 0.0f, @@ -428,33 +442,27 @@ void FlexFlow::top_level_task(Task const *task, weights_filepath, INC_DECODING_MODE, generationConfig, - use_full_precision); + data_type); } else if (model_type == ModelType::OPT) { - OPT::create_opt_model(model, - config_filepath, - weights_filepath, - INC_DECODING_MODE, - use_full_precision); + OPT::create_opt_model( + model, config_filepath, weights_filepath, INC_DECODING_MODE, data_type); } else if (model_type == ModelType::FALCON) { - FALCON::create_falcon_model(model, - config_filepath, - weights_filepath, - INC_DECODING_MODE, - use_full_precision); + FALCON::create_falcon_model( + model, config_filepath, weights_filepath, INC_DECODING_MODE, data_type); } else if (model_type == ModelType::STARCODER) { STARCODER::create_starcoder_model(model, config_filepath, weights_filepath, INC_DECODING_MODE, generationConfig, - use_full_precision); + data_type); } else if (model_type == ModelType::MPT) { MPT::create_mpt_model(model, config_filepath, weights_filepath, INC_DECODING_MODE, generationConfig, - use_full_precision); + data_type); } else { assert(false && "unknow model type"); } diff --git a/inference/incr_decoding/incr_decoding.cc b/inference/incr_decoding/incr_decoding.cc index cd5c68253..27fedd2b6 100644 --- a/inference/incr_decoding/incr_decoding.cc +++ b/inference/incr_decoding/incr_decoding.cc @@ -41,7 +41,9 @@ void parse_input_args(char **argv, int argc, FilePaths &paths, std::string &llm_model_name, + DataType &data_type, bool &use_full_precision, + bool &use_bf16, bool &verbose, bool &do_sample, float &temperature, @@ -84,6 +86,12 @@ void parse_input_args(char **argv, } if (!strcmp(argv[i], "--use-full-precision")) { use_full_precision = true; + data_type = DT_FLOAT; + continue; + } + if (!strcmp(argv[i], "--use-bf16")) { + use_bf16 = true; + data_type = DT_BFLOAT16; continue; } // verbose logging to stdout @@ -261,7 +269,9 @@ void FlexFlow::top_level_task(Task const *task, } FilePaths file_paths; std::string llm_model_name; - bool use_full_precision = false; + DataType data_type = DT_HALF; // default to half precision + bool use_full_precision = false; // if true, use full precision + bool use_bf16_precision = false; // if true, use bfloat16 precision bool verbose = false; bool do_sample = false; float temperature = 0.0f; @@ -281,7 +291,9 @@ void FlexFlow::top_level_task(Task const *task, argc, file_paths, llm_model_name, + data_type, use_full_precision, + use_bf16_precision, verbose, do_sample, temperature, @@ -306,11 +318,13 @@ void FlexFlow::top_level_task(Task const *task, {file_paths.cache_folder_path, "configs", llm_model_name, "config.json"}); std::string tokenizer_filepath = join_path({file_paths.cache_folder_path, "tokenizers", llm_model_name}); - std::string weights_filepath = - join_path({file_paths.cache_folder_path, - "weights", - llm_model_name, - use_full_precision ? "full-precision" : "half-precision"}); + // same weights file for fp32 and bf16 precision + std::string weights_filepath = join_path( + {file_paths.cache_folder_path, + "weights", + llm_model_name, + use_full_precision || use_bf16_precision ? "full-precision" + : (use_bf16_precision ? "full-precision" : "half-precision")}); std::ifstream config_file_handle(config_filepath); if (!config_file_handle.good()) { std::cout << "Model config file " << config_filepath << " not found." @@ -382,33 +396,27 @@ void FlexFlow::top_level_task(Task const *task, weights_filepath, INC_DECODING_MODE, generationConfig, - use_full_precision); + data_type); } else if (model_type == ModelType::OPT) { - OPT::create_opt_model(model, - config_filepath, - weights_filepath, - INC_DECODING_MODE, - use_full_precision); + OPT::create_opt_model( + model, config_filepath, weights_filepath, INC_DECODING_MODE, data_type); } else if (model_type == ModelType::FALCON) { - FALCON::create_falcon_model(model, - config_filepath, - weights_filepath, - INC_DECODING_MODE, - use_full_precision); + FALCON::create_falcon_model( + model, config_filepath, weights_filepath, INC_DECODING_MODE, data_type); } else if (model_type == ModelType::STARCODER) { STARCODER::create_starcoder_model(model, config_filepath, weights_filepath, INC_DECODING_MODE, generationConfig, - use_full_precision); + data_type); } else if (model_type == ModelType::MPT) { MPT::create_mpt_model(model, config_filepath, weights_filepath, INC_DECODING_MODE, generationConfig, - use_full_precision); + data_type); } else { assert(false && "unknow model type"); } diff --git a/inference/models/falcon.cc b/inference/models/falcon.cc index 3a345b15c..6fad88364 100644 --- a/inference/models/falcon.cc +++ b/inference/models/falcon.cc @@ -24,7 +24,7 @@ void FALCON::create_falcon_model(FFModel &ff, std::string const &model_config_file_path, std::string const &weight_file_path, InferenceMode mode, - bool use_full_precision) { + DataType data_type) { FalconConfig falcon_config(model_config_file_path); falcon_config.print(); @@ -56,7 +56,7 @@ void FALCON::create_falcon_model(FFModel &ff, falcon_config.vocab_size, falcon_config.hidden_size, AGGR_MODE_NONE, - use_full_precision ? DT_FLOAT : DT_HALF, + data_type, // fp32, bf16, or fp16 NULL, embed_init, "word_embeddings"); @@ -285,7 +285,7 @@ void FALCON::create_falcon_model(FFModel &ff, falcon_config.hidden_size, head_dim, ff.config.tensor_parallelism_degree, - use_full_precision); + data_type); InferenceManager *im = InferenceManager::get_inference_manager(); im->register_model_weights_loader(&ff, fileloader); diff --git a/inference/models/falcon.h b/inference/models/falcon.h index 996cfb838..f1a96a6ce 100644 --- a/inference/models/falcon.h +++ b/inference/models/falcon.h @@ -113,7 +113,7 @@ class FALCON { std::string const &model_config_file_path, std::string const &weight_file_path, InferenceMode mode, - bool use_full_precision = false); + DataType data_type); }; }; // namespace FlexFlow diff --git a/inference/models/llama.cc b/inference/models/llama.cc index 37d6f444b..39a5c629a 100644 --- a/inference/models/llama.cc +++ b/inference/models/llama.cc @@ -25,7 +25,7 @@ void LLAMA::create_llama_model(FFModel &ff, std::string const &weight_file_path, InferenceMode mode, GenerationConfig generation_config, - bool use_full_precision) { + DataType data_type) { // do not apply cpu offload in beam search model. LLAMAConfig llama_config(model_config_file_path); llama_config.print(); @@ -59,7 +59,7 @@ void LLAMA::create_llama_model(FFModel &ff, llama_config.vocab_size, llama_config.hidden_size, AGGR_MODE_NONE, - use_full_precision ? DT_FLOAT : DT_HALF, + data_type, NULL, embed_init, "embed_tokens"); @@ -313,7 +313,7 @@ void LLAMA::create_llama_model(FFModel &ff, llama_config.hidden_size, head_dim, ff.config.tensor_parallelism_degree, - use_full_precision); + data_type); InferenceManager *im = InferenceManager::get_inference_manager(); im->register_model_weights_loader(&ff, fileloader); diff --git a/inference/models/llama.h b/inference/models/llama.h index e74f8d52b..0246a4216 100644 --- a/inference/models/llama.h +++ b/inference/models/llama.h @@ -109,7 +109,7 @@ class LLAMA { std::string const &weight_file_path, InferenceMode mode, GenerationConfig generation_config, - bool use_full_precision = false); + DataType data_type); }; }; // namespace FlexFlow diff --git a/inference/models/mpt.cc b/inference/models/mpt.cc index db7c2f6f3..026a7b07d 100644 --- a/inference/models/mpt.cc +++ b/inference/models/mpt.cc @@ -25,7 +25,7 @@ void MPT::create_mpt_model(FFModel &ff, std::string const &weight_file_path, InferenceMode mode, GenerationConfig generationConfig, - bool use_full_precision) { + DataType data_type) { MPTConfig mpt_config(model_config_file_path); mpt_config.print(); @@ -53,7 +53,7 @@ void MPT::create_mpt_model(FFModel &ff, mpt_config.vocab_size, mpt_config.hidden_size, AGGR_MODE_NONE, - use_full_precision ? DT_FLOAT : DT_HALF, + data_type, NULL, embed_init, "wte"); @@ -289,7 +289,7 @@ void MPT::create_mpt_model(FFModel &ff, mpt_config.hidden_size, mpt_config.hidden_size / mpt_config.n_heads, ff.config.tensor_parallelism_degree, - use_full_precision); + data_type); InferenceManager *im = InferenceManager::get_inference_manager(); im->register_model_weights_loader(&ff, fileloader); diff --git a/inference/models/mpt.h b/inference/models/mpt.h index a3b4c663f..2a9b986ab 100644 --- a/inference/models/mpt.h +++ b/inference/models/mpt.h @@ -74,7 +74,7 @@ class MPT { std::string const &weight_file_path, InferenceMode mode, GenerationConfig generationConfig, - bool use_full_precision = false); + DataType data_type); }; }; // namespace FlexFlow diff --git a/inference/models/opt.cc b/inference/models/opt.cc index da7bc6ab8..5751ee8bd 100644 --- a/inference/models/opt.cc +++ b/inference/models/opt.cc @@ -24,7 +24,7 @@ void OPT::create_opt_model(FFModel &ff, std::string const &model_config_file_path, std::string const &weight_file_path, InferenceMode mode, - bool use_full_precision) { + DataType data_type) { OPTConfig opt_config(model_config_file_path); opt_config.print(); @@ -57,7 +57,7 @@ void OPT::create_opt_model(FFModel &ff, opt_config.vocab_size, opt_config.word_embed_proj_dim, AGGR_MODE_NONE, - use_full_precision ? DT_FLOAT : DT_HALF, + data_type, NULL, embed_init, "embed_tokens"); @@ -67,7 +67,7 @@ void OPT::create_opt_model(FFModel &ff, opt_config.max_position_embeddings, opt_config.hidden_size, AGGR_MODE_NONE, - use_full_precision ? DT_FLOAT : DT_HALF, + data_type, NULL, embed_init, "embed_positions"); @@ -300,7 +300,8 @@ void OPT::create_opt_model(FFModel &ff, opt_config.hidden_size, opt_config.hidden_size / opt_config.num_attention_heads, ff.config.tensor_parallelism_degree, - use_full_precision); + data_type); + InferenceManager *im = InferenceManager::get_inference_manager(); im->register_model_weights_loader(&ff, fileloader); } diff --git a/inference/models/opt.h b/inference/models/opt.h index 73a76b917..7baf0164d 100644 --- a/inference/models/opt.h +++ b/inference/models/opt.h @@ -98,7 +98,7 @@ class OPT { std::string const &model_config_file_path, std::string const &weight_file_path, InferenceMode mode, - bool use_full_precision = false); + DataType data_type); }; }; // namespace FlexFlow diff --git a/inference/models/starcoder.cc b/inference/models/starcoder.cc index 7c24b9614..3b7f26a30 100644 --- a/inference/models/starcoder.cc +++ b/inference/models/starcoder.cc @@ -26,7 +26,7 @@ void STARCODER::create_starcoder_model( std::string const &weight_file_path, InferenceMode mode, GenerationConfig generationConfig, - bool use_full_precision) { + DataType data_type) { // do not apply cpu offload in beam search model. STARCODERConfig startcoder_config(model_config_file_path); startcoder_config.print(); @@ -70,7 +70,7 @@ void STARCODER::create_starcoder_model( startcoder_config.vocab_size, startcoder_config.hidden_size, AGGR_MODE_NONE, - use_full_precision ? DT_FLOAT : DT_HALF, + data_type, NULL, embed_init, "wte"); @@ -80,7 +80,7 @@ void STARCODER::create_starcoder_model( startcoder_config.max_position_embeddings, startcoder_config.hidden_size, AGGR_MODE_NONE, - use_full_precision ? DT_FLOAT : DT_HALF, + data_type, NULL, embed_init, "wpe"); @@ -272,7 +272,8 @@ void STARCODER::create_starcoder_model( startcoder_config.hidden_size, head_dim, ff.config.tensor_parallelism_degree, - use_full_precision); + data_type); + im->register_model_weights_loader(&ff, fileloader); } diff --git a/inference/models/starcoder.h b/inference/models/starcoder.h index 89897652f..cc9206c09 100644 --- a/inference/models/starcoder.h +++ b/inference/models/starcoder.h @@ -73,7 +73,7 @@ class STARCODER { std::string const &weight_file_path, InferenceMode mode, GenerationConfig generationConfig, - bool use_full_precision = false); + DataType data_type); }; }; // namespace FlexFlow diff --git a/inference/peft/peft.cc b/inference/peft/peft.cc index 1b3bbfc0d..42728b4b3 100644 --- a/inference/peft/peft.cc +++ b/inference/peft/peft.cc @@ -43,7 +43,9 @@ void parse_input_args(char **argv, FilePaths &paths, std::string &llm_model_name, std::string &peft_model_name, + DataType &data_type, bool &use_full_precision, + bool &use_bf16_precision, bool &verbose, bool &do_sample, bool &enable_peft, @@ -102,6 +104,12 @@ void parse_input_args(char **argv, } if (!strcmp(argv[i], "--use-full-precision")) { use_full_precision = true; + data_type = DT_FLOAT; + continue; + } + if (!strcmp(argv[i], "--use-bf16-precision")) { + use_bf16_precision = true; + data_type = DT_BFLOAT16; continue; } if (!strcmp(argv[i], "--warmup")) { @@ -199,7 +207,9 @@ void FlexFlow::top_level_task(Task const *task, } FilePaths file_paths; std::string llm_model_name, peft_model_name; + DataType data_type = DT_HALF; bool use_full_precision = false; + bool use_bf16_precision = false; bool verbose = false; bool do_sample = false; bool enable_peft = false; @@ -222,7 +232,9 @@ void FlexFlow::top_level_task(Task const *task, file_paths, llm_model_name, peft_model_name, + data_type, use_full_precision, + use_bf16_precision, verbose, do_sample, enable_peft, @@ -250,11 +262,13 @@ void FlexFlow::top_level_task(Task const *task, {file_paths.cache_folder_path, "configs", llm_model_name, "config.json"}); std::string tokenizer_filepath = join_path({file_paths.cache_folder_path, "tokenizers", llm_model_name}); + // bfloat16 shares same weight file with float32? std::string weights_filepath = join_path({file_paths.cache_folder_path, "weights", llm_model_name, - use_full_precision ? "full-precision" : "half-precision"}); + use_full_precision ? "full-precision" + : "half-precision"}); std::ifstream config_file_handle(config_filepath); if (!config_file_handle.good()) { std::cout << "Model config file " << config_filepath << " not found." @@ -340,7 +354,7 @@ void FlexFlow::top_level_task(Task const *task, optim_config, false /*init_lora_weights*/, llm_model_name, - use_full_precision ? "fp32" : "fp16"); + use_full_precision ? "fp32" : (use_bf16_precision ? "bf16" : "fp16")); GenerationConfig generationConfig(do_sample, temperature, topp); RequestManager *rm = RequestManager::get_request_manager(); @@ -367,33 +381,33 @@ void FlexFlow::top_level_task(Task const *task, weights_filepath, INC_DECODING_MODE, generationConfig, - use_full_precision); + data_type); } else if (model_type == ModelType::OPT) { OPT::create_opt_model(model, config_filepath, weights_filepath, INC_DECODING_MODE, - use_full_precision); + data_type); } else if (model_type == ModelType::FALCON) { FALCON::create_falcon_model(model, config_filepath, weights_filepath, INC_DECODING_MODE, - use_full_precision); + data_type); } else if (model_type == ModelType::STARCODER) { STARCODER::create_starcoder_model(model, config_filepath, weights_filepath, INC_DECODING_MODE, generationConfig, - use_full_precision); + data_type); } else if (model_type == ModelType::MPT) { MPT::create_mpt_model(model, config_filepath, weights_filepath, INC_DECODING_MODE, generationConfig, - use_full_precision); + data_type); } else { assert(false && "unknow model type"); } diff --git a/inference/python/chat.py b/inference/python/chat.py index 7c7454375..b96df5332 100644 --- a/inference/python/chat.py +++ b/inference/python/chat.py @@ -47,6 +47,7 @@ def get_configs(): "cache_path": os.environ.get("FF_CACHE_PATH", ""), "refresh_cache": False, "full_precision": False, + "bfloat16_precision": False, "max_requests_per_batch": 4, "max_seq_length": 2048, "max_tokens_per_batch": 256, @@ -67,7 +68,11 @@ def main(): # Create the FlexFlow LLM ff_data_type = ( - ff.DataType.DT_FLOAT if configs.full_precision else ff.DataType.DT_HALF + ff.DataType.DT_FLOAT + if configs.full_precision + else ff.DataType.DT_BFLOAT16 + if configs.bfloat16_precision + else ff.DataType.DT_HALF ) llm = ff.LLM( configs.llm_model, diff --git a/inference/spec_infer/spec_infer.cc b/inference/spec_infer/spec_infer.cc index 380722927..98efe5f15 100644 --- a/inference/spec_infer/spec_infer.cc +++ b/inference/spec_infer/spec_infer.cc @@ -59,7 +59,9 @@ void parse_input_args(char **argv, int argc, FilePaths &paths, ModelNames &model_names, + DataType &data_type, bool &use_full_precision, + bool &use_bf16_precision, bool &verbose, int &max_requests_per_batch, int &max_tokens_per_batch, @@ -101,6 +103,12 @@ void parse_input_args(char **argv, } if (!strcmp(argv[i], "--use-full-precision")) { use_full_precision = true; + data_type = DT_FLOAT; + continue; + } + if (!strcmp(argv[i], "--use-bf16-precision")) { + use_bf16_precision = true; + data_type = DT_BFLOAT16; continue; } // verbose logging to stdout @@ -145,7 +153,8 @@ void parse_input_args(char **argv, void get_model_meta(FilePaths &file_paths, ModelMeta &model_metadata, - bool use_full_precision) { + bool use_full_precision, + bool use_bf16_precision) { if (model_metadata.model_names.llm_model_name.empty() || model_metadata.model_names.ssm_model_names.size() == 0) { assert(false && "SpecInfer needs at least one LLM and one SSM for " @@ -160,6 +169,7 @@ void get_model_meta(FilePaths &file_paths, join_path({file_paths.cache_folder_path, "tokenizers", model_metadata.model_names.llm_model_name}); + // bfloat16 shares same weight file with float32 model_metadata.llm_weights_path = join_path({file_paths.cache_folder_path, "weights", @@ -222,11 +232,12 @@ void get_model_meta(FilePaths &file_paths, "config.json"}); std::string ssm_tokenizer_path = join_path({file_paths.cache_folder_path, "tokenizers", ssm_model_name}); + // bfloat16 shares same weight file with float32 std::string ssm_weights_path = join_path({file_paths.cache_folder_path, "weights", ssm_model_name, - use_full_precision ? "full-precision" : "half-precision"}); + use_full_precision || use_bf16_precision ? "full-precision" : "half-precision"}); std::ifstream ssm_config_file_handle(ssm_config_path); if (!ssm_config_file_handle.good()) { @@ -292,7 +303,9 @@ void FlexFlow::top_level_task(Task const *task, FFConfig ffconfig; FilePaths file_paths; ModelMeta model_metadata; + DataType data_type = DT_HALF; bool use_full_precision = false; + bool use_bf16_precision = false; bool verbose = false; int max_requests_per_batch = 16; int max_tokens_per_batch = 256; @@ -308,7 +321,9 @@ void FlexFlow::top_level_task(Task const *task, argc, file_paths, model_metadata.model_names, + data_type, use_full_precision, + use_bf16_precision, verbose, max_requests_per_batch, max_tokens_per_batch, @@ -316,7 +331,7 @@ void FlexFlow::top_level_task(Task const *task, expansion_degree, max_length); - get_model_meta(file_paths, model_metadata, use_full_precision); + get_model_meta(file_paths, model_metadata, use_full_precision, use_bf16_precision); assert(ffconfig.data_parallelism_degree * ffconfig.tensor_parallelism_degree * ffconfig.pipeline_parallelism_degree == @@ -356,26 +371,26 @@ void FlexFlow::top_level_task(Task const *task, model_metadata.llm_weights_path, TREE_VERIFY_MODE, generationConfig, - use_full_precision); + data_type); } else if (model_metadata.llm_model_type == ModelType::OPT) { OPT::create_opt_model(tree_model, model_metadata.llm_model_config_path, model_metadata.llm_weights_path, TREE_VERIFY_MODE, - use_full_precision); + data_type); } else if (model_metadata.llm_model_type == ModelType::FALCON) { FALCON::create_falcon_model(tree_model, model_metadata.llm_model_config_path, model_metadata.llm_weights_path, TREE_VERIFY_MODE, - use_full_precision); + data_type); } else if (model_metadata.llm_model_type == ModelType::MPT) { MPT::create_mpt_model(tree_model, model_metadata.llm_model_config_path, model_metadata.llm_weights_path, TREE_VERIFY_MODE, generationConfig, - use_full_precision); + data_type); } else { assert(false && "Invalid LLM model type passed (or no type was passed)."); } @@ -404,27 +419,27 @@ void FlexFlow::top_level_task(Task const *task, model_metadata.ssm_model_weights_paths[ssm_id], BEAM_SEARCH_MODE, generationConfig, - use_full_precision); + data_type); } else if (model_metadata.ssm_model_types[ssm_id] == ModelType::OPT) { OPT::create_opt_model(beam_model, model_metadata.ssm_model_config_paths[ssm_id], model_metadata.ssm_model_weights_paths[ssm_id], BEAM_SEARCH_MODE, - use_full_precision); + data_type); } else if (model_metadata.ssm_model_types[ssm_id] == ModelType::FALCON) { FALCON::create_falcon_model( beam_model, model_metadata.ssm_model_config_paths[ssm_id], model_metadata.ssm_model_weights_paths[ssm_id], BEAM_SEARCH_MODE, - use_full_precision); + data_type); } else if (model_metadata.ssm_model_types[ssm_id] == ModelType::MPT) { MPT::create_mpt_model(beam_model, model_metadata.ssm_model_config_paths[ssm_id], model_metadata.ssm_model_weights_paths[ssm_id], BEAM_SEARCH_MODE, generationConfig, - use_full_precision); + data_type); } else { assert(false && "Invalid SSM model type passed."); } diff --git a/inference/utils/download_hf_model.py b/inference/utils/download_hf_model.py index 7b4f4d6fb..0a2027169 100644 --- a/inference/utils/download_hf_model.py +++ b/inference/utils/download_hf_model.py @@ -30,6 +30,11 @@ def parse_args(): action="store_true", help="Only download the half precision version of the weights", ) + group.add_argument( + "--bfloat16-precision-only", + action="store_true", + help="Only download the bfloat16 precision version of the weights", + ) args = parser.parse_args() return args @@ -39,8 +44,10 @@ def main(args): data_types = (ff.DataType.DT_FLOAT,) elif args.half_precision_only: data_types = (ff.DataType.DT_HALF,) + elif args.bfloat16_precision_only: + data_types = (ff.DataType.DT_BFLOAT16,) else: - data_types = (ff.DataType.DT_FLOAT, ff.DataType.DT_HALF) + data_types = (ff.DataType.DT_FLOAT, ff.DataType.DT_HALF, ff.DataType.DT_BFLOAT16) for model_name in args.model_names: for data_type in data_types: diff --git a/python/flexflow/core/flexflow_cffi.py b/python/flexflow/core/flexflow_cffi.py index 443cb9bfe..1361762d8 100644 --- a/python/flexflow/core/flexflow_cffi.py +++ b/python/flexflow/core/flexflow_cffi.py @@ -1759,7 +1759,7 @@ def __init__( hidden_dim, qkv_inner_dim, tensor_parallelism_degree, - use_full_precision, + data_type, ): c_weight_file_path = get_c_name(weight_file_path) self.handle = ffc().flexflow_file_data_loader_create( @@ -1769,7 +1769,7 @@ def __init__( hidden_dim, qkv_inner_dim, tensor_parallelism_degree, - use_full_precision, + data_type, ) self._handle = ffi.gc(self.handle, ffc().flexflow_file_data_loader_destroy) diff --git a/python/flexflow/serve/serve.py b/python/flexflow/serve/serve.py index ca467a8f6..52c9fe74d 100644 --- a/python/flexflow/serve/serve.py +++ b/python/flexflow/serve/serve.py @@ -90,7 +90,7 @@ def __init__( :param model_name: The name of the HuggingFace model to use. E.g. 'meta-llama/Llama-2-7b-hf' :type model_name: str - :param data_type: The data type to use for the tensors (e.g. DataType.DT_FLOAT for full precision, or DataType.DT_HALF for half precision), defaults to DataType.DT_HALF + :param data_type: The data type to use for the tensors (e.g. DataType.DT_FLOAT for full precision, or DataType.DT_HALF for half precision, or DataType.DT_BFLOAT16 for bfloat16 precision), defaults to DataType.DT_HALF :type data_type: DataType, optional :param cache_path: Path to the folder (which will be created if it does not yet exist) to use for the FlexFlow weights/tokenizers cache, defaults to "~/.cache/flexflow" :type tokenizer_path: str, optional @@ -108,7 +108,7 @@ def __init__( self.config_class, ) = self.supported_models.get_ff_model_type(self.hf_config) self.data_type = data_type - assert self.data_type == DataType.DT_HALF or self.data_type == DataType.DT_FLOAT + assert self.data_type == DataType.DT_HALF or self.data_type == DataType.DT_FLOAT or self.data_type == DataType.DT_BFLOAT16 self.cache_path = cache_path if len(cache_path) > 0 else "~/.cache/flexflow" self.refresh_cache = refresh_cache self.output_file = output_file @@ -251,6 +251,7 @@ def __get_resource_path( if self.data_type == DataType.DT_FLOAT else "half-precision" ), + # temp: bfloat16 precision share the same folder as half precision ) elif resource_type == CachedResourceType.TOKENIZER: return os.path.join( @@ -323,7 +324,8 @@ def download_and_convert_llm_weights(model_name): torch_dtype=( torch.float32 if self.data_type == DataType.DT_FLOAT - else torch.float16 + else torch.float16 if self.data_type == DataType.DT_HALF + else torch.bfloat16 ), ) # Convert the model to FlexFlow format @@ -393,10 +395,11 @@ def download_peft_adapter_if_needed(self, hf_peft_model_id: str): def download_and_convert_peft_model(hf_peft_model_id: str): if ( self.data_type != DataType.DT_FLOAT - and self.data_type != DataType.DT_HALF + and self.data_type != DataType.DT_HALF + and self.data_type != DataType.DT_BFLOAT16 ): raise ValueError( - "data_type must be either DataType.DT_FLOAT or DataType.DT_HALF" + "data_type must be either DataType.DT_FLOAT or DataType.DT_HALF or DataType.DT_BFLOAT16" ) # Save peft config to file @@ -574,7 +577,7 @@ def compile( model_configs.hidden_size, model_configs.hidden_size // model_configs.num_attention_heads, self.ffconfig.tensor_parallelism_degree, - self.data_type == DataType.DT_FLOAT, + self.data_type, ) # Register weights file loader diff --git a/python/flexflow/torch/model.py b/python/flexflow/torch/model.py index df4042748..88287d2ad 100644 --- a/python/flexflow/torch/model.py +++ b/python/flexflow/torch/model.py @@ -171,7 +171,9 @@ def string_to_node_class(string): @staticmethod def torch_to_ff_dtype(torch_dtype): - if torch_dtype in (torch.float32, torch.float, "float32", "float"): + if torch_dtype in (torch.bfloat16, "bfloat16"): + return DataType.DT_BFLOAT16 + elif torch_dtype in (torch.float32, torch.float, "float32", "float"): return DataType.DT_FLOAT elif torch_dtype in (torch.float64, torch.double, "float64", "double"): return DataType.DT_DOUBLE @@ -184,7 +186,9 @@ def torch_to_ff_dtype(torch_dtype): @staticmethod def numpy_to_ff_dtype(numpy_dtype): - if numpy_dtype in (np.float32, np.float, "float32", "float"): + if numpy_dtype in (np.bfloat16, "bfloat16"): + return DataType.DT_BFLOAT16 + elif numpy_dtype in (np.float32, np.float, "float32", "float"): return DataType.DT_FLOAT elif numpy_dtype in (np.float64, np.double, "float64", "double"): return DataType.DT_DOUBLE diff --git a/python/flexflow/type.py b/python/flexflow/type.py index c2eebe899..0060caa12 100644 --- a/python/flexflow/type.py +++ b/python/flexflow/type.py @@ -36,6 +36,7 @@ class DataType(Enum): DT_FLOAT = 44 DT_DOUBLE = 45 DT_NONE = 49 + DT_BFLOAT16 = 50 class LossType(Enum): @@ -204,6 +205,8 @@ def data_type_size(value: DataType): return 8 elif value == DataType.DT_HALF: return 2 + elif value == DataType.DT_BFLOAT16: + return 2 elif value == DataType.DT_FLOAT: return 4 elif value == DataType.DT_DOUBLE: diff --git a/src/c/flexflow_c.cc b/src/c/flexflow_c.cc index 23439b1fe..19a94b7e5 100644 --- a/src/c/flexflow_c.cc +++ b/src/c/flexflow_c.cc @@ -2836,7 +2836,7 @@ flexflow_file_data_loader_t int hidden_dim, int qkv_inner_dim, int tensor_parallelism_degree, - bool use_full_precision) { + DataType data_type) { assert(weight_file_path != nullptr && "Cannot convert nullptr char * to std::string"); std::string const weight_file_path_str(weight_file_path); @@ -2847,7 +2847,7 @@ flexflow_file_data_loader_t hidden_dim, qkv_inner_dim, tensor_parallelism_degree, - use_full_precision); + data_type); DEBUG_PRINT("[FileDataLoader] new %p", handle); return FFCObjectWrapper::wrap(handle); } diff --git a/src/ops/add_bias_residual_layer_norm.cu b/src/ops/add_bias_residual_layer_norm.cu index 16629f493..d3fd39a16 100644 --- a/src/ops/add_bias_residual_layer_norm.cu +++ b/src/ops/add_bias_residual_layer_norm.cu @@ -237,6 +237,13 @@ void AddBiasResidualLayerNorm::inference_kernel_wrapper( data_type_size(m->input_type[0]) * num_peft_tokens * in_dim, cudaMemcpyDeviceToDevice, stream)); + } else if (m->input_type[0] == DT_BFLOAT16) { + checkCUDA(cudaMemcpyAsync( + m->input_activation, + added_output.get_bfloat16_ptr() + first_token_offset * in_dim, + data_type_size(m->input_type[0]) * num_peft_tokens * in_dim, + cudaMemcpyDeviceToDevice, + stream)); } else { assert(false && "unsupport datatype in layernorm"); } @@ -271,6 +278,19 @@ void AddBiasResidualLayerNorm::inference_kernel_wrapper( m->elementwise_affine ? gamma.get_half_ptr() : nullptr, (m->elementwise_affine && m->use_bias) ? beta.get_half_ptr() : nullptr, stream); + } else if (m->input_type[0] == DT_BFLOAT16) { + AddBiasResidualLayerNorm::inference_kernel<__ff_bfloat16>( + m, + attn_bias_dim, + residual_volume, + input.get_bfloat16_ptr(), + attn_bias.get_bfloat16_ptr(), + residual.get_bfloat16_ptr(), + added_output.get_bfloat16_ptr(), + output.get_bfloat16_ptr(), + m->elementwise_affine ? gamma.get_bfloat16_ptr() : nullptr, + (m->elementwise_affine && m->use_bias) ? beta.get_bfloat16_ptr() : nullptr, + stream); } else { assert(false && "unsupport datatype in layernorm"); } @@ -719,6 +739,19 @@ void AddBiasResidualLayerNorm::backward_kernel_wrapper( (m->elementwise_affine && m->use_bias) ? beta_grad.get_half_ptr() : nullptr, stream); + } else if (m->output_type[0] == DT_BFLOAT16){ + AddBiasResidualLayerNorm::backward_kernel( + m, + output_grad.get_bfloat16_ptr(), + added_output.get_bfloat16_ptr(), + input_grad.get_bfloat16_ptr(), + residual_grad.get_bfloat16_ptr(), + attn_bias_grad.get_bfloat16_ptr(), + m->elementwise_affine ? gamma.get_bfloat16_ptr() : nullptr, + m->elementwise_affine ? gamma_grad.get_bfloat16_ptr() : nullptr, + (m->elementwise_affine && m->use_bias) ? beta_grad.get_bfloat16_ptr() + : nullptr, + stream); } else { assert(false && "Unsupported data type"); } @@ -793,6 +826,13 @@ void AddBiasResidualLayerNorm::peft_bwd_kernel_wrapper( residual_grad.get_half_ptr(), m->elementwise_affine ? gamma.get_half_ptr() : nullptr, stream); + } else if (m->output_type[0] == DT_BFLOAT16) { + peft_bwd_kernel(m, + output_grad.get_bfloat16_ptr(), + input_grad.get_bfloat16_ptr(), + residual_grad.get_bfloat16_ptr(), + m->elementwise_affine ? gamma.get_bfloat16_ptr() : nullptr, + stream); } else { assert(false && "Unsupported data type"); } diff --git a/src/ops/arg_topk.cu b/src/ops/arg_topk.cu index 5b7978812..83c05164a 100644 --- a/src/ops/arg_topk.cu +++ b/src/ops/arg_topk.cu @@ -87,8 +87,7 @@ struct StridedData { // A heap of Entry that can either work as a min-heap or as a max-heap. template - class Data, + template class Data, typename T> struct IndexedHeap { typedef typename Data::Entry Entry; @@ -195,8 +194,7 @@ struct IndexedHeap { template - class Data, + template class Data, typename T> __device__ IndexedHeap make_indexed_heap(typename Data::Entry *data) { @@ -520,6 +518,18 @@ void ArgTopK::forward_kernel_wrapper(ArgTopKMeta const *m, m->sorted, m->speculative_decoding ? bc : nullptr, stream); + } else if (input.data_type == DT_BFLOAT16) { + ArgTopK::forward_kernel(m, + input.get_bfloat16_ptr(), + m->speculative_decoding ? probs.get_float_ptr() + : nullptr, + indices.get_int32_ptr(), + batch_size, + length, + k, + m->sorted, + m->speculative_decoding ? bc : nullptr, + stream); } else { assert(false && "Unsupported data type"); } diff --git a/src/ops/argmax.cpp b/src/ops/argmax.cpp index 9c8ba7c5f..847a8efd1 100644 --- a/src/ops/argmax.cpp +++ b/src/ops/argmax.cpp @@ -524,6 +524,18 @@ void ArgMax::forward_kernel_wrapper(ArgMaxMeta const *m, batch_size, loss, stream); + } else if (input.data_type == DT_BFLOAT16) { + ArgMax::forward_kernel<__ff_bfloat16>(m, + bc, + input.get_bfloat16_ptr(), + indices.get_int32_ptr(), + m->probs, + m->beam_search ? parent.get_int32_ptr() + : nullptr, + length, + batch_size, + loss, + stream); } else { assert(false && "Unsupported data type"); } @@ -549,7 +561,7 @@ ArgMaxMeta::ArgMaxMeta(FFHandler handler, : OpMeta(handler, op) { DataType data_type = op->data_type; size_t prob_size = batch_size; - assert(data_type == DT_FLOAT || data_type == DT_HALF); + assert(data_type == DT_FLOAT || data_type == DT_HALF || data_type == DT_BFLOAT16); size_t total_size = prob_size * sizeof(float); gpu_mem_allocator.create_legion_instance( reserveInst, total_size, "ArgMaxMeta"); diff --git a/src/ops/argmax.cu b/src/ops/argmax.cu index 0dc714c26..ced3bbbe7 100644 --- a/src/ops/argmax.cu +++ b/src/ops/argmax.cu @@ -44,6 +44,10 @@ __device__ __forceinline__ float toFloat(T x) { return x; } else if constexpr (std::is_same::value) { return __half2float(x); + } else if constexpr (std::is_same::value) { + return __bfloat162float(x); + } else { + assert(false && "Unsupported data type"); } } @@ -233,6 +237,18 @@ void ArgMax::forward_kernel_wrapper(ArgMaxMeta const *m, batch_size, loss, stream); + } else if (input.data_type == DT_BFLOAT16) { + ArgMax::forward_kernel<__ff_bfloat16>(m, + bc, + input.get_bfloat16_ptr(), + indices.get_int32_ptr(), + m->beam_search ? m->probs : nullptr, + m->beam_search ? parent.get_int32_ptr() + : nullptr, + length, + batch_size, + loss, + stream); } else { assert(false && "Unsupported data type"); } @@ -263,7 +279,7 @@ ArgMaxMeta::ArgMaxMeta(FFHandler handler, // size_t d_offsets_size = batch_size; size_t prob_size = batch_size; - assert(data_type == DT_FLOAT || data_type == DT_HALF); + assert(data_type == DT_FLOAT || data_type == DT_HALF || data_type == DT_BFLOAT16); size_t total_size = prob_size * sizeof(float); gpu_mem_allocator.create_legion_instance( reserveInst, total_size, "ArgMaxMeta"); diff --git a/src/ops/beam_topk.cu b/src/ops/beam_topk.cu index 7bfb6510d..a490b6c1e 100644 --- a/src/ops/beam_topk.cu +++ b/src/ops/beam_topk.cu @@ -91,8 +91,7 @@ struct StridedData { // A heap of Entry that can either work as a min-heap or as a max-heap. template - class Data, + template class Data, typename T> struct IndexedHeap { typedef typename Data::Entry Entry; @@ -199,8 +198,7 @@ struct IndexedHeap { template - class Data, + template class Data, typename T> __device__ IndexedHeap make_indexed_heap(typename Data::Entry *data) { @@ -343,7 +341,11 @@ __device__ void mergeBeamShards(int num_shards, int const last_k = k - 1; for (int rank = 0; rank < last_k; rank++) { Entry const &max_element = max_heap.root(); - top_k_values[rank] = __half2float(max_element.value); + if constexpr (std::is_same::value) { + top_k_values[rank] = __bfloat162float(max_element.value); + } else { + top_k_values[rank] = __half2float(max_element.value); + } int shard_index = max_element.index; top_k_indices[rank] = entries[shard_index].index; top_k_parents[rank] = @@ -366,7 +368,11 @@ __device__ void mergeBeamShards(int num_shards, // rank == last_k. Entry const &max_element = max_heap.root(); - top_k_values[last_k] = __half2float(max_element.value); + if constexpr (std::is_same::value) { + top_k_values[last_k] = __bfloat162float(max_element.value); + } else { + top_k_values[last_k] = __half2float(max_element.value); + } int shard_index = max_element.index; top_k_indices[last_k] = entries[shard_index].index; top_k_parents[last_k] = @@ -379,9 +385,9 @@ template __global__ void mergeSubRequestsKernel(int64_t N, T const *X, T const *rstd, T *Y) { using T_ACC = T; - const int64_t i = blockIdx.x; + int64_t const i = blockIdx.x; for (int64_t j = threadIdx.x; j < N; j += blockDim.x) { - const int64_t index = i * N + j; + int64_t const index = i * N + j; Y[index] = static_cast(X[index]) * static_cast(rstd[i]); } } @@ -707,6 +713,19 @@ void BeamTopK::forward_kernel_wrapper(BeamTopKMeta const *m, length, sorted, stream); + } else if (input.data_type == DT_BFLOAT16) { + BeamTopK::forward_kernel(m, + bc, + input.get_bfloat16_ptr(), + output_ptr, + indices_ptr, + parent_ptr, + batch_size, + length, + sorted, + stream); + } else { + assert(false && "Unsupported data type"); } if (m->profiling) { @@ -748,6 +767,9 @@ BeamTopKMeta::BeamTopKMeta(FFHandler handler, acc_probs = gpu_mem_allocator.allocate_instance(acc_probs_size); } else if (data_type == DT_HALF) { acc_probs = gpu_mem_allocator.allocate_instance(acc_probs_size); + } else if (data_type == DT_BFLOAT16) { + acc_probs = + gpu_mem_allocator.allocate_instance<__ff_bfloat16>(acc_probs_size); } else { assert(false); } diff --git a/src/ops/inc_multihead_self_attention.cpp b/src/ops/inc_multihead_self_attention.cpp index 5951dc3e6..2dafcb1f1 100644 --- a/src/ops/inc_multihead_self_attention.cpp +++ b/src/ops/inc_multihead_self_attention.cpp @@ -1389,6 +1389,18 @@ void transposeAdd(float *out, out, in, width, height, alpha, beta); } +template <> +void transposeAdd<__ff_bfloat16>(__ff_bfloat16 *out, + __ff_bfloat16 const *in, + int width, + int height, + float alpha, + float beta, + hipStream_t stream) { + transposeAdd_float_kernel<<<4, 1024, 0, stream>>>( + out, in, width, height, __float2bfloat16(alpha), __float2bfloat16(beta)); +} + template <> void transposeAdd(half *out, half const *in, @@ -2306,6 +2318,37 @@ template void Kernels::IncMultiHeadAttention::run_batched_matmul( int batch_ratio_c, bool bwd); +template void Kernels::IncMultiHeadAttention::run_batched_matmul<__ff_bfloat16>( + IncMultiHeadSelfAttentionMeta const *meta, + hipblasHandle_t handle, + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, + int n, + int k, + void const *alpha, + __ff_bfloat16 const *A, + hipblasDatatype_t Atype, + int lda, + long long int strideA, + __ff_bfloat16 const *B, + hipblasDatatype_t Btype, + int ldb, + long long int strideB, + void const *beta, + __ff_bfloat16 *C, + hipblasDatatype_t Ctype, + int ldc, + long long int strideC, + int batchCount, + hipblasDatatype_t computeType, + hipblasGemmAlgo_t algo, + hipStream_t stream, + int batch_ratio_a, + int batch_ratio_b, + int batch_ratio_c, + bool bwd); + template void Kernels::IncMultiHeadAttention::compute_attention_kernel_generation( IncMultiHeadSelfAttentionMeta const *m, @@ -2313,6 +2356,13 @@ template void float *output_ptr, hipStream_t stream); +template void + Kernels::IncMultiHeadAttention::compute_attention_kernel_generation<__ff_bfloat16>( + IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + __ff_bfloat16 *output_ptr, + hipStream_t stream); + template void Kernels::IncMultiHeadAttention::compute_attention_kernel_generation( IncMultiHeadSelfAttentionMeta const *m, @@ -2320,6 +2370,13 @@ template void half *output_ptr, hipStream_t stream); +template void + Kernels::IncMultiHeadAttention::compute_attention_kernel_generation<__ff_bfloat16>( + IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + __ff_bfloat16 *output_ptr, + hipStream_t stream); + template void Kernels::IncMultiHeadAttention::apply_scaling_and_rotary( IncMultiHeadSelfAttentionMeta const *m, BatchConfig const *bc, @@ -2334,4 +2391,11 @@ template void Kernels::IncMultiHeadAttention::apply_scaling_and_rotary( half *output_ptr, hipStream_t stream); +template void Kernels::IncMultiHeadAttention::apply_scaling_and_rotary<__ff_bfloat16>( + IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + int shard_id, + __ff_bfloat16 *output_ptr, + hipStream_t stream); + }; // namespace FlexFlow diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index 343f6becd..14d9556d3 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -2174,17 +2174,15 @@ void IncMultiHeadSelfAttention::inference_kernel_wrapper( output.get_half_ptr(), inf_stream, peft_stream); - } - // else if (input.data_type == DT_BFLOAT16) { - // Kernels::IncMultiHeadAttention::inference_kernel(m, - // bc, - // shard_id, - // input.get_bfloat16_ptr(), - // output.get_bfloat16_ptr(), - // inf_stream, - // peft_stream); - // } - else { + } else if (input.data_type == DT_BFLOAT16) { + Kernels::IncMultiHeadAttention::inference_kernel(m, + bc, + shard_id, + input.get_bfloat16_ptr(), + output.get_bfloat16_ptr(), + inf_stream, + peft_stream); + } else { assert(false && "Unspported data type"); } @@ -2227,9 +2225,19 @@ void IncMultiHeadSelfAttention::peft_bwd_kernel_wrapper( input_grad.get_half_ptr(), output_grad.get_half_ptr(), stream); + } else if (input_grad.data_type == DT_BFLOAT16) { + assert(!m->offload); + Kernels::IncMultiHeadAttention::flash_peft_bwd_kernel( + m, + bc, + shard_id, + input_grad.get_bfloat16_ptr(), + output_grad.get_bfloat16_ptr(), + stream); } else { assert(false && "Unspported data type"); } + if (m->profiling) { cudaEventRecord(t_end, stream); checkCUDA(cudaEventSynchronize(t_end)); @@ -2344,8 +2352,9 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( if (num_q_heads > num_kv_heads && (infer_mode == BEAM_SEARCH_MODE)) { assert(num_q_heads % num_kv_heads == 0 && "num_q_heads must be divisible by num_kv_heads"); - assert(attn->data_type == DT_FLOAT || - attn->data_type == DT_HALF && "Unsupported data type"); + assert((attn->data_type == DT_FLOAT || attn->data_type == DT_HALF || + attn->data_type == DT_BFLOAT16) && + "Unsupported data type"); gqa_ptr_array_size = num_q_heads * sizeof(void *); inf_instance_size += 3 * gqa_ptr_array_size; // fwd } @@ -2624,6 +2633,37 @@ template void Kernels::IncMultiHeadAttention::run_batched_matmul( int batch_ratio_c, bool bwd); +template void Kernels::IncMultiHeadAttention::run_batched_matmul<__ff_bfloat16>( + IncMultiHeadSelfAttentionMeta const *meta, + cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + void const *alpha, + __ff_bfloat16 const *A, + cudaDataType Atype, + int lda, + long long int strideA, + __ff_bfloat16 const *B, + cudaDataType Btype, + int ldb, + long long int strideB, + void const *beta, + __ff_bfloat16 *C, + cudaDataType Ctype, + int ldc, + long long int strideC, + int batchCount, + cudaDataType computeType, + cublasGemmAlgo_t algo, + cudaStream_t stream, + int batch_ratio_a, + int batch_ratio_b, + int batch_ratio_c, + bool bwd); + template void Kernels::IncMultiHeadAttention::apply_scaling_and_rotary( IncMultiHeadSelfAttentionMeta const *m, BatchConfig const *bc, @@ -2638,6 +2678,13 @@ template void Kernels::IncMultiHeadAttention::apply_scaling_and_rotary( half *output_ptr, cudaStream_t inf_stream); +template void Kernels::IncMultiHeadAttention::apply_scaling_and_rotary<__ff_bfloat16>( + IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + int shard_id, + __ff_bfloat16 *output_ptr, + cudaStream_t inf_stream); + template __global__ void Kernels::IncMultiHeadAttention::apply_position_bias_qkprd( float *input_ptr, @@ -2655,4 +2702,13 @@ template __global__ void int num_heads, int global_num_q_heads, int shard_id); + +template __global__ void + Kernels::IncMultiHeadAttention::apply_position_bias_qkprd<__ff_bfloat16>( + __ff_bfloat16 *input_ptr, + int num_tokens, + int num_total_tokens, + int num_heads, + int global_num_q_heads, + int shard_id); }; // namespace FlexFlow diff --git a/src/ops/kernels/decompress_kernels.cpp b/src/ops/kernels/decompress_kernels.cpp index 22bf93d44..19bdcde41 100644 --- a/src/ops/kernels/decompress_kernels.cpp +++ b/src/ops/kernels/decompress_kernels.cpp @@ -54,10 +54,20 @@ template __global__ void decompress_int4_general_weights( char const *input_weight_ptr, float *weight_ptr, int in_dim, int valueSize); template __global__ void decompress_int4_general_weights( char const *input_weight_ptr, half *weight_ptr, int in_dim, int valueSize); +template __global__ void + decompress_int4_general_weights<__ff_bfloat16>(char const *input_weight_ptr, + __ff_bfloat16 *weight_ptr, + int in_dim, + int valueSize); template __global__ void decompress_int8_general_weights( char const *input_weight_ptr, float *weight_ptr, int in_dim, int valueSize); template __global__ void decompress_int8_general_weights( char const *input_weight_ptr, half *weight_ptr, int in_dim, int valueSize); +template __global__ void + decompress_int8_general_weights<__ff_bfloat16>(char const *input_weight_ptr, + __ff_bfloat16 *weight_ptr, + int in_dim, + int valueSize); template __global__ void decompress_int4_attention_weights(char *input_weight_ptr, float *weight_ptr, diff --git a/src/ops/kernels/decompress_kernels.cu b/src/ops/kernels/decompress_kernels.cu index 2e02ce1ee..574ef9175 100644 --- a/src/ops/kernels/decompress_kernels.cu +++ b/src/ops/kernels/decompress_kernels.cu @@ -209,10 +209,14 @@ template __global__ void decompress_int4_general_weights( char const *input_weight_ptr, float *weight_ptr, int in_dim, int valueSize); template __global__ void decompress_int4_general_weights( char const *input_weight_ptr, half *weight_ptr, int in_dim, int valueSize); +template __global__ void decompress_int4_general_weights<__ff_bfloat16>( + char const *input_weight_ptr, __ff_bfloat16 *weight_ptr, int in_dim, int valueSize); template __global__ void decompress_int8_general_weights( char const *input_weight_ptr, float *weight_ptr, int in_dim, int valueSize); template __global__ void decompress_int8_general_weights( char const *input_weight_ptr, half *weight_ptr, int in_dim, int valueSize); +template __global__ void decompress_int8_general_weights<__ff_bfloat16>( + char const *input_weight_ptr, __ff_bfloat16 *weight_ptr, int in_dim, int valueSize); template __global__ void decompress_int4_attention_weights(char *input_weight_ptr, float *weight_ptr, diff --git a/src/ops/kernels/element_binary_kernels.cu b/src/ops/kernels/element_binary_kernels.cu index 91475c6a8..4cf44b6c7 100644 --- a/src/ops/kernels/element_binary_kernels.cu +++ b/src/ops/kernels/element_binary_kernels.cu @@ -106,6 +106,12 @@ void forward_kernel_wrapper(ElementBinaryMeta const *m, in2.get_float_ptr(), out.get_float_ptr(), stream); + } else if (out.data_type == DT_BFLOAT16) { + Internal::forward_kernel(m, + in1.get_bfloat16_ptr(), + in2.get_bfloat16_ptr(), + out.get_bfloat16_ptr(), + stream); } else { assert(false && "Unsupported data type"); } diff --git a/src/ops/kernels/embedding_kernels.cu b/src/ops/kernels/embedding_kernels.cu index 60e6da8d8..86d299896 100644 --- a/src/ops/kernels/embedding_kernels.cu +++ b/src/ops/kernels/embedding_kernels.cu @@ -50,6 +50,16 @@ void forward_kernel_wrapper(EmbeddingMeta const *m, m->aggr, output.domain.get_volume(), stream); + } else if (weight.data_type == DT_BFLOAT16) { + Internal::forward_kernel(input.get_int32_ptr(), + output.get_bfloat16_ptr(), + weight.get_bfloat16_ptr(), + in_dim, + out_dim, + batch_size, + m->aggr, + output.domain.get_volume(), + stream); } else if (weight.data_type == DT_FLOAT) { Internal::forward_kernel(input.get_int32_ptr(), output.get_float_ptr(), @@ -84,10 +94,10 @@ void forward_kernel_wrapper(EmbeddingMeta const *m, m->aggr, output.domain.get_volume(), stream); - } else if (weight.data_type == DT_FLOAT) { + } else if (weight.data_type == DT_BFLOAT16) { Internal::forward_kernel(input.get_int64_ptr(), - output.get_float_ptr(), - weight.get_float_ptr(), + output.get_bfloat16_ptr(), + weight.get_bfloat16_ptr(), in_dim, out_dim, batch_size, @@ -110,14 +120,30 @@ void forward_kernel_wrapper(EmbeddingMeta const *m, } else { assert(false && "Unsupported DataType in Embedding"); } - if (m->profiling) { - checkCUDA(cudaDeviceSynchronize()); - // print_tensor(input_ptr, input_domain.get_volume(), - // "[Embedding:forward:input]"); print_tensor(kernel_ptr, - // kernel_domain.get_volume(), "[Embedding:forward:weight]"); - // print_tensor(output_ptr, output_domain.get_volume(), - // "[Embedding:forward:output]"); + // if (m->profiling) { + checkCUDA(cudaDeviceSynchronize()); + print_tensor( + input.get_int32_ptr(), 32, "[Embedding:forward:input]"); + + if (weight.data_type == DT_HALF) { + print_tensor( + weight.get_half_ptr(), 32, "[Embedding:forward:weight]"); + } else if (weight.data_type == DT_FLOAT) { + print_tensor( + weight.get_float_ptr(), 32, "[Embedding:forward:weight]"); + } else { + assert(false && "Unsupported DataType in Embedding"); + } + if (output.data_type == DT_FLOAT) { + print_tensor( + output.get_float_ptr(), 32, "[Embedding:forward:output]"); + } else if (output.data_type == DT_HALF) { + print_tensor( + output.get_half_ptr(), 32, "[Embedding:forward:output]"); + } else { + assert(false && "Unsupported DataType in Embedding"); } + // } } /*static*/ diff --git a/src/ops/kernels/linear_kernels.cpp b/src/ops/kernels/linear_kernels.cpp index 3de30070f..b8d980a5c 100644 --- a/src/ops/kernels/linear_kernels.cpp +++ b/src/ops/kernels/linear_kernels.cpp @@ -218,6 +218,16 @@ void inference_kernel_wrapper(LinearMeta *m, Internal::store_peft_activations( m, bc, out_dim, static_cast(output_ptr), stream); } + } else if (m->input_type[0] == DT_BFLOAT16) { + Internal::inference_kernel(m, + input_ptr, + output_ptr, + weight_ptr, + bias_ptr, + in_dim, + out_dim, + batch_size, + stream); } if (m->profiling) { @@ -264,6 +274,15 @@ void peft_bwd_kernel_wrapper(LinearMeta const *m, in_dim, out_dim, stream); + } else if (m->input_type[0] == DT_BFLOAT16) { + Internal::peft_bwd_kernel(m, + bc, + input_grad_ptr, + output_grad_ptr, + weight_ptr, + in_dim, + out_dim, + stream); } if (m->profiling) { @@ -351,7 +370,6 @@ void inference_kernel(LinearMeta const *m, in_dim, in_dim * out_dim); } - } else { checkCUDA(hipMemcpyAsync(m->weight_ptr, weight_ptr, diff --git a/src/ops/kernels/linear_kernels.cu b/src/ops/kernels/linear_kernels.cu index f40c31433..f14c6af96 100644 --- a/src/ops/kernels/linear_kernels.cu +++ b/src/ops/kernels/linear_kernels.cu @@ -83,6 +83,14 @@ LinearMeta::LinearMeta(FFHandler handler, min(CUDA_NUM_THREADS, parallelism), 0, stream>>>((half *)one_ptr, batch_size); + } else if (data_type == DT_BFLOAT16) { + Kernels::Linear::Internal:: + build_one_ptr<<>>((__ff_bfloat16 *)one_ptr, batch_size); + } else { + assert(false && "Unsupported data type"); } } else { one_ptr = nullptr; @@ -219,6 +227,23 @@ void inference_kernel_wrapper(LinearMeta *m, Internal::store_peft_activations( m, bc, out_dim, static_cast(output_ptr), stream); } + } else if (m->input_type[0] == DT_BFLOAT16) { + Internal::inference_kernel(m, + input_ptr, + output_ptr, + weight_ptr, + bias_ptr, + in_dim, + out_dim, + batch_size, + stream); + if ((m->activation == AC_MODE_RELU || m->activation == AC_MODE_SIGMOID) && + bc->num_finetuning_fwd_requests() > 0) { + Internal::store_peft_activations<__ff_bfloat16>( + m, bc, out_dim, static_cast<__ff_bfloat16 *>(output_ptr), stream); + } + } else { + assert(false && "Unsupported data type"); } if (m->profiling) { @@ -265,6 +290,18 @@ void peft_bwd_kernel_wrapper(LinearMeta const *m, in_dim, out_dim, stream); + } else if (m->input_type[0] == DT_BFLOAT16) { + // cublas scale type: https://docs.nvidia.com/cuda/cublas/index.html?highlight=cublasGemmEx#cublasgemmex + Internal::peft_bwd_kernel(m, + bc, + input_grad_ptr, + output_grad_ptr, + weight_ptr, + in_dim, + out_dim, + stream); + } else { + assert(false && "Unsupported data type"); } if (m->profiling) { @@ -347,7 +384,6 @@ void inference_kernel(LinearMeta const *m, in_dim, in_dim * out_dim); } - } else { cudaMemcpyAsync(m->weight_ptr, weight_ptr, @@ -358,6 +394,8 @@ void inference_kernel(LinearMeta const *m, } checkCUDA(cublasSetStream(m->handle.blas, stream)); checkCUDNN(cudnnSetStream(m->handle.dnn, stream)); + // CUDA_R_32F for bf16: + // https://docs.nvidia.com/cuda/cublas/index.html?highlight=cublasGemmEx#cublasgemmex DT alpha = 1.0f, beta = 0.0f; cudaDataType_t input_type = ff_to_cuda_datatype(m->input_type[0]); cudaDataType_t weight_type = m->offload @@ -365,7 +403,8 @@ void inference_kernel(LinearMeta const *m, : ff_to_cuda_datatype(m->weight_type[0]); cudaDataType_t output_type = ff_to_cuda_datatype(m->output_type[0]); assert(input_type == weight_type && weight_type == output_type); - cudaDataType_t compute_type = output_type; + cudaDataType_t compute_type = + output_type == CUDA_R_16BF ? CUDA_R_32F : output_type; checkCUDA(cublasGemmEx(m->handle.blas, CUBLAS_OP_T, CUBLAS_OP_N, @@ -492,7 +531,7 @@ void peft_bwd_kernel(LinearMeta const *m, input_grad_ptr = static_cast
(input_grad_ptr); output_grad_ptr = static_cast
(output_grad_ptr); - cudaDataType_t compute_type = output_type; + cudaDataType_t compute_type = output_type == CUDA_R_16BF ? CUDA_R_32F : output_type; int output_size = out_dim * num_peft_tokens; if (m->activation == AC_MODE_RELU) { relu_backward_kernel(m->output_type[0], diff --git a/src/ops/kernels/lora_linear_kernels.cpp b/src/ops/kernels/lora_linear_kernels.cpp index 40f836a63..0fb3c7ec6 100644 --- a/src/ops/kernels/lora_linear_kernels.cpp +++ b/src/ops/kernels/lora_linear_kernels.cpp @@ -66,23 +66,30 @@ void inference_kernel_wrapper(LoraLinearMeta *m, checkCUDA(hipEventRecord(t_start, stream)); } if (m->input_type[0] == DT_FLOAT) { - Internal::inference_kernel(m, - bc, - input.get_float_ptr(), - output.get_float_ptr(), - in_dim, - out_dim, - stream); + Internal::inference_kernel(m, + bc, + input.get_float_ptr(), + output.get_float_ptr(), + in_dim, + out_dim, + stream); } else if (m->input_type[0] == DT_HALF) { - Internal::inference_kernel(m, - bc, - input.get_half_ptr(), - output.get_half_ptr(), - in_dim, - out_dim, - stream); + Internal::inference_kernel(m, + bc, + input.get_half_ptr(), + output.get_half_ptr(), + in_dim, + out_dim, + stream); + } else if (m->input_type[0] == DT_BFLOAT16) { + Internal::inference_kernel(m, + bc, + input.get_bfloat16_ptr(), + output.get_bfloat16_ptr(), + in_dim, + out_dim, + stream); } - if (m->profiling) { checkCUDA(hipEventRecord(t_end, stream)); checkCUDA(hipEventSynchronize(t_end)); diff --git a/src/ops/kernels/lora_linear_kernels.cu b/src/ops/kernels/lora_linear_kernels.cu index 7336a0947..661c86f28 100644 --- a/src/ops/kernels/lora_linear_kernels.cu +++ b/src/ops/kernels/lora_linear_kernels.cu @@ -65,21 +65,31 @@ void inference_kernel_wrapper(LoraLinearMeta *m, cudaEventRecord(t_start, stream); } if (m->input_type[0] == DT_FLOAT) { - Internal::inference_kernel(m, - bc, - input.get_float_ptr(), - output.get_float_ptr(), - in_dim, - out_dim, - stream); + Internal::inference_kernel(m, + bc, + input.get_float_ptr(), + output.get_float_ptr(), + in_dim, + out_dim, + stream); } else if (m->input_type[0] == DT_HALF) { - Internal::inference_kernel(m, - bc, - input.get_half_ptr(), - output.get_half_ptr(), - in_dim, - out_dim, - stream); + Internal::inference_kernel(m, + bc, + input.get_half_ptr(), + output.get_half_ptr(), + in_dim, + out_dim, + stream); + } else if (m->input_type[0] == DT_BFLOAT16) { + Internal::inference_kernel(m, + bc, + input.get_bfloat16_ptr(), + output.get_bfloat16_ptr(), + in_dim, + out_dim, + stream); + } else { + assert(false && "Unsupported data type"); } if (m->profiling) { @@ -117,27 +127,41 @@ void peft_bwd_kernel_wrapper(Context ctx, int in_dim = input_grad.domain.hi()[0] - input_grad.domain.lo()[0] + 1; int out_dim = output_grad.domain.hi()[0] - output_grad.domain.lo()[0] + 1; if (m->input_type[0] == DT_FLOAT) { - Internal::peft_bwd_kernel(ctx, - runtime, - m, - bc, - shard_id, - input_grad.get_float_ptr(), - output_grad.get_float_ptr(), - in_dim, - out_dim, - stream); + Internal::peft_bwd_kernel(ctx, + runtime, + m, + bc, + shard_id, + input_grad.get_float_ptr(), + output_grad.get_float_ptr(), + in_dim, + out_dim, + stream); } else if (m->input_type[0] == DT_HALF) { - Internal::peft_bwd_kernel(ctx, - runtime, - m, - bc, - shard_id, - input_grad.get_half_ptr(), - output_grad.get_half_ptr(), - in_dim, - out_dim, - stream); + Internal::peft_bwd_kernel(ctx, + runtime, + m, + bc, + shard_id, + input_grad.get_half_ptr(), + output_grad.get_half_ptr(), + in_dim, + out_dim, + stream); + } else if (m->input_type[0] == DT_BFLOAT16) { + Internal::peft_bwd_kernel( + ctx, + runtime, + m, + bc, + shard_id, + input_grad.get_bfloat16_ptr(), + output_grad.get_bfloat16_ptr(), + in_dim, + out_dim, + stream); + } else { + assert(false && "Unsupported data type"); } if (m->profiling) { @@ -170,11 +194,11 @@ bool lora_applies_to_this_layer(LoraLinearMeta *m, namespace Internal { -template +template void inference_kernel(LoraLinearMeta *m, BatchConfig const *bc, - DT const *input_ptr, - DT *output_ptr, + DATA_DT const *input_ptr, + DATA_DT *output_ptr, int in_dim, int out_dim, ffStream_t stream) { @@ -185,7 +209,8 @@ void inference_kernel(LoraLinearMeta *m, cudaDataType_t lr_actv_type = output_type; assert(input_type == output_type); cudaDataType_t weight_type = output_type; - cudaDataType_t compute_type = output_type; + cudaDataType_t compute_type = + output_type == CUDA_R_16BF ? CUDA_R_32F : output_type; int num_peft_requests = 0; for (int i = 0; i < bc->max_requests_per_batch(); i++) { @@ -228,7 +253,7 @@ void inference_kernel(LoraLinearMeta *m, assert(m->handle.workSpaceSize >= data_type_size(m->input_type[1]) * num_peft_tokens * lora_config.rank); } - DT alpha = 1.0f, beta = 0.0f; + SCALE_DT alpha = 1.0f, beta = 0.0f; // buffer = weight_first * input // [rank, num_peft_tokens] = [in_dim, rank].T * [in_dim, num_peft_tokens] checkCUDA(cublasGemmEx(m->handle.blas, @@ -254,7 +279,8 @@ void inference_kernel(LoraLinearMeta *m, // [out_dim, num_peft_tokens] = [rank, out_dim].T * [rank, num_peft_tokens] // Note that we use alpha in both places since we do // an in-place update for LoraLinear - DT scaling_constant = (DT)(lora_config.lora_alpha / lora_config.rank); + SCALE_DT scaling_constant = + (SCALE_DT)(lora_config.lora_alpha / lora_config.rank); checkCUDA(cublasGemmEx(m->handle.blas, CUBLAS_OP_T, CUBLAS_OP_N, @@ -302,14 +328,14 @@ __global__ void sgd_update(size_t count, } } -template +template void peft_bwd_kernel(Context ctx, Runtime *runtime, LoraLinearMeta *m, BatchConfig const *bc, int shard_id, - DT *input_grad_ptr, - DT const *output_grad_ptr, + DATA_DT *input_grad_ptr, + DATA_DT const *output_grad_ptr, int in_dim, int out_dim, ffStream_t stream) { @@ -320,7 +346,11 @@ void peft_bwd_kernel(Context ctx, assert(input_type == output_type); cudaDataType_t weight_type = output_type; cudaDataType_t lr_actv_type = output_type; - cudaDataType_t compute_type = output_type; + cudaDataType_t compute_type = + (ff_to_cuda_datatype(m->output_type[0]) == CUDA_R_16BF) + ? CUDA_R_32F + : ff_to_cuda_datatype(m->output_type[0]); + ; assert( bc->peft_bwd_applies_to_this_layer(m->layer_guid.transformer_layer_id)); @@ -343,14 +373,15 @@ void peft_bwd_kernel(Context ctx, // int first_token_offset = bc->requestsInfo[i].first_token_offset_in_batch; LoraLinearWeight weight = m->peft_memory_manager->get_peft( bc->requestsInfo[i].peft_model_id, lora_config); - DT scaling_constant = (DT)(lora_config.lora_alpha / lora_config.rank); + SCALE_DT scaling_constant = + (SCALE_DT)(lora_config.lora_alpha / lora_config.rank); // Compute LORA_B weight's gradient if (bc->requestsInfo[i].optimizer_tasks.compute_gradients) { - DT alpha = 1.0f; - DT beta = (bc->requestsInfo[i].optimizer_tasks.reset_gradients_to_zero) - ? 0.0f - : 1.0f; + SCALE_DT alpha = 1.0f; + SCALE_DT beta = + (bc->requestsInfo[i].optimizer_tasks.reset_gradients_to_zero) ? 0.0f + : 1.0f; // std::cout << "Lora B gradient computation, beta = " << (float) beta << // std::endl; if (m->inference_debugging) { @@ -359,7 +390,7 @@ void peft_bwd_kernel(Context ctx, get_peft_dbg_folder(m, shard_id, false) + ".low_rank_activation.pt"; std::cout << "Save low_rank_activation (" << lora_config.rank << ", " << num_peft_tokens << ") to " << filename << std::endl; - auto tensor = createTorchTensorFromCuda
( + auto tensor = createTorchTensorFromCuda( weight.low_rank_activation, {lora_config.rank, num_peft_tokens}); torch::save(tensor, filename); } @@ -387,7 +418,7 @@ void peft_bwd_kernel(Context ctx, // Compute LORA_B input's (and LORA_A output's) gradient inplace in // low_rank_activation { - DT alpha = 1.0f, beta = 0.0f; + SCALE_DT alpha = 1.0f, beta = 0.0f; checkCUDA(cublasGemmEx(m->handle.peft_blas, CUBLAS_OP_N, CUBLAS_OP_N, @@ -411,10 +442,10 @@ void peft_bwd_kernel(Context ctx, // Compute LORA_A weight's gradient if (bc->requestsInfo[i].optimizer_tasks.compute_gradients) { - DT alpha = 1.0f; - DT beta = (bc->requestsInfo[i].optimizer_tasks.reset_gradients_to_zero) - ? 0.0f - : 1.0f; + SCALE_DT alpha = 1.0f; + SCALE_DT beta = + (bc->requestsInfo[i].optimizer_tasks.reset_gradients_to_zero) ? 0.0f + : 1.0f; checkCUDA(cublasGemmEx(m->handle.peft_blas, CUBLAS_OP_N, CUBLAS_OP_T, @@ -438,8 +469,8 @@ void peft_bwd_kernel(Context ctx, // Compute input gradient // NOTE: we use beta=1 for input_grad to accumulate gradients when needed if (input_grad_ptr != nullptr) { - DT alpha = 1.0f; - DT beta = m->reset_input_grads[0] ? 0.0f : 1.0f; + SCALE_DT alpha = 1.0f; + SCALE_DT beta = m->reset_input_grads[0] ? 0.0f : 1.0f; checkCUDA(cublasGemmEx(m->handle.peft_blas, CUBLAS_OP_N, CUBLAS_OP_N, @@ -480,16 +511,16 @@ void peft_bwd_kernel(Context ctx, sgd_config->weight_decay, sgd_config->momentum, sgd_config->nesterov, - static_cast
(weight.w0_grad_ptr), - static_cast
(weight.w0_v_values_ptr), - static_cast
(weight.w0_ptr)); + static_cast(weight.w0_grad_ptr), + static_cast(weight.w0_v_values_ptr), + static_cast(weight.w0_ptr)); // LoRA_B weight is replicated w tensor parallelism, so we need to sync // and sum first #ifdef FF_USE_NCCL ncclDataType_t nccl_data_type = ff_to_nccl_datatype(m->output_type[0]); runtime->concurrent_task_barrier(ctx); - checkNCCL(ncclAllReduce(static_cast
(weight.w1_grad_ptr), - static_cast
(weight.w1_grad_ptr), + checkNCCL(ncclAllReduce(static_cast(weight.w1_grad_ptr), + static_cast(weight.w1_grad_ptr), w1_num_elements, nccl_data_type, ncclSum, @@ -505,9 +536,9 @@ void peft_bwd_kernel(Context ctx, sgd_config->weight_decay, sgd_config->momentum, sgd_config->nesterov, - static_cast
(weight.w1_grad_ptr), - static_cast
(weight.w1_v_values_ptr), - static_cast
(weight.w1_ptr)); + static_cast(weight.w1_grad_ptr), + static_cast(weight.w1_v_values_ptr), + static_cast(weight.w1_ptr)); } else if (lora_config.optimizer_config->getType() == "Adam") { assert(false && "Adam optimizer type not implemented yet"); } else { diff --git a/src/ops/kernels/residual_rms_norm_kernels.cu b/src/ops/kernels/residual_rms_norm_kernels.cu index 47a235fc9..3025922a6 100644 --- a/src/ops/kernels/residual_rms_norm_kernels.cu +++ b/src/ops/kernels/residual_rms_norm_kernels.cu @@ -274,6 +274,19 @@ void inference_kernel_wrapper(ResidualRMSNormMeta *m, store_peft_activations( m, bc, in_dim, residual_output.get_float_ptr(), stream); } + } else if (output.data_type == DT_BFLOAT16) { + inference_kernel(m, + bc, + input1.get_bfloat16_ptr(), + input2.get_bfloat16_ptr(), + weight.get_bfloat16_ptr(), + residual_output.get_bfloat16_ptr(), + output.get_bfloat16_ptr(), + stream); + if (bc->num_finetuning_fwd_requests() > 0) { + store_peft_activations( + m, bc, in_dim, residual_output.get_bfloat16_ptr(), stream); + } } else { assert(false && "Unsupported data type"); } @@ -433,6 +446,16 @@ void peft_bwd_kernel_wrapper(ResidualRMSNormMeta const *m, input_grad_1.get_float_ptr(), weight.get_float_ptr(), stream); + } else if (output_grad_1.data_type == DT_BFLOAT16) { + peft_bwd_kernel(m, + bc, + m->reset_input_grads[0] ? nullptr + : output_grad_0.get_bfloat16_ptr(), + output_grad_1.get_bfloat16_ptr(), + input_grad_0.get_bfloat16_ptr(), + input_grad_1.get_bfloat16_ptr(), + weight.get_bfloat16_ptr(), + stream); } else { assert(false && "Unsupported data type"); } diff --git a/src/ops/kernels/rms_norm_kernels.cu b/src/ops/kernels/rms_norm_kernels.cu index f9403bf86..5da071eb2 100644 --- a/src/ops/kernels/rms_norm_kernels.cu +++ b/src/ops/kernels/rms_norm_kernels.cu @@ -249,6 +249,16 @@ void inference_kernel_wrapper(RMSNormMeta *m, if (bc->num_finetuning_fwd_requests() > 0) { store_peft_activations(m, bc, m->in_dim, input.get_float_ptr(), stream); } + } else if (output.data_type == DT_BFLOAT16) { + inference_kernel(m, + bc, + input.get_bfloat16_ptr(), + weight.get_bfloat16_ptr(), + output.get_bfloat16_ptr(), + stream); + if (bc->num_finetuning_fwd_requests() > 0) { + store_peft_activations(m, bc, m->in_dim, input.get_bfloat16_ptr(), stream); + } } else { assert(false && "Unsupported data type"); } @@ -380,6 +390,13 @@ void peft_bwd_kernel_wrapper(RMSNormMeta const *m, input_grad.get_float_ptr(), weight.get_float_ptr(), stream); + } else if (output_grad.data_type == DT_BFLOAT16) { + peft_bwd_kernel(m, + bc, + output_grad.get_bfloat16_ptr(), + input_grad.get_bfloat16_ptr(), + weight.get_bfloat16_ptr(), + stream); } else { assert(false && "Unsupported data type"); } diff --git a/src/ops/kernels/softmax.cu b/src/ops/kernels/softmax.cu index 3a7864baf..90f2f0634 100644 --- a/src/ops/kernels/softmax.cu +++ b/src/ops/kernels/softmax.cu @@ -71,6 +71,9 @@ void forward_kernel_wrapper(SoftmaxMeta const *m, } else if (m->output_type[0] == DT_HALF) { Internal::forward_kernel( m, input.get_half_ptr(), output.get_half_ptr(), stream); + } else if (m->output_type[0] == DT_BFLOAT16) { + Internal::forward_kernel( + m, input.get_bfloat16_ptr(), output.get_bfloat16_ptr(), stream); } else { assert(false && "Unsupported data type"); } @@ -114,6 +117,12 @@ void backward_kernel_wrapper(SoftmaxMeta const *m, output_grad.get_half_ptr(), output_grad.domain.get_volume(), stream); + } else if (m->output_type[0] == DT_BFLOAT16) { + Internal::backward_kernel(m, + input_grad.get_bfloat16_ptr(), + output_grad.get_bfloat16_ptr(), + output_grad.domain.get_volume(), + stream); } else { assert(false && "Unsupported data type"); } @@ -168,6 +177,17 @@ void inference_kernel_wrapper(SoftmaxMeta *m, Internal::store_peft_activations( m, bc, num_classes, output.get_half_ptr(), stream); } + } else if (m->output_type[0] == DT_BFLOAT16) { + Internal::inference_kernel(m, + bc, + input.get_bfloat16_ptr(), + output.get_bfloat16_ptr(), + num_classes, + stream); + if (is_last_op && bc->num_finetuning_fwd_requests() > 0) { + Internal::store_peft_activations( + m, bc, num_classes, output.get_bfloat16_ptr(), stream); + } } else { assert(false && "Unsupported data type"); } @@ -205,6 +225,9 @@ void peft_bwd_kernel_wrapper(SoftmaxMeta const *m, } else if (m->output_type[0] == DT_HALF) { Internal::peft_bwd_kernel( m, bc, input_grad.get_half_ptr(), num_classes, stream); + } else if (m->output_type[0] == DT_BFLOAT16) { + Internal::peft_bwd_kernel( + m, bc, input_grad.get_bfloat16_ptr(), num_classes, stream); } else { assert(false && "Unsupported data type"); } diff --git a/src/ops/layer_norm.cu b/src/ops/layer_norm.cu index fecaa067c..73a5f8c6c 100644 --- a/src/ops/layer_norm.cu +++ b/src/ops/layer_norm.cu @@ -119,11 +119,11 @@ __global__ void LayerNormFusedForwardKernel(int64_t N, T *Y) { __shared__ float m_shared[C10_WARP_SIZE]; __shared__ float v_shared[C10_WARP_SIZE]; - const int64_t i = blockIdx.x; + int64_t const i = blockIdx.x; float sum1 = 0.0f; float sum2 = 0.0f; for (int64_t j = threadIdx.x; j < N; j += blockDim.x) { - const int64_t index = i * N + j; + int64_t const index = i * N + j; sum1 += static_cast(X[index]); sum2 += static_cast(X[index]) * static_cast(X[index]); } @@ -141,7 +141,7 @@ __global__ void LayerNormFusedForwardKernel(int64_t N, using T_ACC = T; for (int64_t j = threadIdx.x; j < N; j += blockDim.x) { - const int64_t index = i * N + j; + int64_t const index = i * N + j; const T_ACC gamma_v = gamma == nullptr ? T_ACC(1) : static_cast(gamma[j]); const T_ACC beta_v = @@ -206,6 +206,15 @@ void LayerNorm::forward_kernel_wrapper(LayerNormMeta const *m, m->elementwise_affine ? gamma.get_half_ptr() : nullptr, (m->elementwise_affine && m->use_bias) ? beta.get_half_ptr() : nullptr, stream); + } else if (m->input_type[0] == DT_BFLOAT16) { + LayerNorm::forward_kernel<__ff_bfloat16>( + m, + input.get_bfloat16_ptr(), + output.get_bfloat16_ptr(), + m->elementwise_affine ? gamma.get_bfloat16_ptr() : nullptr, + (m->elementwise_affine && m->use_bias) ? beta.get_bfloat16_ptr() + : nullptr, + stream); } else { assert(false && "unsupport datatype in layernorm"); } @@ -269,6 +278,13 @@ void LayerNorm::inference_kernel_wrapper(LayerNormMeta *m, data_type_size(m->input_type[0]) * num_peft_tokens * in_dim, cudaMemcpyDeviceToDevice, stream)); + } else if (m->input_type[0] == DT_BFLOAT16) { + checkCUDA(cudaMemcpyAsync( + m->input_activation, + input.get_bfloat16_ptr() + first_token_offset * in_dim, + data_type_size(m->input_type[0]) * num_peft_tokens * in_dim, + cudaMemcpyDeviceToDevice, + stream)); } else { assert(false && "unsupport datatype in layernorm"); } @@ -290,6 +306,14 @@ void LayerNorm::inference_kernel_wrapper(LayerNormMeta *m, m->elementwise_affine ? gamma.get_half_ptr() : nullptr, (m->elementwise_affine && m->use_bias) ? beta.get_half_ptr() : nullptr, stream); + } else if (m->input_type[0] == DT_BFLOAT16) { + LayerNorm::forward_kernel<__ff_bfloat16>( + m, + input.get_bfloat16_ptr(), + output.get_bfloat16_ptr(), + m->elementwise_affine ? gamma.get_bfloat16_ptr() : nullptr, + (m->elementwise_affine && m->use_bias) ? beta.get_bfloat16_ptr() : nullptr, + stream); } else { assert(false && "unsupport datatype in layernorm"); } @@ -313,11 +337,11 @@ __global__ void ComputeInternalGradientsCUDAKernel( using T_ACC = T; __shared__ T_ACC ds_shared[C10_WARP_SIZE]; __shared__ T_ACC db_shared[C10_WARP_SIZE]; - const int64_t i = blockIdx.x; + int64_t const i = blockIdx.x; T_ACC sum1 = 0; T_ACC sum2 = 0; for (int64_t j = threadIdx.x; j < N; j += blockDim.x) { - const int64_t index = i * N + j; + int64_t const index = i * N + j; const T_ACC gamma_v = gamma == nullptr ? T_ACC(1) : static_cast(gamma[j]); sum1 += @@ -342,7 +366,7 @@ __global__ void ComputeGradientFusedParamsCUDAKernel(int64_t M, T *c1, T *c2) { using T_ACC = T; - const int64_t index = blockIdx.x * blockDim.x + threadIdx.x; + int64_t const index = blockIdx.x * blockDim.x + threadIdx.x; if (index < M) { const T_ACC s = T_ACC(1) / static_cast((int)N); const T_ACC a = (db[index] * static_cast(mean[index]) - ds[index]) * @@ -365,12 +389,12 @@ __global__ void GammaBetaBackwardSimpleCUDAKernel(int64_t M, T *dg, T *db) { using T_ACC = T; - const int64_t j = blockIdx.x * blockDim.x + threadIdx.x; + int64_t const j = blockIdx.x * blockDim.x + threadIdx.x; if (j < N) { T_ACC sum1 = 0; T_ACC sum2 = 0; for (int64_t i = 0; i < M; ++i) { - const int64_t index = i * N + j; + int64_t const index = i * N + j; sum1 += dg == nullptr ? T_ACC(0) : static_cast(dY[index]) * (static_cast(X[index]) - @@ -399,17 +423,17 @@ __global__ void GammaBetaBackwardCUDAKernel(int64_t M, using T_ACC = T; __shared__ T_ACC g_shared[kColwiseReduceTileSize][kColwiseReduceTileSize + 1]; __shared__ T_ACC b_shared[kColwiseReduceTileSize][kColwiseReduceTileSize + 1]; - const int64_t j = blockIdx.x * blockDim.x + threadIdx.x; + int64_t const j = blockIdx.x * blockDim.x + threadIdx.x; T_ACC dg_sum1 = 0; T_ACC dg_sum2 = 0; T_ACC db_sum1 = 0; T_ACC db_sum2 = 0; if (j < N) { for (int64_t i = threadIdx.y; i < M; i += blockDim.y * 2) { - const int64_t i1 = i; - const int64_t i2 = i + blockDim.y; - const int64_t index1 = i1 * N + j; - const int64_t index2 = i2 * N + j; + int64_t const i1 = i; + int64_t const i2 = i + blockDim.y; + int64_t const index1 = i1 * N + j; + int64_t const index2 = i2 * N + j; dg_sum1 += dg == nullptr ? T_ACC(0) : static_cast(dY[index1]) * (static_cast(X[index1]) - @@ -436,7 +460,7 @@ __global__ void GammaBetaBackwardCUDAKernel(int64_t M, sum1 = WarpReduceSum(sum1); sum2 = WarpReduceSum(sum2); if (threadIdx.x == 0) { - const int64_t j = blockIdx.x * blockDim.x + threadIdx.y; + int64_t const j = blockIdx.x * blockDim.x + threadIdx.y; if (j < N) { if (dg != nullptr) { dg[j] = sum1; @@ -451,7 +475,7 @@ __global__ void GammaBetaBackwardCUDAKernel(int64_t M, sum1 = WarpReduceSum(sum1); sum2 = WarpReduceSum(sum2); if (threadIdx.x == 0) { - const int64_t j = blockIdx.x * blockDim.x + threadIdx.y + blockDim.y; + int64_t const j = blockIdx.x * blockDim.x + threadIdx.y + blockDim.y; if (j < N) { if (dg != nullptr) { dg[j] = sum1; @@ -473,8 +497,8 @@ __device__ __inline__ void compute_gI(T const *__restrict__ dY, int const N, T *buf) { auto const i1 = blockIdx.x; - const T mean_val = mean[i1]; - const T rstd_val = rstd[i1]; + T const mean_val = mean[i1]; + T const rstd_val = rstd[i1]; T stats_x1{0}, stats_x2{0}; constexpr int unroll = 4; auto l = unroll * threadIdx.x; @@ -487,16 +511,16 @@ __device__ __inline__ void compute_gI(T const *__restrict__ dY, #pragma unroll for (int k = 0; k < unroll; k++) { T gamma_val = (gamma != nullptr) ? static_cast(gamma[l + k]) : T(1); - const T c_h = static_cast(X_i[l + k]); - const T c_loss = static_cast(dY_i[l + k]); + T const c_h = static_cast(X_i[l + k]); + T const c_loss = static_cast(dY_i[l + k]); stats_x1 += c_loss * gamma_val; stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; } } for (; l < N; l++) { T gamma_val = (gamma != nullptr) ? static_cast(gamma[l]) : T(1); - const T c_h = static_cast(X_i[l]); - const T c_loss = static_cast(dY_i[l]); + T const c_h = static_cast(X_i[l]); + T const c_loss = static_cast(dY_i[l]); stats_x1 += c_loss * gamma_val; stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; } @@ -514,8 +538,8 @@ __device__ __inline__ void compute_gI(T const *__restrict__ dY, T term1 = (T(1) / fH) * rstd_val; for (int l = threadIdx.x; l < N; l += blockDim.x) { - const T x = X_i[l]; - const T dy = dY_i[l]; + T const x = X_i[l]; + T const dy = dY_i[l]; T gamma_val = (gamma != nullptr) ? static_cast(gamma[l]) : T(1); T f_grad_input = fH * gamma_val * dy; f_grad_input -= (x - mean_val) * rstd_val * stats_x2; @@ -549,8 +573,8 @@ void LayerNorm::backward_kernel(LayerNormMeta const *m, T *gamma_grad_ptr, T *beta_grad_ptr, cudaStream_t stream) { - const int64_t M = m->effective_batch_size; - const int64_t N = m->effective_num_elements; + int64_t const M = m->effective_batch_size; + int64_t const N = m->effective_num_elements; ComputeInternalGradientsCUDAKernel <<>>( N, @@ -559,7 +583,7 @@ void LayerNorm::backward_kernel(LayerNormMeta const *m, gamma_ptr, static_cast(m->ds_ptr), static_cast(m->db_ptr)); - const int64_t B = (M + kCUDANumThreads - 1) / kCUDANumThreads; + int64_t const B = (M + kCUDANumThreads - 1) / kCUDANumThreads; ComputeGradientFusedParamsCUDAKernel <<>>(M, N, @@ -571,7 +595,7 @@ void LayerNorm::backward_kernel(LayerNormMeta const *m, static_cast(m->bias_ptr)); int const warp_size = C10_WARP_SIZE; int const num_threads = 128; - const dim3 blocks(M); + dim3 const blocks(M); int nshared = (num_threads / warp_size) * sizeof(T); layer_norm_grad_input_kernel<<>>( output_grad_ptr, @@ -585,7 +609,7 @@ void LayerNorm::backward_kernel(LayerNormMeta const *m, if (gamma_grad_ptr != NULL || beta_grad_ptr != NULL) { if (M < 512) { // For small batch size, do colwise reduce directly - const int64_t B = (N + kCUDANumThreads - 1) / kCUDANumThreads; + int64_t const B = (N + kCUDANumThreads - 1) / kCUDANumThreads; GammaBetaBackwardSimpleCUDAKernel <<>>(M, N, @@ -596,7 +620,7 @@ void LayerNorm::backward_kernel(LayerNormMeta const *m, gamma_grad_ptr, beta_grad_ptr); } else { - const int64_t B = + int64_t const B = (N + kColwiseReduceTileSize - 1) / kColwiseReduceTileSize; constexpr int kThreadX = kColwiseReduceTileSize; constexpr int kThreadY = kColwiseReduceTileSize / 2; @@ -621,11 +645,11 @@ void LayerNorm::peft_bwd_kernel(LayerNormMeta const *m, T *input_grad_ptr, T const *gamma_ptr, cudaStream_t stream) { - const int64_t M = m->effective_batch_size; - const int64_t N = m->effective_num_elements; + int64_t const M = m->effective_batch_size; + int64_t const N = m->effective_num_elements; int const warp_size = C10_WARP_SIZE; int const num_threads = 128; - const dim3 blocks(M); + dim3 const blocks(M); int nshared = (num_threads / warp_size) * sizeof(T); layer_norm_grad_input_kernel<<>>( output_grad_ptr, @@ -651,13 +675,20 @@ void LayerNorm::peft_bwd_kernel_wrapper( input_grad.get_float_ptr(), gamma.get_float_ptr(), stream); - } else { - assert(m->output_type[0] == DT_HALF); + } else if (m->output_type[0] == DT_BFLOAT16) { + LayerNorm::peft_bwd_kernel(m, + output_grad.get_bfloat16_ptr(), + input_grad.get_bfloat16_ptr(), + gamma.get_bfloat16_ptr(), + stream); + } else if (m->output_type[0] == DT_HALF) { LayerNorm::peft_bwd_kernel(m, output_grad.get_half_ptr(), input_grad.get_half_ptr(), gamma.get_half_ptr(), stream); + } else { + assert(false && "Unsupported data type"); } } @@ -690,6 +721,15 @@ void LayerNorm::backward_kernel_wrapper( gamma_grad.get_half_ptr(), beta_grad.get_half_ptr(), stream); + } else if (m->output_type[0] == DT_BFLOAT16) { + LayerNorm::backward_kernel(m, + output_grad.get_bfloat16_ptr(), + input.get_bfloat16_ptr(), + input_grad.get_bfloat16_ptr(), + gamma.get_bfloat16_ptr(), + gamma_grad.get_bfloat16_ptr(), + beta_grad.get_bfloat16_ptr(), + stream); } else { assert(false && "Unsupported data type"); } diff --git a/src/ops/linear.cc b/src/ops/linear.cc index 7d596ac55..7a91a35c8 100644 --- a/src/ops/linear.cc +++ b/src/ops/linear.cc @@ -33,7 +33,7 @@ using namespace FlexFlow::Kernels::Linear; static constexpr int KERNEL_IDX = 0; static constexpr int BIAS_IDX = 1; -Tensor FFModel::dense(const Tensor input, +Tensor FFModel::dense(Tensor const input, int outDim, ActiMode activation, bool use_bias, @@ -157,7 +157,7 @@ Op *Linear::create_operator_from_layer( Linear::Linear(FFModel &model, Linear const &other, - const ParallelTensor input, + ParallelTensor const input, bool allocate_weights) : Linear(model, other.layer_guid, @@ -194,7 +194,7 @@ Linear::Linear(FFModel &model, Linear::Linear(FFModel &model, LayerID const &_layer_guid, - const ParallelTensor _input, + ParallelTensor const _input, int out_dim, ActiMode _activation, RegularizerMode _kernel_reg_type, @@ -437,6 +437,14 @@ OpMeta *Linear::init_task(Task const *task, return init_task_with_dim( \ task, regions, ctx, runtime); \ } \ + } else if (output.data_type == DT_BFLOAT16) { \ + if (linear->quantization_type != DT_NONE) { \ + return init_task_with_dim<__ff_bfloat16, char, DIM>( \ + task, regions, ctx, runtime); \ + } else { \ + return init_task_with_dim<__ff_bfloat16, __ff_bfloat16, DIM>( \ + task, regions, ctx, runtime); \ + } \ } else { \ assert(false && "Unsupported data type"); \ } @@ -795,6 +803,14 @@ void Linear::forward_task(Task const *task, return forward_task_with_dim( \ task, regions, ctx, runtime); \ } \ + } else if (m->output_type[0] == DT_BFLOAT16) { \ + if (m->quantization_type != DT_NONE) { \ + return forward_task_with_dim<__ff_bfloat16, char, DIM>( \ + task, regions, ctx, runtime); \ + } else { \ + return forward_task_with_dim<__ff_bfloat16, __ff_bfloat16, DIM>( \ + task, regions, ctx, runtime); \ + } \ } else { \ assert(false && "Unsupported data type"); \ } @@ -950,6 +966,8 @@ void Linear::backward_task(Task const *task, return backward_task_with_dim(task, regions, ctx, runtime); \ } else if (m->output_type[0] == DT_FLOAT) { \ return backward_task_with_dim(task, regions, ctx, runtime); \ + } else if (m->output_type[0] == DT_BFLOAT16) { \ + return backward_task_with_dim<__ff_bfloat16, DIM>(task, regions, ctx, runtime); \ } else { \ assert(false && "Unsupported data type"); \ } @@ -1445,7 +1463,7 @@ bool LinearParams::is_valid(ParallelTensorShape const &input_shape) const { * It takes a the input tensor as a parameter, instead of the input's * ParallelTensorShape. */ -void LinearParams::solve_dims(const ParallelTensor input, +void LinearParams::solve_dims(ParallelTensor const input, ParallelDim output_dims[MAX_TENSOR_DIM], int *output_ndims, ParallelDim kernel_dims[MAX_TENSOR_DIM], diff --git a/src/ops/lora_linear.cc b/src/ops/lora_linear.cc index baa1a765a..b0dac9e75 100644 --- a/src/ops/lora_linear.cc +++ b/src/ops/lora_linear.cc @@ -504,6 +504,13 @@ void LoraLinear::inference_task(Task const *task, at::Tensor tensor2 = createTorchTensorFromCuda( weight.w0_ptr, {lora_config.rank, out_dim}); torch::save(tensor2, filenameB.c_str()); + } else if (m->input_type[0] == DT_BFLOAT16) { + at::Tensor tensor1 = createTorchTensorFromCuda<__ff_bfloat16>( + weight.w0_ptr, {lora_config.rank, in_dim}); + torch::save(tensor1, filenameA.c_str()); + at::Tensor tensor2 = createTorchTensorFromCuda<__ff_bfloat16>( + weight.w0_ptr, {lora_config.rank, out_dim}); + torch::save(tensor2, filenameB.c_str()); } else { assert(false && "Data type not supported"); } @@ -520,6 +527,10 @@ void LoraLinear::inference_task(Task const *task, at::Tensor tensor = createTorchTensorFromCuda( weight.low_rank_activation, {lora_config.rank, num_tokens}); torch::save(tensor, filename.c_str()); + } else if (output.data_type == DT_BFLOAT16) { + at::Tensor tensor = createTorchTensorFromCuda<__ff_bfloat16>( + weight.low_rank_activation, {lora_config.rank, num_tokens}); + torch::save(tensor, filename.c_str()); } else { assert(false); } @@ -678,6 +689,23 @@ void lora_inference_debugging(LoraLinearMeta *m, at::Tensor tensorGradB = createTorchTensorFromCuda( weight.w1_grad_ptr, {lora_config.rank, out_dim}); torch::save(tensorGradB, filename_grad_B.c_str()); + } else if (m->input_type[0] == DT_BFLOAT16) { + // weight A + at::Tensor tensorA = createTorchTensorFromCuda<__ff_bfloat16>( + weight.w0_ptr, {in_dim, lora_config.rank}); + torch::save(tensorA, filename_weight_A.c_str()); + // weight grad A + at::Tensor tensorGradA = createTorchTensorFromCuda<__ff_bfloat16>( + weight.w0_grad_ptr, {in_dim, lora_config.rank}); + torch::save(tensorGradA, filename_grad_A.c_str()); + // weight B + at::Tensor tensorB = createTorchTensorFromCuda<__ff_bfloat16>( + weight.w1_ptr, {lora_config.rank, out_dim}); + torch::save(tensorB, filename_weight_B.c_str()); + // weight grad B + at::Tensor tensorGradB = createTorchTensorFromCuda<__ff_bfloat16>( + weight.w1_grad_ptr, {lora_config.rank, out_dim}); + torch::save(tensorGradB, filename_grad_B.c_str()); } else { assert(false && "Data type not supported"); } @@ -789,6 +817,13 @@ void Kernels::LoraLinear::save_peft_weights_if_needed(LoraLinearMeta *m, if (shard_id == 0) { save_peft_to_file((half *)weight.w1_ptr, w1_num_elements, w1_filepath); } + } else if (m->input_type[0] == DT_BFLOAT16) { + save_peft_to_file( + (__ff_bfloat16 *)weight.w0_ptr, w0_num_elements, w0_filepath); + if (shard_id == 0) { + save_peft_to_file( + (__ff_bfloat16 *)weight.w1_ptr, w1_num_elements, w1_filepath); + } } else { assert(false && "Data type not supported"); } diff --git a/src/ops/residual_layer_norm.cu b/src/ops/residual_layer_norm.cu index a258ff294..0ce58c2ad 100644 --- a/src/ops/residual_layer_norm.cu +++ b/src/ops/residual_layer_norm.cu @@ -118,12 +118,12 @@ __global__ void ResidualLayerNormKernel(int64_t N, T *Y) { __shared__ float m_shared[C10_WARP_SIZE]; __shared__ float v_shared[C10_WARP_SIZE]; - const int64_t i = blockIdx.x; + int64_t const i = blockIdx.x; float sum1 = 0.0f; float sum2 = 0.0f; for (int64_t j = threadIdx.x; j < N; j += blockDim.x) { - const int64_t index = i * N + j; - const T residual2_val = (residual2_ptr == nullptr) + int64_t const index = i * N + j; + T const residual2_val = (residual2_ptr == nullptr) ? T(0) : static_cast(residual2_ptr[index]); X[index] = input_ptr[index] + residual1_ptr[index] + residual2_val; @@ -146,7 +146,7 @@ __global__ void ResidualLayerNormKernel(int64_t N, using T_ACC = T; for (int64_t j = threadIdx.x; j < N; j += blockDim.x) { - const int64_t index = i * N + j; + int64_t const index = i * N + j; const T_ACC gamma_v = gamma == nullptr ? T_ACC(1) : static_cast(gamma[j]); const T_ACC beta_v = @@ -229,6 +229,18 @@ void ResidualLayerNorm::inference_kernel_wrapper( m->elementwise_affine ? gamma.get_half_ptr() : nullptr, (m->elementwise_affine && m->use_bias) ? beta.get_half_ptr() : nullptr, stream); + } else if (m->input_type[0] == DT_BFLOAT16) { + ResidualLayerNorm::inference_kernel<__ff_bfloat16>( + m, + input.get_bfloat16_ptr(), + residual1.get_bfloat16_ptr(), + m->use_two_residuals ? residual2.get_bfloat16_ptr() : nullptr, + added_output.get_bfloat16_ptr(), + output.get_bfloat16_ptr(), + m->elementwise_affine ? gamma.get_bfloat16_ptr() : nullptr, + (m->elementwise_affine && m->use_bias) ? beta.get_bfloat16_ptr() + : nullptr, + stream); } else { assert(false && "unsupport datatype in layernorm"); } @@ -262,6 +274,13 @@ void ResidualLayerNorm::inference_kernel_wrapper( data_type_size(m->input_type[0]) * num_peft_tokens * in_dim, cudaMemcpyDeviceToDevice, stream)); + } else if (m->input_type[0] == DT_BFLOAT16) { + checkCUDA(cudaMemcpyAsync( + m->input_activation, + added_output.get_bfloat16_ptr() + first_token_offset * in_dim, + data_type_size(m->input_type[0]) * num_peft_tokens * in_dim, + cudaMemcpyDeviceToDevice, + stream)); } else { assert(false && "unsupport datatype in layernorm"); } @@ -294,11 +313,11 @@ __global__ void ComputeInternalGradientsCUDAKernel( using T_ACC = T; __shared__ T_ACC ds_shared[C10_WARP_SIZE]; __shared__ T_ACC db_shared[C10_WARP_SIZE]; - const int64_t i = blockIdx.x; + int64_t const i = blockIdx.x; T_ACC sum1 = 0; T_ACC sum2 = 0; for (int64_t j = threadIdx.x; j < N; j += blockDim.x) { - const int64_t index = i * N + j; + int64_t const index = i * N + j; const T_ACC gamma_v = gamma == nullptr ? T_ACC(1) : static_cast(gamma[j]); sum1 += @@ -323,7 +342,7 @@ __global__ void ComputeGradientFusedParamsCUDAKernel(int64_t M, T *c1, T *c2) { using T_ACC = T; - const int64_t index = blockIdx.x * blockDim.x + threadIdx.x; + int64_t const index = blockIdx.x * blockDim.x + threadIdx.x; if (index < M) { const T_ACC s = T_ACC(1) / static_cast((int)N); const T_ACC a = (db[index] * static_cast(mean[index]) - ds[index]) * @@ -346,12 +365,12 @@ __global__ void GammaBetaBackwardSimpleCUDAKernel(int64_t M, T *dg, T *db) { using T_ACC = T; - const int64_t j = blockIdx.x * blockDim.x + threadIdx.x; + int64_t const j = blockIdx.x * blockDim.x + threadIdx.x; if (j < N) { T_ACC sum1 = 0; T_ACC sum2 = 0; for (int64_t i = 0; i < M; ++i) { - const int64_t index = i * N + j; + int64_t const index = i * N + j; sum1 += dg == nullptr ? T_ACC(0) : static_cast(dY[index]) * (static_cast(X[index]) - @@ -380,17 +399,17 @@ __global__ void GammaBetaBackwardCUDAKernel(int64_t M, using T_ACC = T; __shared__ T_ACC g_shared[kColwiseReduceTileSize][kColwiseReduceTileSize + 1]; __shared__ T_ACC b_shared[kColwiseReduceTileSize][kColwiseReduceTileSize + 1]; - const int64_t j = blockIdx.x * blockDim.x + threadIdx.x; + int64_t const j = blockIdx.x * blockDim.x + threadIdx.x; T_ACC dg_sum1 = 0; T_ACC dg_sum2 = 0; T_ACC db_sum1 = 0; T_ACC db_sum2 = 0; if (j < N) { for (int64_t i = threadIdx.y; i < M; i += blockDim.y * 2) { - const int64_t i1 = i; - const int64_t i2 = i + blockDim.y; - const int64_t index1 = i1 * N + j; - const int64_t index2 = i2 * N + j; + int64_t const i1 = i; + int64_t const i2 = i + blockDim.y; + int64_t const index1 = i1 * N + j; + int64_t const index2 = i2 * N + j; dg_sum1 += dg == nullptr ? T_ACC(0) : static_cast(dY[index1]) * (static_cast(X[index1]) - @@ -417,7 +436,7 @@ __global__ void GammaBetaBackwardCUDAKernel(int64_t M, sum1 = WarpReduceSum(sum1); sum2 = WarpReduceSum(sum2); if (threadIdx.x == 0) { - const int64_t j = blockIdx.x * blockDim.x + threadIdx.y; + int64_t const j = blockIdx.x * blockDim.x + threadIdx.y; if (j < N) { if (dg != nullptr) { dg[j] = sum1; @@ -432,7 +451,7 @@ __global__ void GammaBetaBackwardCUDAKernel(int64_t M, sum1 = WarpReduceSum(sum1); sum2 = WarpReduceSum(sum2); if (threadIdx.x == 0) { - const int64_t j = blockIdx.x * blockDim.x + threadIdx.y + blockDim.y; + int64_t const j = blockIdx.x * blockDim.x + threadIdx.y + blockDim.y; if (j < N) { if (dg != nullptr) { dg[j] = sum1; @@ -459,8 +478,8 @@ __device__ __inline__ void compute_gI(T const *__restrict__ dY, int const N, T *buf) { auto const i1 = blockIdx.x; - const T mean_val = mean[i1]; - const T rstd_val = rstd[i1]; + T const mean_val = mean[i1]; + T const rstd_val = rstd[i1]; T stats_x1{0}, stats_x2{0}; constexpr int unroll = 4; auto l = unroll * threadIdx.x; @@ -476,16 +495,16 @@ __device__ __inline__ void compute_gI(T const *__restrict__ dY, #pragma unroll for (int k = 0; k < unroll; k++) { T gamma_val = (gamma != nullptr) ? static_cast(gamma[l + k]) : T(1); - const T c_h = static_cast(X_i[l + k]); - const T c_loss = static_cast(dY_i[l + k]); + T const c_h = static_cast(X_i[l + k]); + T const c_loss = static_cast(dY_i[l + k]); stats_x1 += c_loss * gamma_val; stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; } } for (; l < N; l++) { T gamma_val = (gamma != nullptr) ? static_cast(gamma[l]) : T(1); - const T c_h = static_cast(X_i[l]); - const T c_loss = static_cast(dY_i[l]); + T const c_h = static_cast(X_i[l]); + T const c_loss = static_cast(dY_i[l]); stats_x1 += c_loss * gamma_val; stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; } @@ -503,8 +522,8 @@ __device__ __inline__ void compute_gI(T const *__restrict__ dY, T term1 = (T(1) / fH) * rstd_val; for (int l = threadIdx.x; l < N; l += blockDim.x) { - const T x = X_i[l]; - const T dy = dY_i[l]; + T const x = X_i[l]; + T const dy = dY_i[l]; T gamma_val = (gamma != nullptr) ? static_cast(gamma[l]) : T(1); T f_grad_input = fH * gamma_val * dy; f_grad_input -= (x - mean_val) * rstd_val * stats_x2; @@ -572,8 +591,8 @@ void backward_kernel(ResidualLayerNormMeta const *m, T *gamma_grad_ptr, T *beta_grad_ptr, cudaStream_t stream) { - const int64_t M = m->effective_batch_size; - const int64_t N = m->effective_num_elements; + int64_t const M = m->effective_batch_size; + int64_t const N = m->effective_num_elements; ComputeInternalGradientsCUDAKernel <<>>( N, @@ -582,7 +601,7 @@ void backward_kernel(ResidualLayerNormMeta const *m, gamma_ptr, static_cast(m->ds_ptr), static_cast(m->db_ptr)); - const int64_t B = (M + kCUDANumThreads - 1) / kCUDANumThreads; + int64_t const B = (M + kCUDANumThreads - 1) / kCUDANumThreads; ComputeGradientFusedParamsCUDAKernel <<>>(M, N, @@ -594,7 +613,7 @@ void backward_kernel(ResidualLayerNormMeta const *m, static_cast(m->bias_ptr)); int const warp_size = C10_WARP_SIZE; int const num_threads = 128; - const dim3 blocks(M); + dim3 const blocks(M); int nshared = (num_threads / warp_size) * sizeof(T); layer_norm_grad_input_kernel<<>>( output_grad_ptr, @@ -613,7 +632,7 @@ void backward_kernel(ResidualLayerNormMeta const *m, if (gamma_grad_ptr != NULL || beta_grad_ptr != NULL) { if (M < 512) { // For small batch size, do colwise reduce directly - const int64_t B = (N + kCUDANumThreads - 1) / kCUDANumThreads; + int64_t const B = (N + kCUDANumThreads - 1) / kCUDANumThreads; GammaBetaBackwardSimpleCUDAKernel <<>>(M, N, @@ -624,7 +643,7 @@ void backward_kernel(ResidualLayerNormMeta const *m, gamma_grad_ptr, beta_grad_ptr); } else { - const int64_t B = + int64_t const B = (N + kColwiseReduceTileSize - 1) / kColwiseReduceTileSize; constexpr int kThreadX = kColwiseReduceTileSize; constexpr int kThreadY = kColwiseReduceTileSize / 2; @@ -688,6 +707,19 @@ void ResidualLayerNorm::backward_kernel_wrapper( (m->elementwise_affine && m->use_bias) ? beta_grad.get_half_ptr() : nullptr, stream); + }else if (m->output_type[0] == DT_BFLOAT16) { + backward_kernel( + m, + output_grad.get_bfloat16_ptr(), + added_output.get_bfloat16_ptr(), + input_grad.get_bfloat16_ptr(), + residual1_grad.get_bfloat16_ptr(), + m->use_two_residuals ? residual2_grad.get_bfloat16_ptr() : nullptr, + m->elementwise_affine ? gamma.get_bfloat16_ptr() : nullptr, + m->elementwise_affine ? gamma_grad.get_bfloat16_ptr() : nullptr, + (m->elementwise_affine && m->use_bias) ? beta_grad.get_bfloat16_ptr() + : nullptr, + stream); } else { assert(false && "Unsupported data type"); } @@ -712,8 +744,8 @@ void peft_bwd_kernel(ResidualLayerNormMeta const *m, T *residual2_grad_ptr, T const *gamma_ptr, cudaStream_t stream) { - const int64_t M = m->effective_batch_size; - const int64_t N = m->effective_num_elements; + int64_t const M = m->effective_batch_size; + int64_t const N = m->effective_num_elements; if (m->inference_debugging) { // save stuff here @@ -748,7 +780,7 @@ void peft_bwd_kernel(ResidualLayerNormMeta const *m, int const warp_size = C10_WARP_SIZE; int const num_threads = 128; - const dim3 blocks(M); + dim3 const blocks(M); int nshared = (num_threads / warp_size) * sizeof(T); layer_norm_grad_input_kernel<<>>( @@ -801,6 +833,15 @@ void ResidualLayerNorm::peft_bwd_kernel_wrapper( : nullptr, m->elementwise_affine ? gamma.get_half_ptr() : nullptr, stream); + } else if (m->output_type[0] == DT_BFLOAT16) { + peft_bwd_kernel(m, + output_grad.get_bfloat16_ptr(), + input_grad.get_bfloat16_ptr(), + residual1_grad.get_bfloat16_ptr(), + m->use_two_residuals ? residual2_grad.get_bfloat16_ptr() + : nullptr, + m->elementwise_affine ? gamma.get_bfloat16_ptr() : nullptr, + stream); } else { assert(false && "Unsupported data type"); } diff --git a/src/ops/sampling.cu b/src/ops/sampling.cu index b5d5b8ae1..2cd16a8b3 100644 --- a/src/ops/sampling.cu +++ b/src/ops/sampling.cu @@ -182,6 +182,14 @@ void Sampling::forward_kernel_wrapper(SamplingMeta const *m, length, batch_size, stream); + } else if (input.data_type == DT_BFLOAT16) { + Sampling::forward_kernel<__ff_bfloat16>(m, + input.get_bfloat16_ptr(), + indices.get_int32_ptr(), + m->top_p, + length, + batch_size, + stream); } else { assert(false && "Unsupported data type"); } @@ -271,6 +279,21 @@ SamplingMeta::SamplingMeta(FFHandler handler, 0, // begin_bit data_type_size(data_type) * 8, // end_bit = sizeof(KeyT) * 8 stream)); + } else if (data_type == DT_BFLOAT16) { + checkCUDA(cub::DeviceSegmentedRadixSort::SortPairsDescending( + d_temp_storage, + temp_storage_bytes, + input.get_bfloat16_ptr(), + input.get_bfloat16_ptr(), + idx, + idx, + total_ele, + batch_size, + begin_offset, + end_offset + 1, + 0, // begin_bit + data_type_size(data_type) * 8, // end_bit = sizeof(KeyT) * 8 + stream)); } else { assert(false && "input type in float and half"); } diff --git a/src/ops/sigmoid_silu_multi.cu b/src/ops/sigmoid_silu_multi.cu index 22c4d79cc..8ccde6d1b 100644 --- a/src/ops/sigmoid_silu_multi.cu +++ b/src/ops/sigmoid_silu_multi.cu @@ -154,6 +154,19 @@ void SigmoidSiluMulti::inference_kernel_wrapper( input_tensor_size, cudaMemcpyDeviceToDevice, stream)); + } else if (m->input_type[0] == DT_BFLOAT16) { + checkCUDA( + cudaMemcpyAsync(m->input_activation, + input1.get_bfloat16_ptr() + first_token_offset * in_dim, + input_tensor_size, + cudaMemcpyDeviceToDevice, + stream)); + checkCUDA(cudaMemcpyAsync( + (void *)((char *)m->input_activation + input_tensor_size), + input2.get_bfloat16_ptr() + first_token_offset * in_dim, + input_tensor_size, + cudaMemcpyDeviceToDevice, + stream)); } else { assert(false && "unsupport datatype in layernorm"); } @@ -175,6 +188,14 @@ void SigmoidSiluMulti::inference_kernel_wrapper( input1.get_half_ptr(), input2.get_half_ptr(), output.get_half_ptr()); + } else if (m->input_type[0] == DT_BFLOAT16) { + SigmoidSiluMultiKernel<<>>(input1.domain.get_volume(), + input1.get_bfloat16_ptr(), + input2.get_bfloat16_ptr(), + output.get_bfloat16_ptr()); } else { assert(false && "unsupport datatype in SigmoidSiluMulti"); } @@ -237,6 +258,18 @@ void SigmoidSiluMulti::backward_kernel_wrapper( input2_grad.get_half_ptr(), m->reset_input_grads[0], m->reset_input_grads[1]); + }else if (m->input_type[0] == DT_BFLOAT16) { + SigmoidSiluMultiBackwardKernel<<>>(output_grad.domain.get_volume(), + output_grad.get_bfloat16_ptr(), + input1.get_bfloat16_ptr(), + input2.get_bfloat16_ptr(), + input1_grad.get_bfloat16_ptr(), + input2_grad.get_bfloat16_ptr(), + m->reset_input_grads[0], + m->reset_input_grads[1]); } else { assert(false && "unsupport datatype in SigmoidSiluMulti"); } @@ -307,6 +340,20 @@ void SigmoidSiluMulti::peft_bwd_kernel_wrapper( input2_grad.get_half_ptr(), m->reset_input_grads[0], m->reset_input_grads[1]); + } else if (m->input_type[0] == DT_BFLOAT16) { + SigmoidSiluMultiBackwardKernel<<>>( + num_elements, + output_grad.get_bfloat16_ptr(), + static_cast<__ff_bfloat16 const *>(m->input_activation), + static_cast<__ff_bfloat16 const *>(m->input_activation) + + num_peft_tokens * in_dim, + input1_grad.get_bfloat16_ptr(), + input2_grad.get_bfloat16_ptr(), + m->reset_input_grads[0], + m->reset_input_grads[1]); } else { assert(false && "unsupport datatype in SigmoidSiluMulti"); } diff --git a/src/ops/spec_inc_multihead_self_attention.cu b/src/ops/spec_inc_multihead_self_attention.cu index 80d781f4e..ec368a71e 100644 --- a/src/ops/spec_inc_multihead_self_attention.cu +++ b/src/ops/spec_inc_multihead_self_attention.cu @@ -797,6 +797,9 @@ void SpecIncMultiHeadSelfAttention::inference_kernel_wrapper( } else if (input.data_type == DT_FLOAT) { Kernels::SpecIncMultiHeadSelfAttention::inference_kernel( m, bc, shard_id, input.get_float_ptr(), output.get_float_ptr(), stream); + } else if (input.data_type == DT_BFLOAT16) { + Kernels::SpecIncMultiHeadSelfAttention::inference_kernel( + m, bc, shard_id, input.get_bfloat16_ptr(), output.get_bfloat16_ptr(), stream); } else { assert(false && "Unspported data type"); } diff --git a/src/ops/tree_inc_multihead_self_attention.cu b/src/ops/tree_inc_multihead_self_attention.cu index ebd32b8ca..b4db2e5dc 100644 --- a/src/ops/tree_inc_multihead_self_attention.cu +++ b/src/ops/tree_inc_multihead_self_attention.cu @@ -614,6 +614,9 @@ void TreeIncMultiHeadSelfAttention::inference_kernel_wrapper( } else if (input.data_type == DT_FLOAT) { Kernels::TreeIncMultiHeadAttention::inference_kernel( m, bc, shard_id, input.get_float_ptr(), output.get_float_ptr(), stream); + } else if (input.data_type == DT_BFLOAT16) { + Kernels::TreeIncMultiHeadAttention::inference_kernel( + m, bc, shard_id, input.get_bfloat16_ptr(), output.get_bfloat16_ptr(), stream); } else { assert(false && "Unspported data type"); } diff --git a/src/parallel_ops/combine.cc b/src/parallel_ops/combine.cc index a47835c30..90a18c0a0 100644 --- a/src/parallel_ops/combine.cc +++ b/src/parallel_ops/combine.cc @@ -76,7 +76,7 @@ Combine::Combine(FFModel &model, params.name) {} Combine::Combine(FFModel &model, - const ParallelTensor _input, + ParallelTensor const _input, int _combine_legion_dim, int _combine_degree, char const *name) @@ -437,6 +437,8 @@ void Combine::inference_task(Task const *task, forward_task_with_type(task, regions, ctx, runtime); } else if (data_type == DT_FLOAT) { forward_task_with_type(task, regions, ctx, runtime); + } else if (data_type == DT_BFLOAT16) { + forward_task_with_type<__ff_bfloat16>(task, regions, ctx, runtime); } else if (data_type == DT_DOUBLE) { forward_task_with_type(task, regions, ctx, runtime); } else if (data_type == DT_INT32) { @@ -461,6 +463,8 @@ void Combine::forward_task(Task const *task, forward_task_with_type(task, regions, ctx, runtime); } else if (data_type == DT_FLOAT) { forward_task_with_type(task, regions, ctx, runtime); + } else if (data_type == DT_BFLOAT16) { + forward_task_with_type<__ff_bfloat16>(task, regions, ctx, runtime); } else if (data_type == DT_DOUBLE) { forward_task_with_type(task, regions, ctx, runtime); } else if (data_type == DT_INT32) { @@ -525,6 +529,10 @@ bool Combine::peft_bwd_task(Task const *task, backward_kernel(output_grad.get_float_ptr(), input_grad.get_float_ptr(), output_grad.domain.get_volume()); + } else if (data_type == DT_BFLOAT16) { + backward_kernel<__ff_bfloat16>(output_grad.get_bfloat16_ptr(), + input_grad.get_bfloat16_ptr(), + output_grad.domain.get_volume()); } else if (data_type == DT_DOUBLE) { backward_kernel(output_grad.get_double_ptr(), input_grad.get_double_ptr(), diff --git a/src/parallel_ops/kernels/combine_kernels.cpp b/src/parallel_ops/kernels/combine_kernels.cpp index 2a29be1ad..f1f0eb25b 100644 --- a/src/parallel_ops/kernels/combine_kernels.cpp +++ b/src/parallel_ops/kernels/combine_kernels.cpp @@ -74,6 +74,10 @@ template void backward_kernel(half const *output_grad_ptr, template void backward_kernel(float const *output_grad_ptr, float *input_grad_ptr, size_t num_elements); +template void + backward_kernel<__ff_bfloat16>(__ff_bfloat16 const *output_grad_ptr, + __ff_bfloat16 *input_grad_ptr, + size_t num_elements); template void backward_kernel(double const *output_grad_ptr, double *input_grad_ptr, size_t num_elements); diff --git a/src/parallel_ops/kernels/combine_kernels.cu b/src/parallel_ops/kernels/combine_kernels.cu index 5809e2d4f..aeb6e4d1e 100644 --- a/src/parallel_ops/kernels/combine_kernels.cu +++ b/src/parallel_ops/kernels/combine_kernels.cu @@ -52,6 +52,9 @@ template void forward_kernel(half const *input_ptr, template void forward_kernel(float const *input_ptr, float *output_ptr, size_t num_elements); +template void forward_kernel<__ff_bfloat16>(__ff_bfloat16 const *input_ptr, + __ff_bfloat16 *output_ptr, + size_t num_elements); template void forward_kernel(double const *input_ptr, double *output_ptr, size_t num_elements); @@ -67,6 +70,9 @@ template void backward_kernel(half const *output_grad_ptr, template void backward_kernel(float const *output_grad_ptr, float *input_grad_ptr, size_t num_elements); +template void backward_kernel<__ff_bfloat16>(__ff_bfloat16 const *output_grad_ptr, + __ff_bfloat16 *input_grad_ptr, + size_t num_elements); template void backward_kernel(double const *output_grad_ptr, double *input_grad_ptr, size_t num_elements); diff --git a/src/parallel_ops/kernels/replicate_kernels.cpp b/src/parallel_ops/kernels/replicate_kernels.cpp index f49e0d4eb..548b8cc04 100644 --- a/src/parallel_ops/kernels/replicate_kernels.cpp +++ b/src/parallel_ops/kernels/replicate_kernels.cpp @@ -73,6 +73,9 @@ template void forward_kernel(float const *input_ptr, template void forward_kernel(half const *input_ptr, half *output_ptr, size_t num_elements); +template void forward_kernel<__ff_bfloat16>(__ff_bfloat16 const *input_ptr, + __ff_bfloat16 *output_ptr, + size_t num_elements); template __global__ void replicate_backward_kernel(float const *input_ptr, float *output_ptr, diff --git a/src/parallel_ops/kernels/replicate_kernels.cu b/src/parallel_ops/kernels/replicate_kernels.cu index 0b5c434aa..27392e5dd 100644 --- a/src/parallel_ops/kernels/replicate_kernels.cu +++ b/src/parallel_ops/kernels/replicate_kernels.cu @@ -66,6 +66,9 @@ template void forward_kernel(float const *input_ptr, template void forward_kernel(half const *input_ptr, half *output_ptr, size_t num_elements); +template void forward_kernel<__ff_bfloat16>(__ff_bfloat16 const *input_ptr, + __ff_bfloat16 *output_ptr, + size_t num_elements); template __global__ void replicate_backward_kernel(float const *input_ptr, float *output_ptr, diff --git a/src/runtime/accessor.cc b/src/runtime/accessor.cc index d3b94bf14..3eeccb90b 100644 --- a/src/runtime/accessor.cc +++ b/src/runtime/accessor.cc @@ -77,6 +77,15 @@ half const *GenericTensorAccessorR::get_half_ptr() const { } } +__ff_bfloat16 const *GenericTensorAccessorR::get_bfloat16_ptr() const { + if (data_type == DT_BFLOAT16) { + return static_cast<__ff_bfloat16 const *>(ptr); + } else { + assert(false && "Invalid Accessor Type"); + return static_cast<__ff_bfloat16 const *>(nullptr); + } +} + char const *GenericTensorAccessorR::get_byte_ptr() const { if (data_type == DT_INT4 || data_type == DT_INT8) { return static_cast(ptr); @@ -165,6 +174,15 @@ half *GenericTensorAccessorW::get_half_ptr() const { } } +__ff_bfloat16 *GenericTensorAccessorW::get_bfloat16_ptr() const { + if (data_type == DT_BFLOAT16) { + return static_cast<__ff_bfloat16 *>(ptr); + } else { + assert(false && "Invalid Accessor Type"); + return static_cast<__ff_bfloat16 *>(nullptr); + } +} + char *GenericTensorAccessorW::get_byte_ptr() const { if (data_type == DT_INT4 || data_type == DT_INT8) { return static_cast(ptr); @@ -271,6 +289,11 @@ GenericTensorAccessorR ptr = helperGetTensorPointerRO(region, req, fid, ctx, runtime); break; } + case DT_BFLOAT16: { + ptr = helperGetTensorPointerRO<__ff_bfloat16>( + region, req, fid, ctx, runtime); + break; + } case DT_FLOAT: { ptr = helperGetTensorPointerRO(region, req, fid, ctx, runtime); break; @@ -317,6 +340,11 @@ GenericTensorAccessorW ptr = helperGetTensorPointerWO(region, req, fid, ctx, runtime); break; } + case DT_BFLOAT16: { + ptr = helperGetTensorPointerWO<__ff_bfloat16>( + region, req, fid, ctx, runtime); + break; + } case DT_FLOAT: { ptr = helperGetTensorPointerWO(region, req, fid, ctx, runtime); break; @@ -363,6 +391,11 @@ GenericTensorAccessorW ptr = helperGetTensorPointerRW(region, req, fid, ctx, runtime); break; } + case DT_BFLOAT16: { + ptr = helperGetTensorPointerRW<__ff_bfloat16>( + region, req, fid, ctx, runtime); + break; + } case DT_FLOAT: { ptr = helperGetTensorPointerRW(region, req, fid, ctx, runtime); break; diff --git a/src/runtime/cuda_helper.cu b/src/runtime/cuda_helper.cu index 9e4df18ff..ff5205c7d 100644 --- a/src/runtime/cuda_helper.cu +++ b/src/runtime/cuda_helper.cu @@ -12,6 +12,16 @@ using Legion::coord_t; using Legion::Domain; using Legion::Rect; +#ifdef FF_USE_CUDA +#include // for __nv_bfloat16 +typedef __nv_bfloat16 __ff_bfloat16; +#elif FF_USE_HIP_CUDA +#include +typedef hip_bfloat16 __ff_bfloat16; +#else +#error "Unknown device, please make sure if CUDA is enabled" +#endif + namespace FlexFlow { #ifdef FF_USE_CUDA @@ -87,6 +97,12 @@ __host__ void relu_backward_kernel(DataType data_type, reluBackward <<>>( (half *)output_grad_ptr, (half const *)output_ptr, output_size); + } else if (data_type == DT_BFLOAT16) { + reluBackward + <<>>( + (nv_bfloat16 *)output_grad_ptr, + (nv_bfloat16 const *)output_ptr, + output_size); } else if (data_type == DT_FLOAT) { reluBackward <<>>( @@ -105,7 +121,7 @@ template __global__ void sigmoid_backward_function(DT *grad_ptr, const DT *output, size_t n) { CUDA_KERNEL_LOOP(i, n) { - grad_ptr[i] = grad_ptr[i] * output[i] * (1.0f - output[i]); + grad_ptr[i] = grad_ptr[i] * output[i] * (static_cast
(1.0f) - output[i]); } } @@ -118,6 +134,12 @@ __host__ void sigmoid_backward_kernel(DataType data_type, sigmoid_backward_function <<>>( (float *)output_grad_ptr, (float const *)output_ptr, output_size); + } else if (data_type == DT_BFLOAT16) { + sigmoid_backward_function + <<>>( + (nv_bfloat16 *)output_grad_ptr, + (nv_bfloat16 const *)output_ptr, + output_size); } else if (data_type == DT_DOUBLE) { sigmoid_backward_function <<>>( @@ -437,6 +459,8 @@ torch::Tensor return createTorchTensorFromCuda(tensor.ptr, dims); } else if (tensor.data_type == DT_HALF) { return createTorchTensorFromCuda(tensor.ptr, dims); + } else if (tensor.data_type == DT_BFLOAT16) { + return createTorchTensorFromCuda(tensor.ptr, dims); } else if (tensor.data_type == DT_DOUBLE) { return createTorchTensorFromCuda(tensor.ptr, dims); } else if (tensor.data_type == DT_INT32) { @@ -611,6 +635,8 @@ cudnnDataType_t ff_to_cudnn_datatype(DataType type) { return CUDNN_DATA_DOUBLE; case DT_INT32: return CUDNN_DATA_INT32; + case DT_BFLOAT16: + return CUDNN_DATA_BFLOAT16; default: assert(false && "Unsupported cudnn data type"); } @@ -627,6 +653,8 @@ cudaDataType_t ff_to_cuda_datatype(DataType type) { return CUDA_R_64F; case DT_INT32: return CUDA_R_32I; + case DT_BFLOAT16: + return CUDA_R_16BF; default: assert(false && "Unspoorted cuda data type"); } @@ -640,6 +668,8 @@ ncclDataType_t ff_to_nccl_datatype(DataType type) { return ncclHalf; case DT_FLOAT: return ncclFloat; + case DT_BFLOAT16: + return ncclBfloat16; case DT_DOUBLE: return ncclDouble; case DT_INT32: @@ -659,6 +689,8 @@ cudaDataType_t cudnn_to_cuda_datatype(cudnnDataType_t type) { return CUDA_R_64F; case CUDNN_DATA_INT32: return CUDA_R_32I; + case CUDNN_DATA_BFLOAT16: + return CUDA_R_16BF; default: assert(false && "Unsupported cuda data type"); } @@ -673,6 +705,8 @@ cudnnDataType_t cuda_to_cudnn_datatype(cudaDataType_t type) { return CUDNN_DATA_DOUBLE; case CUDA_R_32I: return CUDNN_DATA_INT32; + case CUDA_R_16BF: + return CUDNN_DATA_BFLOAT16; default: assert(false && "Unsupported cudnn data type"); } @@ -725,6 +759,9 @@ template __global__ void assign_kernel(half *ptr, coord_t size, half value); template __global__ void assign_kernel(float *ptr, coord_t size, float value); +template __global__ void assign_kernel<__ff_bfloat16>(__ff_bfloat16 *ptr, + coord_t size, + __ff_bfloat16 value); template __global__ void assign_kernel(double *ptr, coord_t size, double value); template __global__ void @@ -736,6 +773,10 @@ template __global__ void scale_kernel(half *ptr, coord_t size, half a, half b); template __global__ void scale_kernel(float *ptr, coord_t size, float a, float b); +template __global__ void scale_kernel<__ff_bfloat16>(__ff_bfloat16 *ptr, + coord_t size, + __ff_bfloat16 a, + __ff_bfloat16 b); template __global__ void scale_kernel(double *ptr, coord_t size, double a, double b); @@ -743,6 +784,9 @@ template __global__ void add_kernel(half *dst, half const *src, size_t size); template __global__ void add_kernel(float *dst, float const *src, size_t size); +template __global__ void add_kernel<__ff_bfloat16>(__ff_bfloat16 *dst, + __ff_bfloat16 const *src, + size_t size); template __global__ void add_kernel(double *dst, double const *src, size_t size); template __global__ void @@ -754,6 +798,9 @@ template __global__ void copy_kernel(half *dst, half const *src, coord_t size); template __global__ void copy_kernel(float *dst, float const *src, coord_t size); +template __global__ void copy_kernel<__ff_bfloat16>(__ff_bfloat16 *dst, + __ff_bfloat16 const *src, + coord_t size); template __global__ void copy_kernel(double *dst, double const *src, coord_t size); template __global__ void @@ -838,6 +885,9 @@ template __host__ void template torch::Tensor createTorchTensorFromCuda(void *cudaData, std::vector const &dims); +template torch::Tensor + createTorchTensorFromCuda<__ff_bfloat16>(void *cudaData, + std::vector const &dims); template torch::Tensor createTorchTensorFromCuda(void *cudaData, std::vector const &dims); @@ -853,6 +903,9 @@ template torch::Tensor template torch::Tensor createTorchTensorFromCuda(void const *cudaData, std::vector const &dims); +template torch::Tensor + createTorchTensorFromCuda<__ff_bfloat16>(void const *cudaData, + std::vector const &dims); template torch::Tensor createTorchTensorFromCuda(void const *cudaData, std::vector const &dims); @@ -870,6 +923,8 @@ template __host__ float *copy_tensor_dev_to_host(float const *ptr, size_t num_elements); template __host__ half *copy_tensor_dev_to_host(half const *ptr, size_t num_elements); +template __host__ __ff_bfloat16 *copy_tensor_dev_to_host<__ff_bfloat16>(__ff_bfloat16 const *ptr, + size_t num_elements); template __host__ double *copy_tensor_dev_to_host(double const *ptr, size_t num_elements); template __host__ int32_t * @@ -882,6 +937,9 @@ template __host__ void copy_tensor_dev_to_host(float const *ptr, template __host__ void copy_tensor_dev_to_host(half const *ptr, half *dst, size_t num_elements); +template __host__ void copy_tensor_dev_to_host<__ff_bfloat16>(__ff_bfloat16 const *ptr, + __ff_bfloat16 *dst, + size_t num_elements); template __host__ void copy_tensor_dev_to_host(double const *ptr, double *dst, size_t num_elements); @@ -897,6 +955,9 @@ template __host__ void copy_tensor_host_to_dev(float *dst, template __host__ void copy_tensor_host_to_dev(half *dst, half const *src, size_t num_elements); +template __host__ void copy_tensor_host_to_dev<__ff_bfloat16>(__ff_bfloat16 *dst, + __ff_bfloat16 const *src, + size_t num_elements); template __host__ void copy_tensor_host_to_dev(double *dst, double const *src, size_t num_elements); diff --git a/src/runtime/ffconst_utils.cc b/src/runtime/ffconst_utils.cc index 5a7d98b4d..45467f1ee 100644 --- a/src/runtime/ffconst_utils.cc +++ b/src/runtime/ffconst_utils.cc @@ -218,6 +218,8 @@ size_t data_type_size(DataType type) { switch (type) { case DT_HALF: return sizeof(half); + case DT_BFLOAT16: + return sizeof(__ff_bfloat16); case DT_FLOAT: return sizeof(float); case DT_DOUBLE: diff --git a/src/runtime/file_loader.cc b/src/runtime/file_loader.cc index e87594b00..d056c1a92 100644 --- a/src/runtime/file_loader.cc +++ b/src/runtime/file_loader.cc @@ -30,12 +30,12 @@ FileDataLoader::FileDataLoader(std::string _prompts_filepath, size_t _hidden_dim, size_t _qkv_inner_dim, int _tensor_parallelism_degree, - bool _use_full_precision) + DataType _data_type) : prompts_filepath(_prompts_filepath), weights_folder(_weights_folder), num_heads(_num_heads), num_kv_heads(_num_kv_heads), hidden_dim(_hidden_dim), qkv_inner_dim(_qkv_inner_dim), tensor_parallelism_degree(_tensor_parallelism_degree), - use_full_precision(_use_full_precision){}; + data_type(_data_type){}; BatchConfig::TokenId *FileDataLoader::generate_requests(int num, int length) { @@ -366,8 +366,8 @@ void load_attention_weights_quantized(char *ptr, size_t qkv_inner_dim, std::string layer_name, std::string weights_folder, - DataType data_type, - bool use_full_precision) { + DataType quantized_data_type, + DataType dequantized_data_type) { std::string q_file = layer_name + ".q_proj.weight"; std::string k_file = layer_name + ".k_proj.weight"; std::string v_file = layer_name + ".v_proj.weight"; @@ -406,7 +406,7 @@ void load_attention_weights_quantized(char *ptr, } assert(partial_size == host_array.size()); - size_t one_head_size = data_type == DT_INT8 + size_t one_head_size = quantized_data_type == DT_INT8 ? hidden_dim * (hidden_dim / num_heads) : hidden_dim * (hidden_dim / num_heads) / 2; @@ -414,7 +414,7 @@ void load_attention_weights_quantized(char *ptr, for (int i = 0; i < num_heads; i++) { size_t start_index = i * one_head_size * 4 + file_index * one_head_size; for (size_t j = start_index; j < start_index + one_head_size; j++) { - if (data_type == DT_INT4) { + if (quantized_data_type == DT_INT4) { char v1 = host_array.at(data_index); char v2 = host_array.at(data_index + 1); ptr[j] = (v2 & 0XF) | (v1 << 4); @@ -431,7 +431,7 @@ void load_attention_weights_quantized(char *ptr, // load scale and offset to the end of weight tensor // the layout is like |values * 32 heads|offset|scale| - size_t offset = data_type == DT_INT8 ? one_weight_file_size * 4 + size_t offset = quantized_data_type == DT_INT8 ? one_weight_file_size * 4 : (one_weight_file_size * 4) / 2; for (auto filename : weight_filenames) { std::cout << "Loading weight file " << filename << std::endl; @@ -448,7 +448,7 @@ void load_attention_weights_quantized(char *ptr, } assert(in.good() && "incorrect weight file path"); - if (use_full_precision) { + if (dequantized_data_type == DT_FLOAT) { // float std::vector host_array(partial_size); size_t loaded_data_size = sizeof(float) * partial_size; @@ -467,7 +467,7 @@ void load_attention_weights_quantized(char *ptr, *(float *)(ptr + offset) = v; offset += sizeof(float); } - } else { + } else if (dequantized_data_type == DT_HALF) { // half std::vector host_array(partial_size); size_t loaded_data_size = sizeof(half) * partial_size; @@ -485,6 +485,27 @@ void load_attention_weights_quantized(char *ptr, *(half *)(ptr + offset) = v; offset += sizeof(half); } + } else if (dequantized_data_type == DT_BFLOAT16) { + // bfloat16 + std::vector<__ff_bfloat16> host_array(partial_size); + size_t loaded_data_size = sizeof(__ff_bfloat16) * partial_size; + in.seekg(0, in.end); + in.seekg(0, in.beg); + in.read((char *)host_array.data(), loaded_data_size); + size_t in_get_size = in.gcount(); + + if (in_get_size != loaded_data_size) { + std::cout << "load data error"; + return; + } + assert(partial_size == host_array.size()); + + for (auto v : host_array) { + *(__ff_bfloat16 *)(ptr + offset) = v; + offset += sizeof(__ff_bfloat16); + } + } else { + assert(false && "Unsupported data type"); } } } @@ -493,27 +514,27 @@ void load_attention_weights_quantized(char *ptr, void load_from_quantized_file(char *ptr, size_t size, std::string filename, - DataType data_type, - bool use_full_precision) { - assert(data_type == DT_INT4 || data_type == DT_INT8); + DataType quantized_data_type, + DataType dequantized_data_type) { + assert(quantized_data_type == DT_INT4 || quantized_data_type == DT_INT8); std::string value_file = filename; std::string offset_file = filename + "_offset"; std::string scaling_file = filename + "_scale"; size_t value_size = 0, offset_size = 0, scaling_size = 0; - if (data_type == DT_INT4) { + if (quantized_data_type == DT_INT4) { // float/half + 4bit quantization // size1 = volume / 2, size2 = volume / 32 * (sizeof(DT)), size3 = size2 - value_size = 2 * (use_full_precision ? (size * 2 / 3) : (size * 4 / 5)); - offset_size = use_full_precision ? (size / 6) : (size / 10); - scaling_size = use_full_precision ? (size / 6) : (size / 10); - } else if (data_type == DT_INT8) { + value_size = 2 * (dequantized_data_type == DT_FLOAT ? (size * 2 / 3) : (size * 4 / 5)); + offset_size = dequantized_data_type == DT_FLOAT ? (size / 6) : (size / 10); + scaling_size = dequantized_data_type == DT_FLOAT ? (size / 6) : (size / 10); + } else if (quantized_data_type == DT_INT8) { // float/half + 8bit quantization // size1 = volume * 1, size2 = volume / 32 * (sizeof(DT)), size3 = size2 - value_size = use_full_precision ? (size * 4 / 5) : (size * 8 / 9); - offset_size = use_full_precision ? (size / 10) : (size / 18); - scaling_size = use_full_precision ? (size / 10) : (size / 18); + value_size = dequantized_data_type == DT_FLOAT ? (size * 4 / 5) : (size * 8 / 9); + offset_size = dequantized_data_type == DT_FLOAT ? (size / 10) : (size / 18); + scaling_size = dequantized_data_type == DT_FLOAT ? (size / 10) : (size / 18); } std::vector quantized_files = { @@ -549,7 +570,7 @@ void load_from_quantized_file(char *ptr, // normal size_t idx = 0; while (idx < host_array.size()) { - if (data_type == DT_INT4) { + if (quantized_data_type == DT_INT4) { // pack 2 elements into one byte char v1 = host_array.at(idx); char v2 = host_array.at(idx + 1); @@ -560,7 +581,7 @@ void load_from_quantized_file(char *ptr, ptr[data_index++] = host_array.at(idx++); } } - } else if (use_full_precision) { + } else if (dequantized_data_type == DT_FLOAT) { // load offset/scale in float type; size = quantized_sizes.at(file_idx); std::vector host_array(size / sizeof(float)); @@ -581,8 +602,7 @@ void load_from_quantized_file(char *ptr, *(float *)(ptr + data_index) = v; data_index += sizeof(float); } - - } else { + } else if (dequantized_data_type == DT_HALF) { // load offset/scale in half type; size = quantized_sizes.at(file_idx); std::vector host_array(size / sizeof(half)); @@ -603,6 +623,29 @@ void load_from_quantized_file(char *ptr, *(half *)(ptr + data_index) = v; data_index += sizeof(half); } + } else if (dequantized_data_type == DT_BFLOAT16) { + // load offset/scale in bfloat16 type; + size = quantized_sizes.at(file_idx); + std::vector<__ff_bfloat16> host_array(size / sizeof(__ff_bfloat16)); + size_t loaded_data_size = size; + in.seekg(0, in.end); + in.seekg(0, in.beg); + in.read((char *)host_array.data(), loaded_data_size); + + size_t in_get_size = in.gcount(); + if (in_get_size != loaded_data_size) { + std::cout << "load weight data error " << in_get_size << ", " + << loaded_data_size << ", " << sizeof(__ff_bfloat16) << std::endl; + return; + } + assert(size / sizeof(__ff_bfloat16) == host_array.size()); + // normal + for (auto v : host_array) { + *(__ff_bfloat16 *)(ptr + data_index) = v; + data_index += sizeof(__ff_bfloat16); + } + } else { + assert(false && "Unsupported data type"); } in.close(); file_idx++; @@ -615,7 +658,7 @@ void FileDataLoader::load_quantization_weight(FFModel *ff, size_t volume, size_t num_replicas, char *weight, - DataType data_type, + DataType quantized_data_type, Domain weight_domain) { size_t volume_ = 1; std::vector dims_vec; @@ -639,8 +682,8 @@ void FileDataLoader::load_quantization_weight(FFModel *ff, qkv_inner_dim, weight_filename, weights_folder, - data_type, - use_full_precision); + quantized_data_type, + data_type); } // else { // load_attention_bias_quantized(data, @@ -661,8 +704,8 @@ void FileDataLoader::load_quantization_weight(FFModel *ff, load_from_quantized_file(data, volume, join_path({weights_folder, weight_filename}), - data_type, - use_full_precision); + quantized_data_type, + data_type); } char *ptr = weight; @@ -831,6 +874,16 @@ void FileDataLoader::load_weight_task( weight_domain); break; } + case DT_BFLOAT16: { + args->loader->load_single_weight_tensor<__ff_bfloat16>(args->ff, + args->layer, + args->weight_idx, + args->volume, + args->num_replicas, + weight.get_bfloat16_ptr(), + weight_domain); + break; + } case DT_INT4: case DT_INT8: { args->loader->load_quantization_weight(args->ff, @@ -865,7 +918,8 @@ void FileDataLoader::load_weights_parallel(FFModel *ff, } if (weight->data_type != DT_FLOAT && weight->data_type != DT_HALF && - weight->data_type != DT_INT4 && weight->data_type != DT_INT8) { + weight->data_type != DT_BFLOAT16 && weight->data_type != DT_INT4 && + weight->data_type != DT_INT8) { assert(false && "Unsupported data type"); } diff --git a/src/runtime/hip_helper.cpp b/src/runtime/hip_helper.cpp index 057be8f44..bb58d7cd4 100644 --- a/src/runtime/hip_helper.cpp +++ b/src/runtime/hip_helper.cpp @@ -562,6 +562,8 @@ miopenDataType_t ff_to_cudnn_datatype(DataType type) { return miopenFloat; case DT_INT32: return miopenInt32; + case DT_BFLOAT16: + return miopenBFloat16; default: assert(false && "Unsupported cudnn data type"); } @@ -576,6 +578,8 @@ hipblasDatatype_t ff_to_cuda_datatype(DataType type) { return HIPBLAS_R_64F; case DT_INT32: return HIPBLAS_R_32I; + case DT_BFLOAT16: + return HIPBLAS_R_16BF; default: assert(false && "Unspoorted cuda data type"); } @@ -592,6 +596,8 @@ ncclDataType_t ff_to_nccl_datatype(DataType type) { return ncclDouble; case DT_INT32: return ncclInt; + case DT_BFLOAT16: + return ncclBFloat16; default: assert(false && "Unspoorted nccl data type"); } @@ -765,6 +771,9 @@ template __host__ float *copy_tensor_dev_to_host(float const *ptr, size_t num_elements); template __host__ half *copy_tensor_dev_to_host(half const *ptr, size_t num_elements); +template __host__ __ff_bfloat16 * + copy_tensor_dev_to_host<__ff_bfloat16>(__ff_bfloat16 const *ptr, + size_t num_elements); template __host__ double *copy_tensor_dev_to_host(double const *ptr, size_t num_elements); template __host__ int32_t * @@ -777,6 +786,8 @@ template __host__ void copy_tensor_dev_to_host(float const *ptr, template __host__ void copy_tensor_dev_to_host(half const *ptr, half *dst, size_t num_elements); +template __host__ void copy_tensor_dev_to_host<__ff_bfloat16>( + __ff_bfloat16 const *ptr, __ff_bfloat16 *dst, size_t num_elements); template __host__ void copy_tensor_dev_to_host(double const *ptr, double *dst, size_t num_elements); @@ -792,6 +803,8 @@ template __host__ void copy_tensor_host_to_dev(float *dst, template __host__ void copy_tensor_host_to_dev(half *dst, half const *src, size_t num_elements); +template __host__ void copy_tensor_host_to_dev<__ff_bfloat16>( + __ff_bfloat16 *dst, __ff_bfloat16 const *src, size_t num_elements); template __host__ void copy_tensor_host_to_dev(double *dst, double const *src, size_t num_elements); diff --git a/src/runtime/initializer_kernel.cu b/src/runtime/initializer_kernel.cu index b6629ec90..c9cf4a97e 100644 --- a/src/runtime/initializer_kernel.cu +++ b/src/runtime/initializer_kernel.cu @@ -241,6 +241,12 @@ void ZeroInitializer::init_task(Task const *task, assign_kernel <<>>( w, domain.get_volume(), 0.0f); + } else if (meta->data_types[i] == DT_BFLOAT16) { + __ff_bfloat16 *w = helperGetTensorPointerWO<__ff_bfloat16>( + regions[i], task->regions[i], FID_DATA, ctx, runtime); + assign_kernel<__ff_bfloat16> + <<>>( + w, domain.get_volume(), 0.0f); } else if (meta->data_types[i] == DT_INT32) { int32_t *w = helperGetTensorPointerWO( regions[i], task->regions[i], FID_DATA, ctx, runtime); diff --git a/src/runtime/model.cc b/src/runtime/model.cc index 7bc818d12..48dcf2962 100644 --- a/src/runtime/model.cc +++ b/src/runtime/model.cc @@ -2105,6 +2105,9 @@ void FFModel::map_tensor_with_dim2(ParallelTensor tensor, case DT_INT8: allocator.allocate_field(sizeof(char), FID_DATA); break; + case DT_BFLOAT16: + allocator.allocate_field(sizeof(__ff_bfloat16), FID_DATA); + break; default: assert(false); } diff --git a/src/runtime/peft_weight_allocator.cc b/src/runtime/peft_weight_allocator.cc index 6d5815449..f1f67ecfe 100644 --- a/src/runtime/peft_weight_allocator.cc +++ b/src/runtime/peft_weight_allocator.cc @@ -268,6 +268,29 @@ void PEFTMemoryManager::load_peft_model(LoraLinearWeight &weight, lora_B_num_shards, shard_id, w1_filepath); + } else if (dt == DT_BFLOAT16) { + std::cout << "Loading LORA weight " << lora_layername_substr + "_A.weight" + << ", num_rows: " << lora_A_num_rows + << ", num_cols: " << lora_A_num_cols + << ", num_shards: " << lora_A_num_shards + << ", shard_id: " << shard_id << std::endl; + load_peft_from_file((__ff_bfloat16 *)weight.w0_ptr, + lora_A_num_rows, + lora_A_num_cols, + lora_A_num_shards, + shard_id, + w0_filepath); + std::cout << "Loading LORA weight " << lora_layername_substr + "_B.weight" + << ", num_rows: " << lora_B_num_rows + << ", num_cols: " << lora_B_num_cols + << ", num_shards: " << lora_B_num_shards + << ", shard_id: " << shard_id << std::endl; + load_peft_from_file((__ff_bfloat16 *)weight.w1_ptr, + lora_B_num_rows, + lora_B_num_cols, + lora_B_num_shards, + shard_id, + w1_filepath); } else { assert(false && "Data type not supported"); } diff --git a/src/runtime/peft_weight_allocator.cu b/src/runtime/peft_weight_allocator.cu index 3c4ea91db..4ab3761d0 100644 --- a/src/runtime/peft_weight_allocator.cu +++ b/src/runtime/peft_weight_allocator.cu @@ -1,5 +1,3 @@ - - #include "flexflow/ops/kernels/decompress_kernels.h" #include "flexflow/utils/cuda_helper.h" #include "flexflow/utils/peft_weight_allocator.h" @@ -28,8 +26,10 @@ void lora_init_kernel(LoraLinearWeight const &weight, std::vector
lora_a_random_init(w0_num_elements); for (auto &num : lora_a_random_init) { float num_float = dis_lora_a(gen); - if (std::is_same::value) { + if constexpr (std::is_same::value) { num = __float2half(num_float); + } else if constexpr (std::is_same::value) { + num = __float2bfloat16(num_float); } else { num = num_float; } @@ -72,6 +72,9 @@ void init_peft_weight_wrapper(LoraLinearWeight const &weight, lora_init_kernel(weight, in_dim, out_dim, rank, seed, stream); } else if (dt == DT_HALF) { lora_init_kernel(weight, in_dim, out_dim, rank, seed, stream); + } else if (dt == DT_BFLOAT16) { + lora_init_kernel<__ff_bfloat16>( + weight, in_dim, out_dim, rank, seed, stream); } else { assert(false && "Unsupported data type"); } diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index e639653bb..6c7b7b6fe 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -174,6 +174,8 @@ bool RequestManager::load_request_token_ids(Request &request) { << request.dataset.size() << std::endl; } else { std::ifstream file_handle(request.peft_finetuning_info.dataset_filepath); + // print the dataset file path + std::cout << "Dataset file path: " << request.peft_finetuning_info.dataset_filepath << std::endl; assert(file_handle.good() && "Dataset file does not exist."); nlohmann::ordered_json dataset_json = nlohmann::ordered_json::parse(file_handle, diff --git a/tests/align/peft_flash_attn/peft_flash_debug_note b/tests/align/peft_flash_attn/peft_flash_debug_note index 7279b89f1..8a2d629ec 100644 --- a/tests/align/peft_flash_attn/peft_flash_debug_note +++ b/tests/align/peft_flash_attn/peft_flash_debug_note @@ -31,4 +31,7 @@ set environment LD_LIBRARY_PATH=/root/flexflow-serve/build:/root/flexflow-serve/ run # single C++ tests: -./build/inference/peft/peft -ll:gpu 2 -ll:cpu 4 -ll:util 4 -tensor-parallelism-degree 2 -ll:fsize 8192 -ll:zsize 12000 --max-requests-per-batch 1 --max-sequence-length 128 --max-tokens-per-batch 128 -llm-model JackFram/llama-160m -finetuning-dataset ./inference/prompt/peft_dataset.json -peft-model goliaro/llama-160m-lora -enable-peft --inference-debugging \ No newline at end of file +# fp16 +./build/inference/peft/peft -ll:gpu 2 -ll:cpu 4 -ll:util 4 -tensor-parallelism-degree 2 -ll:fsize 8192 -ll:zsize 12000 --max-requests-per-batch 1 --max-sequence-length 128 --max-tokens-per-batch 128 -llm-model JackFram/llama-160m -finetuning-dataset ./inference/prompt/peft_dataset.json -peft-model goliaro/llama-160m-lora -enable-peft --inference-debugging +# bf16 +./build/inference/peft/peft -ll:gpu 2 -ll:cpu 4 -ll:util 4 -tensor-parallelism-degree 2 -ll:fsize 8192 -ll:zsize 12000 --max-requests-per-batch 1 --max-sequence-length 128 --max-tokens-per-batch 128 -llm-model JackFram/llama-160m -finetuning-dataset ./inference/prompt/peft_dataset.json -peft-model goliaro/llama-160m-lora -enable-peft --inference-debugging --use-bf16-precision diff --git a/tests/peft/hf_finetune.py b/tests/peft/hf_finetune.py index 6fe4c8831..3896072fb 100644 --- a/tests/peft/hf_finetune.py +++ b/tests/peft/hf_finetune.py @@ -47,6 +47,9 @@ def main(): parser.add_argument( "--use-full-precision", action="store_true", help="Use full precision" ) + parser.add_argument( + "--use-bfloat16-precision", action="store_true", help="Use bfloat16 precision" + ) parser.add_argument("--output-dir", type=str, default="") parser.add_argument("--publish-peft-with-id", type=str, default="") parser.add_argument( @@ -63,7 +66,10 @@ def main(): # Set default tensor type depending on argument indicating the float type to use if not args.use_full_precision: - torch.set_default_dtype(torch.float16) + if args.use_bfloat16_precision: + torch.set_default_dtype(torch.bfloat16) + else: + torch.set_default_dtype(torch.float16) else: torch.set_default_dtype(torch.float32) diff --git a/tests/peft/hf_utils.py b/tests/peft/hf_utils.py index 76e7fcd6e..477c4db39 100644 --- a/tests/peft/hf_utils.py +++ b/tests/peft/hf_utils.py @@ -272,14 +272,18 @@ def build_peft_config(args, finetuning=False): return peft_config -def prepare_model_for_lora_finetuning(model, use_full_precision=False, save_peft_tensors=False): +def prepare_model_for_lora_finetuning(model, use_bfloat16_precision=False, use_full_precision=False, save_peft_tensors=False): # Freeze all layers except the LORA ones. Cast small layers to full precision for stability for name, param in model.named_parameters(): if "lora" not in name: param.requires_grad = False # freeze the model - train adapters later else: param.requires_grad = True - if not use_full_precision: + if use_bfloat16_precision: + param.data = param.data.to(torch.bfloat16) + elif use_full_precision: + param.data = param.data.to(torch.float32) + else: param.data = param.data.to(torch.float16) # cast to fp16 for speed # if param.ndim == 1: # # cast the small parameters (e.g. layernorm) to fp32 for stability @@ -294,13 +298,13 @@ def build_peft_model(args, peft_config): # Load base model, and apply the PEFT layer model = AutoModelForCausalLM.from_pretrained( peft_config.base_model_name_or_path, - torch_dtype=torch.float32 if args.use_full_precision else torch.float16, + torch_dtype=torch.bfloat16 if args.use_bfloat16_precision else torch.float32 if args.use_full_precision else torch.float16, device_map="auto", attn_implementation="eager", ) model = PeftModel.from_pretrained(model, args.peft_model_id, config=peft_config, - torch_dtype=torch.float32 if args.use_full_precision else torch.float16,) - model = prepare_model_for_lora_finetuning(model, args.use_full_precision, args.save_peft_tensors) + torch_dtype=torch.bfloat16 if args.use_bfloat16_precision else torch.float32 if args.use_full_precision else torch.float16,) + model = prepare_model_for_lora_finetuning(model, args.use_bfloat16_precision, args.use_full_precision, args.save_peft_tensors) return model @@ -308,7 +312,7 @@ def get_peft_tokenizer(args, peft_config): # Get Tokenizer tokenizer = AutoTokenizer.from_pretrained( peft_config.base_model_name_or_path, - torch_dtype=torch.float32 if args.use_full_precision else torch.float16, + torch_dtype=torch.bfloat16 if args.use_bfloat16_precision else torch.float32 if args.use_full_precision else torch.float16, ) if tokenizer.pad_token is None: tokenizer.pad_token = "[PAD]" diff --git a/tests/peft_test.sh b/tests/peft_test.sh index 3c041835e..8a0d260a5 100755 --- a/tests/peft_test.sh +++ b/tests/peft_test.sh @@ -17,6 +17,7 @@ TP_DEGREE=${TP_DEGREE:-4} PP_DEGREE=${PP_DEGREE:-1} FF_CACHE_PATH=${FF_CACHE_PATH:-"~/.cache/flexflow"} FULL_PRECISION=${FULL_PRECISION:-false} +BFLOAT16_PRECISION=${BFLOAT16_PRECISION:-true} FUSION=${FUSION:-false} # false because we save the debugging tensors in lora_linear.cc LEARNING_RATE=${LEARNING_RATE:-0.001} NUM_GPUS=$((TP_DEGREE * PP_DEGREE)) @@ -46,10 +47,11 @@ export LEGION_BACKTRACE=1 python ./inference/utils/download_peft_model.py "${MODEL_NAME}" if [ "$FULL_PRECISION" = "true" ]; then full_precision_flag="--use-full-precision"; else full_precision_flag=""; fi +if [ "$BFLOAT16_PRECISION" = "true" ]; then bfloat16_precision_flag="--use-bfloat16-precision"; else bfloat16_precision_flag=""; fi if [ "$FUSION" = "true" ]; then fusion_flag="--fusion"; else fusion_flag=""; fi # Run PEFT in Huggingface to get ground truth tensors -eval python ./tests/peft/hf_finetune.py --peft-model-id "${MODEL_NAME}" --save-peft-tensors "${full_precision_flag}" -lr "${LEARNING_RATE}" +eval python ./tests/peft/hf_finetune.py --peft-model-id "${MODEL_NAME}" --save-peft-tensors "${full_precision_flag}" "${bfloat16_precision_flag}" -lr "${LEARNING_RATE}" # Python test echo "Python test" @@ -90,6 +92,7 @@ python ./tests/peft/peft_alignment_test.py -m "${MODEL_NAME}" -tp "${TP_DEGREE}" # C++ test echo "C++ test" +if [ "$BFLOAT16_PRECISION" = "true" ]; then bfloat16_precision_flag="--use-bf16-precision"; else bfloat16_precision_flag=""; fi ./build/inference/peft/peft \ -ll:gpu ${NUM_GPUS} -ll:cpu 4 -ll:util 4 \ -tensor-parallelism-degree "${TP_DEGREE}" \ @@ -101,10 +104,12 @@ echo "C++ test" -finetuning-dataset ./inference/prompt/peft_dataset.json \ -peft-model "$MODEL_NAME" \ -enable-peft \ - "${full_precision_flag}" "${fusion_flag}" --inference-debugging + ${full_precision_flag} \ + ${fusion_flag} \ + ${bfloat16_precision_flag} \ + --inference-debugging + -# Check alignment -python ./tests/peft/peft_alignment_test.py -m "${MODEL_NAME}" -tp "${TP_DEGREE}" -lr "${LEARNING_RATE}" # Print succeess message echo "" diff --git a/triton/src/operators/binary.cu b/triton/src/operators/binary.cu index 7410d098f..e0e031e51 100644 --- a/triton/src/operators/binary.cu +++ b/triton/src/operators/binary.cu @@ -188,6 +188,16 @@ __host__ output_ptr)); break; } + case DataType::DT_BFLOAT16: { + __ff_bfloat16 alpha0 = 1.f, + alpha1 = (args->op_type == OperatorType::OP_EW_SUB) ? -1.f : 1.f, + beta = 0.f; + CHECK_CUDNN(cudnnOpTensor( + args->cudnn, args->opDesc, &alpha0, args->input0Tensor, input0_ptr, + &alpha1, args->input1Tensor, input1_ptr, &beta, args->outputTensor, + output_ptr)); + break; + } case DataType::DT_INT8: { int8_t alpha0 = 1, alpha1 = (args->op_type == OperatorType::OP_EW_SUB) ? -1 : 1, @@ -233,6 +243,13 @@ __host__ (float*)output_ptr, alpha, beta, args->op_type, num_elements); break; } + case DataType::DT_BFLOAT16: { + __ff_bfloat16 alpha = 1.f, beta = 0.f; + binary_forward_bfloat16<<>>( + (const __ff_bfloat16*)input0_ptr, (const __ff_bfloat16*)input1_ptr, + (__ff_bfloat16*)output_ptr, alpha, beta, args->op_type, num_elements); + break; + } case DataType::DT_INT8: { int8_t alpha = 1, beta = 0; binary_forward_int8<<>>( diff --git a/triton/src/operators/matmul.cc b/triton/src/operators/matmul.cc index e972665ad..7bda013c2 100644 --- a/triton/src/operators/matmul.cc +++ b/triton/src/operators/matmul.cc @@ -806,8 +806,9 @@ MatMul::forward_gpu( const uint8_t* in2_local = in2_ptr + in2_offset; uint8_t* out_local = out_ptr + out_offset; switch (args->out_datatype) { - // Use 32-bit intermediate for 16-bit float + // Use 32-bit intermediate for 16-bit float and bfloat16 case DT_HALF: + case DT_BFLOAT16: case DT_FLOAT: { float alpha = 1.f, beta = 0.f; CHECK_CUBLAS(cublasGemmStridedBatchedEx( @@ -853,8 +854,9 @@ MatMul::forward_gpu( // This is the easy case where there are no broadcasts // so we can do the full batch matmul in a single call switch (args->out_datatype) { - // Use 32-bit intermediate for 16-bit float + // Use 32-bit intermediate for 16-bit float and bfloat16 case DT_HALF: + case DT_BFLOAT16: case DT_FLOAT: { float alpha = 1.f, beta = 0.f; CHECK_CUBLAS(cublasGemmStridedBatchedEx( diff --git a/triton/src/operators/unary.cc b/triton/src/operators/unary.cc index c0cdf8a37..cd905f23a 100644 --- a/triton/src/operators/unary.cc +++ b/triton/src/operators/unary.cc @@ -38,6 +38,10 @@ UnaryOperator::UnaryOperator( memcpy(&scalar.half_value, scalar_value, sizeof(__half)); break; } + case DT_BFLOAT16: { + memcpy(&scalar.bfloat16_value, scalar_value, sizeof(__ff_bfloat16)); + break; + } case DT_FLOAT: { memcpy(&scalar.float_value, scalar_value, sizeof(float)); break; @@ -111,6 +115,10 @@ UnaryOperator::Load(Realm::Processor proc) proc_args.scalar.half_value = scalar.half_value; break; } + case DT_BFLOAT16: { + proc_args.scalar.bfloat16_value = scalar.bfloat16_value; + break; + } case DT_FLOAT: { proc_args.scalar.float_value = scalar.float_value; break; diff --git a/triton/src/operators/unary.cu b/triton/src/operators/unary.cu index ef4c77f8e..d30363065 100644 --- a/triton/src/operators/unary.cu +++ b/triton/src/operators/unary.cu @@ -281,6 +281,11 @@ forward_cast( gpu_forward_cast<<>>( (const T*)input_ptr, (float*)output_ptr, num_elements); break; + } + case DT_BFLOAT16: { + gpu_forward_cast<<>>( + (const T*)input_ptr, (__ff_bfloat16*)output_ptr, num_elements); + break; } case DT_DOUBLE: { gpu_forward_cast<<>>( @@ -349,6 +354,11 @@ __host__ forward_cast<__half>( args->casttype, stream, input_ptr, output_ptr, num_elements); break; + } + case DT_BFLOAT16: { + forward_cast<__ff_bfloat16>( + args->casttype, stream, input_ptr, output_ptr, num_elements); + break; } case DT_FLOAT: { forward_cast( @@ -437,6 +447,13 @@ __host__ (const __half*)input_ptr, (__half*)output_ptr, alpha, beta, args->scalar.half_value, args->op_type, num_elements); break; + } + case DT_BFLOAT16: { + __ff_bfloat16 alpha = 1.f, beta = 0.f; + unary_forward_bfloat16<<>>( + (const __ff_bfloat16*)input_ptr, (__ff_bfloat16*)output_ptr, alpha, beta, + args->scalar.bfloat16_value, args->op_type, num_elements); + break; } case DT_FLOAT: { float alpha = 1.f, beta = 0.f; diff --git a/triton/src/operators/unary.h b/triton/src/operators/unary.h index 33723d90a..1965c0656 100644 --- a/triton/src/operators/unary.h +++ b/triton/src/operators/unary.h @@ -39,6 +39,7 @@ struct UnaryArgs : public OperatorArgs { union { int8_t int8_value; __half half_value; + __ff_bfloat16 bfloat16_value; float float_value; double double_value; } scalar; @@ -94,6 +95,7 @@ class UnaryOperator : public Operator { union { int8_t int8_value; __half half_value; + __ff_bfloat16 bfloat16_value; float float_value; double double_value; } scalar;