Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ float a16w4_moe_gemm(const MoeFlatmmHostArgs& args, const ck_tile::stream_config
FlatmmConfig::NumWaveGroups,
true>; // Preshuffle_

constexpr bool MXFP4_Pipeline = std::is_same_v<BDataType, ck_tile::pk_fp4_t>;
// TODO(yadai): rename to W4_Pipeline
constexpr bool MXFP4_Pipeline = std::is_same_v<BDataType, ck_tile::pk_fp4_t> | std::is_same_v<BDataType, ck_tile::pk_int4_t>;

if constexpr(!MXFP4_Pipeline && moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up)
{
Expand Down Expand Up @@ -444,6 +445,22 @@ int run_a16w4_moe_flatmm_example(int argc, char* argv[])
FlatmmConfig,
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
}
else if(mixed_prec == "fp16xint4")
{
return run_a16w4_moe_gemm_example_with_layouts<
ck_tile::half_t,
ck_tile::pk_int4_t,
FlatmmConfig,
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
}
else if(mixed_prec == "bf16xint4")
{
return run_a16w4_moe_gemm_example_with_layouts<
ck_tile::bfloat16_t,
ck_tile::pk_int4_t,
FlatmmConfig,
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported precision type for gemm1_gate_up!");
Expand Down Expand Up @@ -498,6 +515,18 @@ int main(int argc, char* argv[])
{
return !run_a16w4_moe_flatmm_example<A16W4_FlatmmConfig16>(argc, argv);
}
else if (warp_tile == 1) {
return !run_a16w4_moe_flatmm_example<A16W4_FlatmmConfig16_M16>(argc, argv);
}
else if (warp_tile == 2) {
return !run_a16w4_moe_flatmm_example<A16W4_FlatmmConfig16_M32>(argc, argv);
}
else if (warp_tile == 3) {
return !run_a16w4_moe_flatmm_example<A16W4_FlatmmConfig16_M64>(argc, argv);
}
else if (warp_tile == 4) {
return !run_a16w4_moe_flatmm_example<A16W4_FlatmmConfig16_M128>(argc, argv);
}
// else if(warp_tile == 1)
// {
// return !run_a16w4_moe_flatmm_example<A16W4_FlatmmConfig16_950>(argc, argv);
Expand Down
31 changes: 29 additions & 2 deletions example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
// GEMM config with 16x16 warp tile
struct A16W4_FlatmmConfig16
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t M_Tile = 64;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 256;

static constexpr ck_tile::index_t M_Warp = 1;
Expand Down Expand Up @@ -43,6 +43,33 @@ struct A16W4_FlatmmConfig16
static constexpr bool TiledMMAPermuteN = false;
};

struct A16W4_FlatmmConfig16_M16 : public A16W4_FlatmmConfig16
{
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t kBlockPerCu = 2;
};

struct A16W4_FlatmmConfig16_M32 : public A16W4_FlatmmConfig16
{
static constexpr ck_tile::index_t M_Tile = 32;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr ck_tile::index_t kBlockPerCu = 2;
};

struct A16W4_FlatmmConfig16_M64 : public A16W4_FlatmmConfig16
{
static constexpr ck_tile::index_t M_Tile = 64;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr ck_tile::index_t kBlockPerCu = 2;
};

struct A16W4_FlatmmConfig16_M128 : public A16W4_FlatmmConfig16
{
static constexpr ck_tile::index_t M_Tile = 128;
};

struct A16W4_FlatmmConfig16_950 : public A16W4_FlatmmConfig16
{
static constexpr ck_tile::index_t N_Tile = 128;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ float invoke_a16w4_moe_gemm(int n_warmup, int n_repeat, const MoeHostArgs& args)

std::size_t flop = std::size_t(2) * args.M * args.N * args.K;
std::size_t num_byte = sizeof(ADataType) * args.M * args.K +
sizeof(BDataType) * args.N * args.K / PackedSize +
sizeof(BDataType) * args.N * args.K * std::min(args.NumExperts, args.NumTokens * args.TopK) / PackedSize +
sizeof(CDataType) * args.M * args.N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
Expand Down Expand Up @@ -72,7 +72,8 @@ int run_a16w4_moe_gemm_example_with_layouts(int argc,
using CDataType = PrecActType;
using AccDataType = float;

using ScaleType = ck_tile::e8m0_t;
static constexpr bool IsInt4 = std::is_same_v<BDataType, ck_tile::pk_int4_t>;
using ScaleType = std::conditional_t<IsInt4, ck_tile::bfloat16_t, ck_tile::e8m0_t>;

constexpr int ScaleGranularityN = 1;
constexpr int ScaleGranularityK = 32;
Expand Down Expand Up @@ -188,7 +189,7 @@ int run_a16w4_moe_gemm_example_with_layouts(int argc,
int tile_off = i % MPerBlock;
if(tile_off < token_per_tile && tokenid < num_tokens * topk)
{
sorted_token_ids.mData[i] = (tokenid % num_tokens) | ((tokenid / num_tokens) << 24);
sorted_token_ids.mData[i] = (tokenid / experts) | ((tokenid % experts) << 24);
tokenid++;
}
else
Expand Down
4 changes: 0 additions & 4 deletions example/ck_tile/18_flatmm/run_moe_flatmm_example.inc
Original file line number Diff line number Diff line change
Expand Up @@ -302,10 +302,6 @@ int run_moe_gemm_example_with_layouts(int argc,
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer()),
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer()));

const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, 1 /*kbatch*/, max_accumulated_value);
c_m_n_ref_buf->FromDevice(c_m_n_host_ref.data());

const float rtol = std::is_same_v<ADataType, ck_tile::half_t> && IsInputGemm ? 1e-3 : 1e-2;
Expand Down
4 changes: 4 additions & 0 deletions include/ck_tile/core/arch/generic_memory_space_atomic.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ CK_TILE_DEVICE void atomic_add(X* p_dst, const X& x);
template <>
CK_TILE_DEVICE void atomic_add<bf16x2_t>(bf16x2_t* p_dst, const bf16x2_t& x)
{
#if HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN
__builtin_amdgcn_global_atomic_fadd_v2bf16(c_style_pointer_cast<bf16x2_t*>(p_dst), x);
#else
union U32BF162_ADDR
{
uint32_t* u32_a;
Expand All @@ -128,6 +131,7 @@ CK_TILE_DEVICE void atomic_add<bf16x2_t>(bf16x2_t* p_dst, const bf16x2_t& x)
new_v = new_.u32;
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
} while(cur_v.u32 != old_v);
#endif
}

template <>
Expand Down
18 changes: 18 additions & 0 deletions include/ck_tile/core/numeric/pk_int4.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,16 @@ CK_TILE_HOST_DEVICE fp16x2_t pk_int4_t_to_halfx2_t(const pk_int4_t& x)
return pk_add_f16(bit_cast<fp16x2_t>(lo), bit_cast<fp16x2_t>(SUB));
}



CK_TILE_HOST_DEVICE fp16x2_t pk_int4_t_to_halfx2_t(const pk_int4_t& x, float scale)
{
auto float_vec2 = pk_int4_t_to_fp32x2_t(x);
float_vec2.x = float_vec2.x * scale;
float_vec2.y = float_vec2.y * scale;
return fp16x2_t{type_convert<fp16_t>(float_vec2.x), type_convert<fp16_t>(float_vec2.y)};
}

CK_TILE_HOST_DEVICE bf16x2_t pk_int4_t_to_bfloat16x2_t(const pk_int4_t& x)
{
uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
Expand All @@ -166,6 +176,14 @@ CK_TILE_HOST_DEVICE bf16x2_t pk_int4_t_to_bfloat16x2_t(const pk_int4_t& x)
return res;
}

CK_TILE_HOST_DEVICE bf16x2_t pk_int4_t_to_bfloat16x2_t(const pk_int4_t& x, float scale)
{
auto float_vec2 = pk_int4_t_to_fp32x2_t(x);
float_vec2.x = float_vec2.x * scale;
float_vec2.y = float_vec2.y * scale;
return bf16x2_t{type_convert<bf16_t>(float_vec2.x), type_convert<bf16_t>(float_vec2.y)};
}

CK_TILE_HOST_DEVICE int8x2_t pk_int4_t_to_int8x2_t(const pk_int4_t& x)
{
uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
Expand Down
12 changes: 6 additions & 6 deletions include/ck_tile/core/tensor/tile_scatter_gather.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ struct tile_scatter_gather

static constexpr auto get_space_filling_curve()
{
constexpr auto tile_dstr = TileDstr{};
[[maybe_unused]] constexpr auto tile_dstr = TileDstr{};

constexpr auto thread_tensor_lengths_ys =
to_sequence(tile_dstr.get_ys_to_d_descriptor().get_lengths());
Expand Down Expand Up @@ -309,7 +309,7 @@ struct tile_scatter_gather
CK_TILE_DEVICE auto load(number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
constexpr auto tile_dstr = TileDstr{};
[[maybe_unused]] constexpr auto tile_dstr = TileDstr{};
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
load(dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
return dst_tensor;
Expand All @@ -326,7 +326,7 @@ struct tile_scatter_gather
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;

constexpr auto tile_dstr = TileDstr{};
[[maybe_unused]] constexpr auto tile_dstr = TileDstr{};

// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
Expand Down Expand Up @@ -418,7 +418,7 @@ struct tile_scatter_gather
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;

constexpr auto tile_dstr = TileDstr{};
[[maybe_unused]] constexpr auto tile_dstr = TileDstr{};

// Precompute invariant values outside loops
const auto window_origin = lds_tile.get_window_origin();
Expand Down Expand Up @@ -614,7 +614,7 @@ struct tile_scatter_gather
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;

constexpr auto tile_dstr = TileDstr{};
[[maybe_unused]] constexpr auto tile_dstr = TileDstr{};

static_for<0, NumCoord, 1>{}([&](auto iCoord) {
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
Expand Down Expand Up @@ -696,7 +696,7 @@ struct tile_scatter_gather
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;

constexpr auto tile_dstr = TileDstr{};
[[maybe_unused]] constexpr auto tile_dstr = TileDstr{};
// printf("off %d\n", page_idx_[I0]);
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
Expand Down
72 changes: 68 additions & 4 deletions include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ struct MoeFlatmmKernel
IsGateUp ? TilePartitioner::NPerBlock / 2 : TilePartitioner::NPerBlock;

// MXF4_Pipeline only has the of scale B and granularityK is 32
static constexpr bool MXFP4_Pipeline = std::is_same_v<BDataType, pk_fp4_t>;
static constexpr bool MXFP4_Pipeline = std::is_same_v<BDataType, pk_fp4_t> || std::is_same_v<BDataType, pk_int4_t>;
static constexpr int MXFP4N_Pack = 2;
static constexpr int MXFP4K_Pack = 2;

Expand Down Expand Up @@ -623,7 +623,7 @@ struct MoeFlatmmKernel
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
e_ptr,
make_tuple(IsInputGemm ? kargs.NumTokens * kargs.TopK : kargs.NumToken,
make_tuple(IsInputGemm ? kargs.NumTokens * kargs.TopK : kargs.NumTokens,
IsGateUp ? kargs.N / 2 : kargs.N),
make_tuple(1, kargs.stride_C),
number<1>{},
Expand All @@ -638,7 +638,9 @@ struct MoeFlatmmKernel
index_t FlatScaleK = scale_k * N_Pack * BlockGemmShape::WarpTile::at(I1);
index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1);

using ScaleType = std::conditional_t<MXFP4_Pipeline, e8m0_t, float>;
// using ScaleType = std::conditional_t<MXFP4_Pipeline, e8m0_t, float>;
static constexpr bool IsInt4 = std::is_same_v<BDataType, pk_int4_t>;
using ScaleType = std::conditional_t<MXFP4_Pipeline, std::conditional_t<IsInt4, bfloat16_t, e8m0_t>, float>;

const auto scale_b_flat_view = make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const ScaleType*>(scale_n.ptr) + expert_id * kargs.N * scale_k,
Expand Down Expand Up @@ -721,13 +723,71 @@ struct MoeFlatmmKernel

constexpr bool isNonInterleaveGateUp = !IsGateUp || MXFP4_Pipeline;

/*
const auto& b_flat_block_window =
make_tile_window(b_flat_pad_view,
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
number<FlatmmPipeline::flatKPerWarp>{}),
{static_cast<int>(coord_n / BlockGemmShape::WarpTile::at(I1) /
(isNonInterleaveGateUp ? 1 : 2)),
0});
*/
const auto& b_flat_block_window = [&]() {
// GateUp needs to shuffle weight
if constexpr(IsGateUp)
{
// 1. Get Dimensions
const auto N = b_flat_pad_view.get_tensor_descriptor().get_length(I0);
const auto K = b_flat_pad_view.get_tensor_descriptor().get_length(I1);

// 2. View Linear N as (2, N/2) -> effectively separating Gate (0) and Up (1) blocks
// Layout becomes: (BlockIdx, RowInBlock, K)
auto v_split = transform_tensor_view(
b_flat_pad_view,
make_tuple(make_unmerge_transform(make_tuple(number<2>{}, N / 2)),
make_pass_through_transform(K)),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0, 1>{}, sequence<2>{}));

// 3. Permute to (N/2, 2, K) -> (RowInBlock, BlockIdx, K)
// This puts Gate(i) and Up(i) adjacent in the view
auto v_permute = transform_tensor_view(
v_split,
make_tuple(make_pass_through_transform(N / 2),
make_pass_through_transform(number<2>{}),
make_pass_through_transform(K)),
make_tuple(sequence<1>{}, sequence<0>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));

// 4. Merge back to (N, K) -> effectively Interleaved View
auto b_interleaved_view = transform_tensor_view(
v_permute,
make_tuple(make_merge_transform(make_tuple(N / 2, number<2>{})),
make_pass_through_transform(K)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));

// 5. Create Window on the transformed view
return make_tile_window(
b_interleaved_view,
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
number<FlatmmPipeline::flatKPerWarp>{}),
{static_cast<int>(coord_n / BlockGemmShape::WarpTile::at(I1) /
(isNonInterleaveGateUp ? 1 : 2)),
0});
}
else
{
// Default behavior for Interleaved or non-GateUp
return make_tile_window(
b_flat_pad_view,
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
number<FlatmmPipeline::flatKPerWarp>{}),
{static_cast<int>(coord_n / BlockGemmShape::WarpTile::at(I1) /
(isNonInterleaveGateUp ? 1 : 2)),
0});
}
}();

const int output_N_offset = IsGateUp ? coord_n / 2 : coord_n;

Expand Down Expand Up @@ -1250,6 +1310,8 @@ struct MoeFlatmmKernel
constexpr int MPerThread = TileEncodingPattern::Y2;
statically_indexed_array<statically_indexed_array<index_t, MPerThread>, NumMEpiTile>
c_scatter_offsets;
statically_indexed_array<statically_indexed_array<bool, MPerThread>, NumMEpiTile>
c_scatter_valids;
auto c_coord = dram_tile_distribution.calculate_index();
static_for<0, NumMEpiTile, 1>{}([&](auto mIter) {
static_for<0, MPerThread, 1>{}([&](auto m0) {
Expand All @@ -1262,6 +1324,7 @@ struct MoeFlatmmKernel
scatter_token_id =
scatter_token_id * kargs.TopK + (fused_token >> token_id_offset);
c_scatter_offsets[mIter][m0] = scatter_token_id * kargs.stride_C;
c_scatter_valids[mIter][m0] = (scatter_token_id < (kargs.NumTokens * (IsInputGemm? kargs.TopK : 1)));
});
});

Expand Down Expand Up @@ -1302,7 +1365,8 @@ struct MoeFlatmmKernel
c_block_window.get_window_lengths(),
c_block_window.get_window_origin(),
dram_tile_distribution,
c_scatter_offsets[mIter]);
c_scatter_offsets[mIter],
c_scatter_valids[mIter]);

if constexpr(!IsInputGemm ||
EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add)
Expand Down
Loading