diff --git a/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp b/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp index 62fb6bbcb29..18607399492 100644 --- a/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp +++ b/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp @@ -86,7 +86,8 @@ float a16w4_moe_gemm(const MoeFlatmmHostArgs& args, const ck_tile::stream_config FlatmmConfig::NumWaveGroups, true>; // Preshuffle_ - constexpr bool MXFP4_Pipeline = std::is_same_v; + // TODO(yadai): rename to W4_Pipeline + constexpr bool MXFP4_Pipeline = std::is_same_v | std::is_same_v; if constexpr(!MXFP4_Pipeline && moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up) { @@ -444,6 +445,22 @@ int run_a16w4_moe_flatmm_example(int argc, char* argv[]) FlatmmConfig, ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{}); } + else if(mixed_prec == "fp16xint4") + { + return run_a16w4_moe_gemm_example_with_layouts< + ck_tile::half_t, + ck_tile::pk_int4_t, + FlatmmConfig, + ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{}); + } + else if(mixed_prec == "bf16xint4") + { + return run_a16w4_moe_gemm_example_with_layouts< + ck_tile::bfloat16_t, + ck_tile::pk_int4_t, + FlatmmConfig, + ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{}); + } else { throw std::runtime_error("Unsupported precision type for gemm1_gate_up!"); @@ -498,6 +515,18 @@ int main(int argc, char* argv[]) { return !run_a16w4_moe_flatmm_example(argc, argv); } + else if (warp_tile == 1) { + return !run_a16w4_moe_flatmm_example(argc, argv); + } + else if (warp_tile == 2) { + return !run_a16w4_moe_flatmm_example(argc, argv); + } + else if (warp_tile == 3) { + return !run_a16w4_moe_flatmm_example(argc, argv); + } + else if (warp_tile == 4) { + return !run_a16w4_moe_flatmm_example(argc, argv); + } // else if(warp_tile == 1) // { // return !run_a16w4_moe_flatmm_example(argc, argv); diff --git a/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.hpp b/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.hpp index 458e7ba6434..e3cce789f28 100644 --- a/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.hpp +++ b/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.hpp @@ -13,8 +13,8 @@ // GEMM config with 16x16 warp tile struct A16W4_FlatmmConfig16 { - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t M_Tile = 64; + static constexpr ck_tile::index_t N_Tile = 128; static constexpr ck_tile::index_t K_Tile = 256; static constexpr ck_tile::index_t M_Warp = 1; @@ -43,6 +43,33 @@ struct A16W4_FlatmmConfig16 static constexpr bool TiledMMAPermuteN = false; }; +struct A16W4_FlatmmConfig16_M16 : public A16W4_FlatmmConfig16 +{ + static constexpr ck_tile::index_t M_Tile = 16; + static constexpr ck_tile::index_t kBlockPerCu = 2; +}; + +struct A16W4_FlatmmConfig16_M32 : public A16W4_FlatmmConfig16 +{ + static constexpr ck_tile::index_t M_Tile = 32; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; + static constexpr ck_tile::index_t kBlockPerCu = 2; +}; + +struct A16W4_FlatmmConfig16_M64 : public A16W4_FlatmmConfig16 +{ + static constexpr ck_tile::index_t M_Tile = 64; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; + static constexpr ck_tile::index_t kBlockPerCu = 2; +}; + +struct A16W4_FlatmmConfig16_M128 : public A16W4_FlatmmConfig16 +{ + static constexpr ck_tile::index_t M_Tile = 128; +}; + struct A16W4_FlatmmConfig16_950 : public A16W4_FlatmmConfig16 { static constexpr ck_tile::index_t N_Tile = 128; diff --git a/example/ck_tile/18_flatmm/mixed_prec/run_a16w4_moe_flatmm_example.inc b/example/ck_tile/18_flatmm/mixed_prec/run_a16w4_moe_flatmm_example.inc index f236332d620..116c55f333a 100644 --- a/example/ck_tile/18_flatmm/mixed_prec/run_a16w4_moe_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/mixed_prec/run_a16w4_moe_flatmm_example.inc @@ -36,7 +36,7 @@ float invoke_a16w4_moe_gemm(int n_warmup, int n_repeat, const MoeHostArgs& args) std::size_t flop = std::size_t(2) * args.M * args.N * args.K; std::size_t num_byte = sizeof(ADataType) * args.M * args.K + - sizeof(BDataType) * args.N * args.K / PackedSize + + sizeof(BDataType) * args.N * args.K * std::min(args.NumExperts, args.NumTokens * args.TopK) / PackedSize + sizeof(CDataType) * args.M * args.N; float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time; @@ -72,7 +72,8 @@ int run_a16w4_moe_gemm_example_with_layouts(int argc, using CDataType = PrecActType; using AccDataType = float; - using ScaleType = ck_tile::e8m0_t; + static constexpr bool IsInt4 = std::is_same_v; + using ScaleType = std::conditional_t; constexpr int ScaleGranularityN = 1; constexpr int ScaleGranularityK = 32; @@ -188,7 +189,7 @@ int run_a16w4_moe_gemm_example_with_layouts(int argc, int tile_off = i % MPerBlock; if(tile_off < token_per_tile && tokenid < num_tokens * topk) { - sorted_token_ids.mData[i] = (tokenid % num_tokens) | ((tokenid / num_tokens) << 24); + sorted_token_ids.mData[i] = (tokenid / experts) | ((tokenid % experts) << 24); tokenid++; } else diff --git a/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc b/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc index 9e0cbda0c00..053c039cd53 100644 --- a/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc @@ -302,10 +302,6 @@ int run_moe_gemm_example_with_layouts(int argc, static_cast(per_token_scale_dev_buf.GetDeviceBuffer()), static_cast(per_channel_scale_dev_buf.GetDeviceBuffer())); - const float max_accumulated_value = - *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); - const auto rtol_atol = calculate_rtol_atol( - K, 1 /*kbatch*/, max_accumulated_value); c_m_n_ref_buf->FromDevice(c_m_n_host_ref.data()); const float rtol = std::is_same_v && IsInputGemm ? 1e-3 : 1e-2; diff --git a/include/ck_tile/core/arch/generic_memory_space_atomic.hpp b/include/ck_tile/core/arch/generic_memory_space_atomic.hpp index e56bcadcba2..0ff97bb9a79 100644 --- a/include/ck_tile/core/arch/generic_memory_space_atomic.hpp +++ b/include/ck_tile/core/arch/generic_memory_space_atomic.hpp @@ -102,6 +102,9 @@ CK_TILE_DEVICE void atomic_add(X* p_dst, const X& x); template <> CK_TILE_DEVICE void atomic_add(bf16x2_t* p_dst, const bf16x2_t& x) { +#if HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN + __builtin_amdgcn_global_atomic_fadd_v2bf16(c_style_pointer_cast(p_dst), x); +#else union U32BF162_ADDR { uint32_t* u32_a; @@ -128,6 +131,7 @@ CK_TILE_DEVICE void atomic_add(bf16x2_t* p_dst, const bf16x2_t& x) new_v = new_.u32; cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v); } while(cur_v.u32 != old_v); +#endif } template <> diff --git a/include/ck_tile/core/numeric/pk_int4.hpp b/include/ck_tile/core/numeric/pk_int4.hpp index fc1caf13ff9..088407b40c7 100644 --- a/include/ck_tile/core/numeric/pk_int4.hpp +++ b/include/ck_tile/core/numeric/pk_int4.hpp @@ -151,6 +151,16 @@ CK_TILE_HOST_DEVICE fp16x2_t pk_int4_t_to_halfx2_t(const pk_int4_t& x) return pk_add_f16(bit_cast(lo), bit_cast(SUB)); } + + +CK_TILE_HOST_DEVICE fp16x2_t pk_int4_t_to_halfx2_t(const pk_int4_t& x, float scale) +{ + auto float_vec2 = pk_int4_t_to_fp32x2_t(x); + float_vec2.x = float_vec2.x * scale; + float_vec2.y = float_vec2.y * scale; + return fp16x2_t{type_convert(float_vec2.x), type_convert(float_vec2.y)}; +} + CK_TILE_HOST_DEVICE bf16x2_t pk_int4_t_to_bfloat16x2_t(const pk_int4_t& x) { uint8_t x_u8 = ck_tile::bit_cast(x); @@ -166,6 +176,14 @@ CK_TILE_HOST_DEVICE bf16x2_t pk_int4_t_to_bfloat16x2_t(const pk_int4_t& x) return res; } +CK_TILE_HOST_DEVICE bf16x2_t pk_int4_t_to_bfloat16x2_t(const pk_int4_t& x, float scale) +{ + auto float_vec2 = pk_int4_t_to_fp32x2_t(x); + float_vec2.x = float_vec2.x * scale; + float_vec2.y = float_vec2.y * scale; + return bf16x2_t{type_convert(float_vec2.x), type_convert(float_vec2.y)}; +} + CK_TILE_HOST_DEVICE int8x2_t pk_int4_t_to_int8x2_t(const pk_int4_t& x) { uint8_t x_u8 = ck_tile::bit_cast(x); diff --git a/include/ck_tile/core/tensor/tile_scatter_gather.hpp b/include/ck_tile/core/tensor/tile_scatter_gather.hpp index 4b04fd513db..e6adc7d40b4 100644 --- a/include/ck_tile/core/tensor/tile_scatter_gather.hpp +++ b/include/ck_tile/core/tensor/tile_scatter_gather.hpp @@ -125,7 +125,7 @@ struct tile_scatter_gather static constexpr auto get_space_filling_curve() { - constexpr auto tile_dstr = TileDstr{}; + [[maybe_unused]] constexpr auto tile_dstr = TileDstr{}; constexpr auto thread_tensor_lengths_ys = to_sequence(tile_dstr.get_ys_to_d_descriptor().get_lengths()); @@ -309,7 +309,7 @@ struct tile_scatter_gather CK_TILE_DEVICE auto load(number = {}, bool_constant = {}) const { - constexpr auto tile_dstr = TileDstr{}; + [[maybe_unused]] constexpr auto tile_dstr = TileDstr{}; auto dst_tensor = make_static_distributed_tensor(tile_dstr); load(dst_tensor, number{}, bool_constant{}); return dst_tensor; @@ -326,7 +326,7 @@ struct tile_scatter_gather using vector_t = typename Traits::vector_t; using SFC_Ys = typename Traits::SFC_Ys; - constexpr auto tile_dstr = TileDstr{}; + [[maybe_unused]] constexpr auto tile_dstr = TileDstr{}; // loop over thread tensor space [y0, y1, ...] static_for<0, NumCoord, 1>{}([&](auto iCoord) { @@ -418,7 +418,7 @@ struct tile_scatter_gather using vector_t = typename Traits::vector_t; using SFC_Ys = typename Traits::SFC_Ys; - constexpr auto tile_dstr = TileDstr{}; + [[maybe_unused]] constexpr auto tile_dstr = TileDstr{}; // Precompute invariant values outside loops const auto window_origin = lds_tile.get_window_origin(); @@ -614,7 +614,7 @@ struct tile_scatter_gather using vector_t = typename Traits::vector_t; using SFC_Ys = typename Traits::SFC_Ys; - constexpr auto tile_dstr = TileDstr{}; + [[maybe_unused]] constexpr auto tile_dstr = TileDstr{}; static_for<0, NumCoord, 1>{}([&](auto iCoord) { auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; @@ -696,7 +696,7 @@ struct tile_scatter_gather using vector_t = typename Traits::vector_t; using SFC_Ys = typename Traits::SFC_Ys; - constexpr auto tile_dstr = TileDstr{}; + [[maybe_unused]] constexpr auto tile_dstr = TileDstr{}; // printf("off %d\n", page_idx_[I0]); // loop over thread tensor space [y0, y1, ...] static_for<0, NumCoord, 1>{}([&](auto iCoord) { diff --git a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp index 411cfe81edf..62a69a16676 100644 --- a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp @@ -241,7 +241,7 @@ struct MoeFlatmmKernel IsGateUp ? TilePartitioner::NPerBlock / 2 : TilePartitioner::NPerBlock; // MXF4_Pipeline only has the of scale B and granularityK is 32 - static constexpr bool MXFP4_Pipeline = std::is_same_v; + static constexpr bool MXFP4_Pipeline = std::is_same_v || std::is_same_v; static constexpr int MXFP4N_Pack = 2; static constexpr int MXFP4K_Pack = 2; @@ -623,7 +623,7 @@ struct MoeFlatmmKernel { return make_naive_tensor_view( e_ptr, - make_tuple(IsInputGemm ? kargs.NumTokens * kargs.TopK : kargs.NumToken, + make_tuple(IsInputGemm ? kargs.NumTokens * kargs.TopK : kargs.NumTokens, IsGateUp ? kargs.N / 2 : kargs.N), make_tuple(1, kargs.stride_C), number<1>{}, @@ -638,7 +638,9 @@ struct MoeFlatmmKernel index_t FlatScaleK = scale_k * N_Pack * BlockGemmShape::WarpTile::at(I1); index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1); - using ScaleType = std::conditional_t; + // using ScaleType = std::conditional_t; + static constexpr bool IsInt4 = std::is_same_v; + using ScaleType = std::conditional_t, float>; const auto scale_b_flat_view = make_naive_tensor_view( reinterpret_cast(scale_n.ptr) + expert_id * kargs.N * scale_k, @@ -721,6 +723,7 @@ struct MoeFlatmmKernel constexpr bool isNonInterleaveGateUp = !IsGateUp || MXFP4_Pipeline; + /* const auto& b_flat_block_window = make_tile_window(b_flat_pad_view, make_tuple(number{}, @@ -728,6 +731,63 @@ struct MoeFlatmmKernel {static_cast(coord_n / BlockGemmShape::WarpTile::at(I1) / (isNonInterleaveGateUp ? 1 : 2)), 0}); + */ + const auto& b_flat_block_window = [&]() { + // GateUp needs to shuffle weight + if constexpr(IsGateUp) + { + // 1. Get Dimensions + const auto N = b_flat_pad_view.get_tensor_descriptor().get_length(I0); + const auto K = b_flat_pad_view.get_tensor_descriptor().get_length(I1); + + // 2. View Linear N as (2, N/2) -> effectively separating Gate (0) and Up (1) blocks + // Layout becomes: (BlockIdx, RowInBlock, K) + auto v_split = transform_tensor_view( + b_flat_pad_view, + make_tuple(make_unmerge_transform(make_tuple(number<2>{}, N / 2)), + make_pass_through_transform(K)), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0, 1>{}, sequence<2>{})); + + // 3. Permute to (N/2, 2, K) -> (RowInBlock, BlockIdx, K) + // This puts Gate(i) and Up(i) adjacent in the view + auto v_permute = transform_tensor_view( + v_split, + make_tuple(make_pass_through_transform(N / 2), + make_pass_through_transform(number<2>{}), + make_pass_through_transform(K)), + make_tuple(sequence<1>{}, sequence<0>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + // 4. Merge back to (N, K) -> effectively Interleaved View + auto b_interleaved_view = transform_tensor_view( + v_permute, + make_tuple(make_merge_transform(make_tuple(N / 2, number<2>{})), + make_pass_through_transform(K)), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + // 5. Create Window on the transformed view + return make_tile_window( + b_interleaved_view, + make_tuple(number{}, + number{}), + {static_cast(coord_n / BlockGemmShape::WarpTile::at(I1) / + (isNonInterleaveGateUp ? 1 : 2)), + 0}); + } + else + { + // Default behavior for Interleaved or non-GateUp + return make_tile_window( + b_flat_pad_view, + make_tuple(number{}, + number{}), + {static_cast(coord_n / BlockGemmShape::WarpTile::at(I1) / + (isNonInterleaveGateUp ? 1 : 2)), + 0}); + } + }(); const int output_N_offset = IsGateUp ? coord_n / 2 : coord_n; @@ -1250,6 +1310,8 @@ struct MoeFlatmmKernel constexpr int MPerThread = TileEncodingPattern::Y2; statically_indexed_array, NumMEpiTile> c_scatter_offsets; + statically_indexed_array, NumMEpiTile> + c_scatter_valids; auto c_coord = dram_tile_distribution.calculate_index(); static_for<0, NumMEpiTile, 1>{}([&](auto mIter) { static_for<0, MPerThread, 1>{}([&](auto m0) { @@ -1262,6 +1324,7 @@ struct MoeFlatmmKernel scatter_token_id = scatter_token_id * kargs.TopK + (fused_token >> token_id_offset); c_scatter_offsets[mIter][m0] = scatter_token_id * kargs.stride_C; + c_scatter_valids[mIter][m0] = (scatter_token_id < (kargs.NumTokens * (IsInputGemm? kargs.TopK : 1))); }); }); @@ -1302,7 +1365,8 @@ struct MoeFlatmmKernel c_block_window.get_window_lengths(), c_block_window.get_window_origin(), dram_tile_distribution, - c_scatter_offsets[mIter]); + c_scatter_offsets[mIter], + c_scatter_valids[mIter]); if constexpr(!IsInputGemm || EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add) diff --git a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 17c88e4f08f..5dbc71d6ea6 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -71,7 +71,11 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 using WG = remove_cvref_t())>; static constexpr index_t DsWritePreIssue = 3; // default 2, ds write at MIter - 2 +#if CKTILE_FLATMM_USE_BUFFER_LOAD_LDS + static constexpr index_t DsReadPreload = 16; // default 8, if using lds, register pressure is alleviated, improve preload +#else static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read +#endif static constexpr index_t BlockSize = Problem::kBlockSize; static constexpr index_t WaveSize = get_warp_size(); @@ -183,14 +187,145 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 // For the basic gemm pipelien DoubleSmemBuffer set to be false naturally. static constexpr bool DoubleSmemBuffer = false; + struct DequantizeMxFP4 { + + CK_TILE_DEVICE auto operator()([[maybe_unused]] statically_indexed_array& dequant_B_n, + [[maybe_unused]] const auto& quant_weight_tensor, + [[maybe_unused]] const auto& scale_tensor, + [[maybe_unused]] auto xdl_nIter, + [[maybe_unused]] auto xdl_kIter) { + auto quant_idx_k = xdl_kIter % number{}; + + auto scale_idx_n = xdl_nIter % number{}; + auto scale_idx_k = (xdl_kIter % number{}) / number{}; + auto scale_offset = scale_idx_n + scale_idx_k * number{}; + + auto scale = scale_tensor.get_thread_buffer()[scale_offset]; + + constexpr int ScalarCnt = WG::BWarpTensor::get_thread_buffer_size(); + constexpr int PackedCnt = ScalarCnt / MXFP4PackedSize; + constexpr int float_mantissa = 23; + + uint32_t uscale = uint32_t(bit_cast(scale)) << float_mantissa; + + using ComputeV2Type = + std::conditional_t, fp16x2_t, bf16x2_t>; + +#if defined(__gfx950__) + auto pk_mxfp4x4_to_compute_v2 = [](auto pk_mxfp4x4, float fscale, auto byte_idx) { + if constexpr(std::is_same_v) + { + return __builtin_amdgcn_cvt_scalef32_pk_f16_fp4( + pk_mxfp4x4, fscale, int(byte_idx)); + } + else if constexpr(std::is_same_v) + { + return __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4( + pk_mxfp4x4, fscale, int(byte_idx)); + } + else + { + static_assert(sizeof(pk_mxfp4x4) == 0, "unsupported compute type"); + } + }; + static_for<0, PackedCnt, 1>{}([&](auto i) { + dequant_B_n[xdl_nIter].get_thread_buffer().template set_as( + i, + pk_mxfp4x4_to_compute_v2( + quant_weight_tensor[quant_idx_k], bit_cast(uscale), i)); + }); +#else + auto pk_mxfp4_to_compute_v2 = [](auto pk_mxfp4, float fscale) { + if constexpr(std::is_same_v) + { + return pk_fp4_to_fp16x2(pk_mxfp4, fscale); + } + else if constexpr(std::is_same_v) + { + return pk_fp4_to_bf16x2(pk_mxfp4, fscale); + } + else + { + static_assert(sizeof(pk_mxfp4) == 0, "unsupported compute type"); + } + }; + static_for<0, PackedCnt, 1>{}([&](auto i) { + dequant_B_n[xdl_nIter].get_thread_buffer().template set_as( + i, + pk_mxfp4_to_compute_v2( + bit_cast>(quant_weight_tensor[quant_idx_k]) + .at(i), + bit_cast(uscale))); + }); +#endif + return 0; + } + }; + + struct DequantizeINT4 { + + CK_TILE_DEVICE auto operator()(statically_indexed_array& dequant_B_n, + const auto& quant_weight_tensor, + const auto& scale_tensor, + auto xdl_nIter, + auto xdl_kIter) { + + auto quant_idx_k = xdl_kIter % number{}; + + auto scale_idx_n = xdl_nIter % number{}; + auto scale_idx_k = (xdl_kIter % number{}) / number{}; + auto scale_offset = scale_idx_n + scale_idx_k * number{}; + + auto scale = scale_tensor.get_thread_buffer()[scale_offset]; + + constexpr int ScalarCnt = WG::BWarpTensor::get_thread_buffer_size(); + constexpr int PackedCnt = ScalarCnt / MXFP4PackedSize; + /* + constexpr int float_mantissa = 23; + + uint32_t uscale = uint32_t(scale.data) << float_mantissa; + */ + + // float scale_f32 = type_convert(scale.data); + float scale_f32 = type_convert(scale); + + using ComputeV2Type = + std::conditional_t, fp16x2_t, bf16x2_t>; + + auto pk_int4_to_compute_v2 = [](auto pk_int4, float fscale) { + if constexpr(std::is_same_v) + { + return pk_int4_t_to_halfx2_t(pk_int4, fscale); + } + else if constexpr(std::is_same_v) + { + return pk_int4_t_to_bfloat16x2_t(pk_int4, fscale); + } + else + + + + { + static_assert(sizeof(pk_int4) == 0, "unsupported compute type"); + } + }; + static_for<0, PackedCnt, 1>{}([&](auto i) { + dequant_B_n[xdl_nIter].get_thread_buffer().template set_as( + i, + pk_int4_to_compute_v2( + bit_cast>(quant_weight_tensor[quant_idx_k]) + .at(i), + scale_f32)); + }); + return 0; + } + }; + + using DequantOp = typename std::conditional, DequantizeMxFP4, DequantizeINT4>::type; + CK_TILE_HOST_DEVICE static constexpr auto SchedulerPerM(index_t dsread_perM, index_t dswrite_perM, index_t load_perM) { -#if CKTILE_FLATMM_USE_BUFFER_LOAD_LDS - // GFX950 use BUFFER_LOAD_LDS to fill lds_buffer_A. - // There is no separate DS_WRITE instruction at all. - dswrite_perM = 0; -#endif // Init inst order index_t max_data_inst = dsread_perM > load_perM ? (dsread_perM > dswrite_perM ? dsread_perM : dswrite_perM) @@ -360,7 +495,36 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 // Calculate ds_read number per M dsread_perM = dsread_per_wg; +#if CKTILE_FLATMM_USE_BUFFER_LOAD_LDS == 0 + // Calculate ds_write number per M + if(mIter == 0) + { + dswrite_perM = + (dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep) > 0 + ? dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep + : 0; + } + else if(mIter >= MIterPerWarp - DsWritePreIssue + 1) + { + dswrite_perM = 0; + } + else + { + dswrite_perM = (dswrite_num_perK - + (MIterPerWarp - DsWritePreIssue - mIter) * dswrite_rep) > 0 + ? dswrite_rep + : 0; + } + // Add ds write when ds write data > needed + if(dswrite_num_perK == 0 && kIter == (KIterPerWarp - 1 - dswrite_kIter)) + { + if(mIter == MIterPerWarp - 1 - dswrite_mIter) + dswrite_perM = 1; + } +#endif + // Calculate buffer_load number per M +#if CKTILE_FLATMM_USE_BUFFER_LOAD_LDS == 0 if(mIter < HalfMIter) { load_perM = @@ -375,10 +539,17 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 ? Aload_rep : 0; } - if((kIter % KPerScaleLoad == 0) && (mIter == 0)) - { - load_perM = load_perM + 1; +#else + if ((kIter * MIterPerWarp + mIter) >= + (KIterPerWarp * MIterPerWarp - m_preload)) { + load_perM = 1; } +#endif + // if((kIter % KPerScaleLoad == 0) && (mIter == 0)) + // { + // load_perM = load_perM + 1; + // } + // SchedulerPerM(dsread_perM, dswrite_perM, load_perM); SchedulerPerM(dsread_perM, dswrite_perM, load_perM); } } @@ -444,7 +615,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 typename BFlatBlockWindowTmp, typename DequantBFlatWindow> CK_TILE_HOST_DEVICE auto operator()(ADramBlockWindowTmp a_copy_dram_window_, - const AElementFunction& a_element_func, + [[maybe_unused]] const AElementFunction& a_element_func, const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, const DequantBFlatWindow& scale_b_flat_window, const index_t num_loop, @@ -606,7 +777,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 scale_b_warp_tensor_pong; using ABlockTile = decltype(load_tile(a_copy_dram_window)); - ABlockTile a_block_tile; + [[maybe_unused]] ABlockTile a_block_tile; enum { @@ -621,7 +792,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 if constexpr(prefill_location & PrefillAfterGemm) async_load_tile(lds_tile_a, dram_tile_a); }; - auto prefill_lds_a_stage2 = [&](auto lds_tile_a) { + auto prefill_lds_a_stage2 = [&]([[maybe_unused]] auto lds_tile_a) { // async_load_fence(); // __builtin_amdgcn_s_waitcnt(0x03fc); // data has been stored in lds, no need more operation. @@ -712,6 +883,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 statically_indexed_array dequant_B_n; + + /* auto dequant_mxfp4 = [&](const auto& quant_weight_tensor, const auto& scale_tensor, auto xdl_nIter, @@ -781,6 +954,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 }); #endif }; + */ // MAIN LOOP index_t iCounter = (num_loop - 1) / 2; @@ -821,6 +995,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 }); }); + __builtin_amdgcn_sched_barrier(0); + // Prefill A(2i+1) prefill_lds_a_stage2(a_copy_lds_window_pong); @@ -840,7 +1016,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); if constexpr(mIter == 0) - dequant_mxfp4( + DequantOp{}( + dequant_B_n, b_warp_tensor_ping(nIter)(kIter / number{}), scale_b_warp_tensor_ping(nIter / number{})( kIter / number{}), @@ -866,12 +1043,23 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 load_tile(a_warp_windows_ping(number{})(number{})); } + // yadai comments out the following + /* // barrier if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) { __builtin_amdgcn_s_waitcnt(Bload_total_num); block_sync_lds(); } + */ + + // sync shouble made as early as possible + if constexpr((kIter * MIterPerWarp + mIter) == + (KIterPerWarp * MIterPerWarp - m_preload)) + { + __builtin_amdgcn_s_waitcnt(Bload_total_num); + block_sync_lds(); + } }); }); prefill_lds_a_stage1( @@ -928,6 +1116,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 }); }); + __builtin_amdgcn_sched_barrier(0); + // Prefill A(2i+2) prefill_lds_a_stage2(a_copy_lds_window_ping); @@ -947,7 +1137,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); if constexpr(mIter == 0) - dequant_mxfp4( + DequantOp{}( + dequant_B_n, b_warp_tensor_pong(nIter)(kIter / number{}), scale_b_warp_tensor_pong(nIter / number{})( kIter / number{}), @@ -973,12 +1164,23 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 load_tile(a_warp_windows_pong(number{})(number{})); } + // yadai comments out the following + /* // barrier if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) { __builtin_amdgcn_s_waitcnt(Bload_total_num); block_sync_lds(); } + */ + + // sync shouble made as early as possible + if constexpr((kIter * MIterPerWarp + mIter) == + (KIterPerWarp * MIterPerWarp - m_preload)) + { + __builtin_amdgcn_s_waitcnt(Bload_total_num); + block_sync_lds(); + } }); }); prefill_lds_a_stage1( @@ -1063,7 +1265,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); if constexpr(mIter == 0) - dequant_mxfp4( + DequantOp{}( + dequant_B_n, b_warp_tensor_ping(nIter)(kIter / number{}), scale_b_warp_tensor_ping(nIter / number{})( kIter / number{}), @@ -1124,7 +1327,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); if constexpr(mIter == 0) - dequant_mxfp4( + DequantOp{}( + dequant_B_n, b_warp_tensor_pong(nIter)(kIter / number{}), scale_b_warp_tensor_pong(nIter / number{})( kIter / number{}), @@ -1175,7 +1379,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); if constexpr(mIter == 0) - dequant_mxfp4( + DequantOp{}( + dequant_B_n, b_warp_tensor_ping(nIter)(kIter / number{}), scale_b_warp_tensor_ping(nIter / number{})( kIter / number{}), diff --git a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index f34c682b0f1..f5954c29abf 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -7,7 +7,7 @@ namespace ck_tile { -#define CKTILE_FLATMM_USE_BUFFER_LOAD_LDS_AS_POSSIBLE 0 +#define CKTILE_FLATMM_USE_BUFFER_LOAD_LDS_AS_POSSIBLE 1 #if defined(__gfx950__) #define CKTILE_FLATMM_ARCH_SUPPORT_BUFFER_LOAD_LDS_DWORDx4 1