Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions include/flexflow/accessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,25 @@

#if defined(FF_USE_CUDA)
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#elif defined(FF_USE_HIP_CUDA)
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#elif defined(FF_USE_HIP_ROCM)
#include <hip/hip_fp16.h>
#include <hip_bfloat16.h>
#endif

// using namespace Legion;

namespace FlexFlow {

#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA)
typedef __nv_bfloat16 __ff_bfloat16;
#elif defined(FF_USE_HIP_ROCM)
typedef hip_bfloat16 __ff_bfloat16;
#endif

template <typename FT, int N, typename T = Legion::coord_t>
using AccessorRO =
Legion::FieldAccessor<READ_ONLY, FT, N, T, Realm::AffineAccessor<FT, N, T>>;
Expand Down Expand Up @@ -61,6 +70,7 @@ class GenericTensorAccessorW {
float *get_float_ptr() const;
double *get_double_ptr() const;
half *get_half_ptr() const;
__ff_bfloat16 *get_bfloat16_ptr() const;
char *get_byte_ptr() const;
DataType data_type;
Legion::Domain domain;
Expand All @@ -80,6 +90,7 @@ class GenericTensorAccessorR {
float const *get_float_ptr() const;
double const *get_double_ptr() const;
half const *get_half_ptr() const;
__ff_bfloat16 const *get_bfloat16_ptr() const;
char const *get_byte_ptr() const;
DataType data_type;
Legion::Domain domain;
Expand Down
9 changes: 8 additions & 1 deletion include/flexflow/ffconst.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ enum DataType {
DT_INT4 = 46,
DT_INT8 = 47,
DT_NONE = 49,
DT_BFLOAT16 = 50, // placeholder for bfloat16
DT_BFLOAT16 = 50,
};

enum LossType {
Expand Down Expand Up @@ -287,4 +287,11 @@ enum {
PEFT_MODEL_ID_FIRST_VALID = 6000000,
PEFT_MODEL_ID_LAST_VALID = 6999999
};

enum InferencePrecision {
INFERENCE_FLOAT = 800,
INFERENCE_HALF = 801,
INFERENCE_BFLOAT16 = 802,
};

#endif // _FLEXFLOW_CONST_H_
Loading
Loading