Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
4985afb
adap gemm_mx_kernel.hpp from flatmm, comment changes needed to mx pip…
samremes Dec 18, 2025
0faed29
refactor the mx pipeline, backup the modified flatmm pipeline
samremes Dec 18, 2025
6a4951c
add mx gemm example
samremes Dec 18, 2025
86cc59e
fix settings for example, fix some things in pipeline
samremes Dec 19, 2025
10fb184
WIP: fixing loading logic
samremes Dec 19, 2025
ec1a069
Use simpler layout for scales.
samremes Jan 12, 2026
f944bc0
Extend comp async pipeline with scales
samremes Jan 13, 2026
edd11c9
Extend comp async pipeline with scales
samremes Jan 13, 2026
93ff8b0
use new pipeline in example
samremes Jan 13, 2026
5d4e07e
Merge remote-tracking branch 'origin/develop' into samremes/ck_tile_m…
samremes Jan 14, 2026
f6f9931
WIP
samremes Jan 14, 2026
16ca5cb
WIP
samremes Jan 16, 2026
f09e109
fixed vector load siz for fp4
samremes Jan 16, 2026
d2a7c2f
compiles again using get_y_sliced_thread_data in warpgemm loop
samremes Jan 23, 2026
70c7fcd
WIP: debugging...
samremes Jan 26, 2026
f62cc54
current state of pipeline
samremes Jan 27, 2026
08ec1f4
update example code
samremes Jan 27, 2026
30d4c25
use PackedSize in slicing
samremes Jan 27, 2026
0033748
revert custom ldstile, should be able to use the regular ones
samremes Jan 28, 2026
409a7d8
Merge remote-tracking branch 'origin/develop' into samremes/ck_tile_m…
samremes Jan 30, 2026
2cc0e3d
override base policys vector size with static_assert 4/12/16 bytes
samremes Jan 30, 2026
b124a72
revert mostly back to original comp_async
samremes Jan 30, 2026
771c46a
add initial version for scale block_gemm, not used yet
samremes Jan 30, 2026
b8cdea5
enable fp8 mx gemm too
samremes Jan 30, 2026
407df88
enable 32 element for fp4
samremes Jan 30, 2026
4d24128
use default scale (no scale) for 16x16x128 mfma scale
samremes Jan 30, 2026
b47853d
enable fp4 for universal gemm - without any scaling
samremes Feb 3, 2026
6b50755
fix alignment calculation of lds tensor views
samremes Feb 3, 2026
16fa73d
use proper rtol/atol
samremes Feb 3, 2026
329eabd
fix strides in mx gemm example
samremes Feb 3, 2026
6c61804
try to enable scale loading in kernel and pipeline
samremes Feb 5, 2026
3500228
init=1 init=2 working, some scales are still wrong as init=0 failing
samremes Feb 5, 2026
c4daaf2
fix packing in example
samremes Feb 5, 2026
a8d48f9
now offsetting with M/MPerXdl to get scales
samremes Feb 5, 2026
061c9f9
save packing approach
samremes Feb 6, 2026
c588a1f
use unpacked scales
samremes Feb 6, 2026
241ee59
clean up example a bit
samremes Feb 6, 2026
06a8998
clean up kernel and pipeline code
samremes Feb 6, 2026
dc4366a
add main include file
samremes Feb 6, 2026
1622674
use persistent
samremes Feb 6, 2026
457474e
use stricter tolerance
samremes Feb 6, 2026
c7298e5
remove some old files
samremes Feb 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ if (WIN32)
find_package(ROCmCMakeBuildTools REQUIRED PATHS C:/dist/TheRock)
set(HIP_PLATFORM "amd" CACHE STRING "HIP platform")
else()
find_package(ROCM REQUIRED PATHS /opt/rocm)
find_package(ROCM REQUIRED PATHS /opt/venv/lib/python3.12/site-packages/_rocm_sdk_devel/)
endif()

include(ROCMInstallTargets)
Expand Down
63 changes: 63 additions & 0 deletions example/ck_tile/03_gemm/gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <variant>

#include "ck_tile/core.hpp"
#include "ck_tile/core/numeric/pk_fp4.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
Expand Down Expand Up @@ -137,6 +138,27 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
static constexpr int kBlockPerCu = 2;
};

template <typename PrecType>
struct GemmConfigComputeV3_3 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 64;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256;

static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;

static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 128;

static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;

static constexpr int kBlockPerCu = 2;
};

template <typename PrecType>
struct GemmConfigComputeV3_WMMA : public GemmConfigBase
{
Expand Down Expand Up @@ -241,6 +263,28 @@ struct GemmConfigComputeV6 : public GemmConfigBase
static constexpr ck_tile::index_t NumWaveGroups = 1;
};

template <typename PrecType>
struct GemmConfigComputeAsync : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 64;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256;

static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;

static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 128;
// static constexpr ck_tile::index_t K_Warp_Tile = 32;

static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_ASYNC;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool UseStructuredSparsity = false;
};

template <typename PrecType>
struct GemmConfigPreshuffleDecode : public GemmConfigBase
{
Expand Down Expand Up @@ -375,6 +419,15 @@ struct GemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, int32_t>
using CDataType = int32_t;
};

template <>
struct GemmTypeConfig<ck_tile::pk_fp4_t, ck_tile::pk_fp4_t, ck_tile::half_t>
{
using ADataType = ck_tile::pk_fp4_t;
using BDataType = ck_tile::pk_fp4_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};

template <ck_tile::GemmPipeline PipelineId>
struct PipelineTypeTraits;

Expand Down Expand Up @@ -423,6 +476,16 @@ struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V6>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV6<PipelineProblem>;
};

template <>
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_ASYNC>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompAsync<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompAsync<PipelineProblem>;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
};

template <>
struct PipelineTypeTraits<ck_tile::GemmPipeline::PRESHUFFLE_V2>
{
Expand Down
37 changes: 22 additions & 15 deletions example/ck_tile/03_gemm/run_gemm_example.inc
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,22 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// using ComputeType =
// std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// // Calculate thresholds
// const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
// ck_tile::integer_divide_ceil(K, kbatch));
// const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
// max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// // Calculate error due to split_k accumulation
// const auto rtol_split_k =
// ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
// const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
// max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
// return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
ck_tile::ignore = K; ck_tile::ignore = kbatch; ck_tile::ignore = max_accumulated_value;
return ck_tile::make_tuple(0.1, 1.0);
}

template <typename GemmConfig,
Expand Down Expand Up @@ -273,7 +275,10 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,

if(!preshuffle && GemmConfig::UseStructuredSparsity)
{
ck_tile::AdjustToStructuredSparsity<ADataType>{}(a_m_k);
if constexpr(GemmConfig::UseStructuredSparsity)
{
ck_tile::AdjustToStructuredSparsity<ADataType>{}(a_m_k);
}
}

ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
Expand Down Expand Up @@ -369,7 +374,9 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,

std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte =
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
sizeof(ADataType) * M * K / ck_tile::numeric_traits<ADataType>::PackedSize +
sizeof(BDataType) * N * K / ck_tile::numeric_traits<BDataType>::PackedSize +
sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;

Expand Down
149 changes: 78 additions & 71 deletions example/ck_tile/03_gemm/universal_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,86 +182,92 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");

if(data_type == "fp16")
if(data_type == "fp4")
{
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::pk_fp4_t>, ck_tile::pk_fp4_t, ck_tile::pk_fp4_t, ck_tile::half_t>(
a_layout, b_layout, arg_parser);
}
else if(data_type == "bf16")
{
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::bf16_t>, ck_tile::bf16_t>(
a_layout, b_layout, arg_parser);
}
else if(data_type == "fp8")
// if(data_type == "fp16")
// {
// return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
// a_layout, b_layout, arg_parser);
// }
// else if(data_type == "bf16")
// {
// return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::bf16_t>, ck_tile::bf16_t>(
// a_layout, b_layout, arg_parser);
// }
else
if(data_type == "fp8")
{
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::fp8_t>,
ck_tile::fp8_t,
ck_tile::fp8_t,
ck_tile::half_t>(
a_layout, b_layout, arg_parser);
}
else if(data_type == "bf8")
{
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::bf8_t>,
ck_tile::bf8_t,
ck_tile::bf8_t,
ck_tile::half_t>(
a_layout, b_layout, arg_parser);
}
else if(data_type == "int8")
{
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::int8_t>,
ck_tile::int8_t,
ck_tile::int8_t,
ck_tile::int32_t>(
a_layout, b_layout, arg_parser);
}
else if(data_type == "fp16i4")
{
// TODO: Add support for bhalf_t ADataType
if constexpr(GemmConfig<ck_tile::half_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
{
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::half_t>,
ck_tile::half_t,
ck_tile::pk_int4_t,
ck_tile::half_t>(
a_layout, b_layout, arg_parser);
}
else
{
throw std::runtime_error("Unsupported pipeline for this operation !!!");
}
}
else if(data_type == "fp8i4")
{
if constexpr(GemmConfig<ck_tile::fp8_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
{
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::fp8_t>,
ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t>(
a_layout, b_layout, arg_parser);
}
else
{
throw std::runtime_error("Unsupported pipeline for this operation !!!");
}
}
else if(data_type == "bf8i4")
{
if constexpr(GemmConfig<ck_tile::bf8_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
{
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::bf8_t>,
ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t>(
a_layout, b_layout, arg_parser);
}
else
{
throw std::runtime_error("Unsupported pipeline for this operation !!!");
}
}
// else if(data_type == "bf8")
// {
// return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::bf8_t>,
// ck_tile::bf8_t,
// ck_tile::bf8_t,
// ck_tile::half_t>(
// a_layout, b_layout, arg_parser);
// }
// else if(data_type == "int8")
// {
// return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::int8_t>,
// ck_tile::int8_t,
// ck_tile::int8_t,
// ck_tile::int32_t>(
// a_layout, b_layout, arg_parser);
// }
// else if(data_type == "fp16i4")
// {
// // TODO: Add support for bhalf_t ADataType
// if constexpr(GemmConfig<ck_tile::half_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
// {
// return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::half_t>,
// ck_tile::half_t,
// ck_tile::pk_int4_t,
// ck_tile::half_t>(
// a_layout, b_layout, arg_parser);
// }
// else
// {
// throw std::runtime_error("Unsupported pipeline for this operation !!!");
// }
// }
// else if(data_type == "fp8i4")
// {
// if constexpr(GemmConfig<ck_tile::fp8_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
// {
// return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::fp8_t>,
// ck_tile::fp8_t,
// ck_tile::pk_int4_t,
// ck_tile::half_t>(
// a_layout, b_layout, arg_parser);
// }
// else
// {
// throw std::runtime_error("Unsupported pipeline for this operation !!!");
// }
// }
// else if(data_type == "bf8i4")
// {
// if constexpr(GemmConfig<ck_tile::bf8_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
// {
// return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::bf8_t>,
// ck_tile::bf8_t,
// ck_tile::pk_int4_t,
// ck_tile::half_t>(
// a_layout, b_layout, arg_parser);
// }
// else
// {
// throw std::runtime_error("Unsupported pipeline for this operation !!!");
// }
// }
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");
Expand All @@ -281,7 +287,8 @@ int main(int argc, char* argv[])
#if CK_TILE_USE_WMMA
return !run_gemm_example<GemmConfigComputeV3_WMMA>(arg_parser);
#else
return !run_gemm_example<GemmConfigComputeV3_2>(arg_parser);
return !run_gemm_example<GemmConfigComputeAsync>(arg_parser);
// return !run_gemm_example<GemmConfigComputeV3_3>(arg_parser);
#endif
}
catch(const std::runtime_error& e)
Expand Down
1 change: 1 addition & 0 deletions example/ck_tile/03_gemm/universal_gemm_invoker.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ struct UniversalInvoker
GemmConfig::NumWaveGroups,
GemmConfig::Preshuffle>;

static_assert(GemmConfig::UseStructuredSparsity == false, "UseStructuredSparsity must be false");
constexpr auto scheduler = GemmConfig::Scheduler;

using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
Expand Down
17 changes: 17 additions & 0 deletions example/ck_tile/42_mx_gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT

set(SUPPORTED_GPUS gfx950)

set(has_supported_gpu FALSE)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST SUPPORTED_GPUS)
set(has_supported_gpu TRUE)
break()
endif()
endforeach()

if(has_supported_gpu)
add_executable(tile_example_mx_gemm mx_gemm.cpp)
target_compile_options(tile_example_mx_gemm PRIVATE -Wno-undefined-func-template)
endif()
Loading