Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions kernels/csrc/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,22 @@ struct m16_n16_k32_c_fragment {
};

template<typename>
constexpr char ptx_type_name[] = "unknown_dtype";
inline constexpr char ptx_type_name[] = "unknown_dtype";

template<>
constexpr char ptx_type_name<float>[4] = "f32";
inline constexpr char ptx_type_name<float>[4] = "f32";

template<>
constexpr char ptx_type_name<half>[4] = "f16";
inline constexpr char ptx_type_name<half>[4] = "f16";

template<>
constexpr char ptx_type_name<nv_bfloat16>[5] = "bf16";
inline constexpr char ptx_type_name<nv_bfloat16>[5] = "bf16";

template<>
constexpr char ptx_type_name<__nv_fp8_e4m3>[5] = "e4m3";
inline constexpr char ptx_type_name<__nv_fp8_e4m3>[5] = "e4m3";

template<>
constexpr char ptx_type_name<__nv_fp8_e5m2>[5] = "e5m2";
inline constexpr char ptx_type_name<__nv_fp8_e5m2>[5] = "e5m2";

__device__ __forceinline__ m16_n16_a_fragment<nv_bfloat16> load_fragment_a(int lane_id, const nv_bfloat16* base, int ldd) {
// see https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-16816-i8-f8
Expand Down
6 changes: 4 additions & 2 deletions kernels/python/quartet2/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,10 @@ def forward(ctx, input, weight, had, mode: NVFP4QuantMode, disable_backward_quan
if autocast_enabled:
input = input.to(torch.bfloat16)
weight = weight.to(torch.bfloat16)

assert input.dtype == torch.bfloat16
elif weight.dtype != torch.bfloat16:
raise TypeError("Weight must be bfloat16. Either set `dtype=torch.bfloat16` or enable autocast`")
elif input.dtype != torch.bfloat16:
raise TypeError("Input must be bfloat16. Either cast input to bfloat16 or enable autocast`")
Comment on lines +168 to +170

forward_scale_override = 1.0

Expand Down
Loading