From f7955d94023a855d24d593f4e3473fa1a42b16f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Thu, 18 Dec 2025 04:36:02 -0500 Subject: [PATCH 01/81] Add placeholder test. --- .../test/conv/ck/test_ckb_conv_bwd_weight.cpp | 86 +++++++++++++++++++ 1 file changed, 86 insertions(+) create mode 100644 experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp new file mode 100644 index 00000000000..7ae0ba27ea5 --- /dev/null +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp @@ -0,0 +1,86 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" +#include "ck_tile/builder/testing/conv_bwd_ck.hpp" +#include "ck_tile/host/device_prop.hpp" + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace cku = ck_tile::builder::test_utils; + +constexpr auto SIGNATURE = + ckt::ConvSignature{.spatial_dim = 2, + .direction = ckb::ConvDirection::BACKWARD_WEIGHT, + .data_type = ckb::DataType::FP16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = ckb::TensorLayout::GNHWC}}, + .weight = {.config = {.layout = ckb::TensorLayout::GKYXC}}, + .output = {.config = {.layout = ckb::TensorLayout::GNHWK}}}; + +constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} + .with_thread_block(cku::FwdThreadBlock_256_256x256x32) + .with_gemm_config(cku::FwdGemmParams_Xdl_4x4_per_wave) + .with_transfer(cku::FwdTransfer_4x64x1) + .with_specializations(ckb::ConvFwdSpecialization::DEFAULT, + ckb::GemmSpecialization::MNKPadding) + .with_block_gemm(cku::BlockGemmDesc_v3_intrawave); + +using Builder = ckb::ConvBuilder; +using Instance = Builder::Instance; + +TEST(BwdWeight_2DFp16_CShufV3_GNHWC, Create) +{ + const auto expected_transfer_parameters = to_string(ALGORITHM); + cku::run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", + expected_transfer_parameters, + "Default", + "Intrawave", + "v3", + "GNHWC,GKYXC,EmptyTuple,GNHWK", + "PassThrough,PassThrough,PassThrough", + "MNKPadding"}); +} + +TEST(BwdWeight_2DFp16_CShufV3_GNHWC, EndToEnd) +{ + if(!ck_tile::get_device_name().starts_with("gfx9")) + { + GTEST_SKIP() << "unsupported architecture"; + } + + ckt::Args args = { + .lengths = + { + .batch_size = 16, + .groups = 1, + .input_channels = 32, + .output_channels = 48, + .image = + { + .width = 56, + .height = 64, + }, + .filter = + { + .width = 3, + .height = 5, + }, + }, + .filter_strides = {.width = 1, .height = 1}, + .filter_dilation = {.width = 1, .height = 1}, + .input_left_pad = {.width = 0, .height = 0}, + .input_right_pad = {.width = 0, .height = 0}, + .a_elementwise_op = {}, + .b_elementwise_op = {}, + .cde_elementwise_op = {}, + }; + + auto inputs = alloc_inputs(args); + auto outputs = alloc_outputs(args); + + auto conv = Instance{}; + ckt::run(conv, args, inputs.get(), outputs.get()); +} \ No newline at end of file From 2460cf4579b7a5353bf22d43d4aa5fbfc868e484 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Fri, 19 Dec 2025 07:59:37 -0500 Subject: [PATCH 02/81] Initial conv bwd weight factory. --- .../builder/conv_algorithm_concepts.hpp | 19 +++- .../factory/conv_bwd_weight_xdl_factory.hpp | 102 ++++++++++++++++++ .../builder/factory/conv_dispatcher.hpp | 69 +++++++----- .../builder/factory/conv_fwd_dl_factory.hpp | 2 +- .../factory/conv_fwd_large_tensor_factory.hpp | 2 +- .../builder/factory/conv_fwd_v3_factory.hpp | 2 +- .../builder/factory/conv_fwd_wmma_factory.hpp | 2 +- .../builder/factory/conv_fwd_xdl_factory.hpp | 2 +- .../helpers/ck/conv_elementwise_op.hpp | 13 +++ .../factory/helpers/ck/conv_tensor_layout.hpp | 19 +++- .../factory/helpers/ck/conv_tensor_type.hpp | 21 ++++ .../factory/helpers/ck/conv_tuning_params.hpp | 27 +++-- .../builder/include/ck_tile/builder/types.hpp | 4 +- .../test/conv/ck/test_ckb_conv_bwd_weight.cpp | 4 +- .../test/impl/conv_algorithm_types.hpp | 84 ++++++++++++--- 15 files changed, 310 insertions(+), 62 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index bf7e89fcaab..fcc9a09c6d0 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -27,8 +27,6 @@ concept ThreadBlockDescriptor = requires(T t) { // Concept for parameters that describe a gridwise XDL GEMM problem. template concept GridwiseXdlGemmDescriptor = requires(T t) { - { t.ak1 } -> std::convertible_to; - { t.bk1 } -> std::convertible_to; { t.m_per_xdl } -> std::convertible_to; { t.n_per_xdl } -> std::convertible_to; { t.m_xdl_per_wave } -> std::convertible_to; @@ -159,7 +157,17 @@ concept SpecifiesTileThreadBlock = requires { // Concept to check if a struct specifies gridwise XDL GEMM info. template -concept SpecifiesGridwiseXdlGemm = requires { +concept SpecifiesGridwiseFwdXdlGemm = requires { + { T::gridwise_gemm.ak1 } -> std::convertible_to; + { T::gridwise_gemm.bk1 } -> std::convertible_to; + { T::gridwise_gemm } -> GridwiseXdlGemmDescriptor; +}; + +// Concept to check if a struct specifies gridwise XDL GEMM info. +template +concept SpecifiesGridwiseBwdXdlGemm = requires { + { T::gridwise_gemm.k0_per_block } -> std::convertible_to; + { T::gridwise_gemm.k1 } -> std::convertible_to; { T::gridwise_gemm } -> GridwiseXdlGemmDescriptor; }; @@ -247,6 +255,11 @@ concept SpecifiesFwdConvSpecialization = requires { { T::fwd_specialization } -> std::convertible_to; }; +template +concept SpecifiesBwdWeightConvSpecialization = requires { + { T::bwd_weight_specialization } -> std::convertible_to; +}; + template concept SpecifiesGemmSpecialization = requires { { T::gemm_specialization } -> std::convertible_to; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp new file mode 100644 index 00000000000..98893560922 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp @@ -0,0 +1,102 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_limits.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" + +namespace ck_tile::builder::factory { + +// Factory for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle instance +// of a grouped forward convolution kernel. +template + requires ConvDirectionIsBackwardWeight +struct ConvBwdWeightXdlFactory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = internal::ConvTensorLayouts; + using Types = internal::BwdConvTensorDataTypes; + using Ops = internal::ElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto BWD_CONV_SPECIALIZATION = internal::SetBwdWeightConvSpecialization(); + + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto A_BLOCK_TRANSFER = + internal::SetFwdConvBlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = + internal::SetFwdConvBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); + + // Check limits for the algorithm parameters. + // TODO: Add more limits checks as needed. + static_assert(InputVectorTransferLimits); + static_assert(InputVectorTransferLimits); + static_assert(OutputVectorTransferLimits); + static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits); + + // The forward convolution kernel class instance. + using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle< + SPATIAL_DIM, + typename Layouts::InLayout, + typename Layouts::WeiLayout, + typename Layouts::OutLayout, + typename Types::InDataType, + typename Types::WeiDataType, + typename Types::OutDataType, + typename Types::AccDataType, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, + BWD_CONV_SPECIALIZATION, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.k0_per_block, + GRIDWISE_GEMM.k1, + GRIDWISE_GEMM.m_per_xdl, + GRIDWISE_GEMM.n_per_xdl, + GRIDWISE_GEMM.m_xdl_per_wave, + GRIDWISE_GEMM.n_xdl_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + A_BLOCK_TRANSFER.lds_padding, + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + B_BLOCK_TRANSFER.lds_padding, + C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + typename Types::InComputeType, + typename Types::WeiComputeType, + ALGORITHM.max_transpose_transfer_src_scalar_per_vector, + ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp index 99e7479e362..b18d54f4891 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -60,6 +60,7 @@ #include "ck_tile/builder/factory/conv_fwd_dl_factory.hpp" #include "ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp" #include "ck_tile/builder/factory/conv_tile_factory.hpp" +#include "ck_tile/builder/factory/conv_bwd_weigth_xdl_factory.hpp" namespace ck_tile::builder::factory { @@ -88,34 +89,43 @@ concept IsTileAlgorithm = ConvAlgorithmDescriptor && SpecifiesTileThreadBlock SpecifiesTileTransfer && SpecifiesTileConvSpecialization && SpecifiesTileBlockGemm && SpecifiesTileOptimizations; +template +concept SpecifiesDataTransfer = + SpecifiesThreadBlock && SpecifiesBlockTransfer && + SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && + SpecifiesSourceAccessOrder; + // XDL-based kernel with V3 pipeline structure (newer block GEMM pipeline) template -concept IsXdlV3Algorithm = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && - SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && - SpecifiesSourceAccessOrder && SpecifiesFwdConvSpecialization && - SpecifiesGemmSpecialization && SpecifiesBlockGemm; +concept IsFwdXdlV3Algorithm = ConvAlgorithmDescriptor && + SpecifiesDataTransfer && SpecifiesGridwiseFwdXdlGemm && + SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && + SpecifiesBlockGemm; + +// Standard XDL-based fwd kernel (uses XDLops hardware instructions for matrix multiply) +template +concept IsFwdXdlAlgorithm = ConvAlgorithmDescriptor && + SpecifiesDataTransfer && SpecifiesGridwiseFwdXdlGemm && + SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && + SpecifiesNumPrefetchStages && SpecifiesNumGroupsToMerge && + SpecifiesLoopScheduler; -// Standard XDL-based kernel (uses XDLops hardware instructions for matrix multiply) +// Standard XDL-based bwd weight kernel (uses XDLops hardware instructions for matrix multiply) template -concept IsXdlAlgorithm = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && - SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && - SpecifiesSourceAccessOrder && SpecifiesFwdConvSpecialization && - SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && - SpecifiesNumGroupsToMerge && SpecifiesLoopScheduler; +concept IsBwdXdlAlgorithm = ConvAlgorithmDescriptor && + SpecifiesDataTransfer && SpecifiesGridwiseBwdXdlGemm && + SpecifiesBwdWeightConvSpecialization && SpecifiesTransposeTransfer; // WMMA-based kernel (uses Wavefront Matrix-Matrix Accumulate instructions) template -concept IsWmmaAlgorithm = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseWmmaGemm && - SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && - SpecifiesSourceAccessOrder && SpecifiesFwdConvSpecialization && - SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && SpecifiesLoopScheduler; +concept IsFwdWmmaAlgorithm = ConvAlgorithmDescriptor && + SpecifiesDataTransfer && SpecifiesGridwiseWmmaGemm && + SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && + SpecifiesNumPrefetchStages && SpecifiesLoopScheduler; // Specialized DL kernel for specific NHWC/KYXC/NHWK data layouts template -concept IsDlAlgorithm = +concept IsFwdDlAlgorithm = ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && SpecifiesDlThreadConfig && SpecifiesDlThreadCluster && SpecifiesDlBlockTransfer && SpecifiesDlEpilogue; @@ -139,19 +149,19 @@ constexpr auto make_conv_instance() } else if constexpr(ConvDirectionIsForward) { - if constexpr(IsXdlV3Algorithm) + if constexpr(IsFwdXdlV3Algorithm) { return typename ConvFwdXdlV3Factory::Instance{}; } - else if constexpr(IsXdlAlgorithm) + else if constexpr(IsFwdXdlAlgorithm) { return typename ConvFwdXdlFactory::Instance{}; } - else if constexpr(IsWmmaAlgorithm) + else if constexpr(IsFwdWmmaAlgorithm) { return typename ConvFwdWmmaFactory::Instance{}; } - else if constexpr(IsDlAlgorithm) + else if constexpr(IsFwdDlAlgorithm) { return typename ConvFwdDlFactory::Instance{}; } @@ -177,10 +187,17 @@ constexpr auto make_conv_instance() } else if constexpr(ConvDirectionIsBackwardWeight) { - static_assert( - false, - "Backward weight convolution is not yet supported. " - "Only forward convolution (ConvDirection::FORWARD) is currently implemented."); + if constexpr (IsBwdXdlAlgorithm) + { + return typename ConvBwdWeightXdlFactory::Instance{}; + } + else + { + static_assert( + false, + "Backward weight convolution is not yet supported. " + "Only forward convolution (ConvDirection::FORWARD) is currently implemented."); + } } else { diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp index ca202aabfd8..42c59dfaec2 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp @@ -24,7 +24,7 @@ template ; + using Layouts = internal::ConvTensorLayouts; using Types = internal::FwdConvTensorDataTypes; using Ops = internal::ElementwiseOps; using AlgorithmType = decltype(ALGORITHM); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp index fadf41f48aa..fca36386974 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp @@ -26,7 +26,7 @@ template ; + using Layouts = internal::ConvTensorLayouts; using Types = internal::FwdConvTensorDataTypes; using Ops = internal::ElementwiseOps; using AlgorithmType = decltype(ALGORITHM); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp index 89787cc1b32..47891869cce 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp @@ -26,7 +26,7 @@ template ; + using Layouts = internal::ConvTensorLayouts; using Types = internal::FwdConvTensorDataTypes; using Ops = internal::ElementwiseOps; using AlgorithmType = decltype(ALGORITHM); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp index bb844790710..1fb3942df0a 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp @@ -26,7 +26,7 @@ template ; + using Layouts = internal::ConvTensorLayouts; using Types = internal::FwdConvTensorDataTypes; using Ops = internal::ElementwiseOps; using AlgorithmType = decltype(ALGORITHM); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp index 8ec5c633ce0..695f1546143 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp @@ -26,7 +26,7 @@ template ; + using Layouts = internal::ConvTensorLayouts; using Types = internal::FwdConvTensorDataTypes; using Ops = internal::ElementwiseOps; using AlgorithmType = decltype(ALGORITHM); diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp index a39cd7410bb..b24344c90a4 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp @@ -62,6 +62,7 @@ consteval auto GetElementwiseOp() } template +requires ConvDirectionIsForward struct ElementwiseOps { static constexpr auto input_op = GetElementwiseOp(); @@ -72,4 +73,16 @@ struct ElementwiseOps using CDEElementwiseOp = typename decltype(output_op)::Op; }; +template +requires ConvDirectionIsBackwardWeight +struct ElementwiseOps +{ + static constexpr auto input_op = GetElementwiseOp(); + static constexpr auto weight_op = GetElementwiseOp(); + static constexpr auto output_op = GetElementwiseOp(); + using InElementwiseOp = typename decltype(input_op)::Op; + using WeiElementwiseOp = typename decltype(weight_op)::Op; + using OutElementwiseOp = typename decltype(output_op)::Op; +}; + } // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp index a6c0b48c54b..3df3f8f37cd 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp @@ -216,18 +216,31 @@ consteval auto GetAuxiliaryTensorLayouts() return EmptyAuxiliaryTensorLayout{}; } -template +template requires(ConvSpatialDim && ValidConvInputLayoutForSpatialDim && ValidConvWeightLayoutForSpatialDim && - ValidConvOutputLayoutForSpatialDim) + ValidConvOutputLayoutForSpatialDim && + ConvDirectionIsForward) struct ConvTensorLayouts { - static_assert(DIR == ConvDirection::FORWARD, "Only Forward convolution is supported."); using ALayout = decltype(TensorLayoutToCK()); using BLayout = decltype(TensorLayoutToCK()); using ELayout = decltype(TensorLayoutToCK()); using DsLayout = decltype(GetAuxiliaryTensorLayouts())::type; }; +template + requires(ConvSpatialDim && + ValidConvInputLayoutForSpatialDim && + ValidConvWeightLayoutForSpatialDim && + ValidConvOutputLayoutForSpatialDim && + ConvDirectionIsBackwardWeight) +struct ConvTensorLayouts +{ + using InLayout = decltype(TensorLayoutToCK()); + using WeiLayout = decltype(TensorLayoutToCK()); + using OutLayout = decltype(TensorLayoutToCK()); +}; + } // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp index c819e11d009..7839cb7f4a6 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp @@ -151,6 +151,7 @@ consteval auto GetAuxiliaryTensorDataTypes() } template +requires ConvDirectionIsForward struct FwdConvTensorDataTypes { static constexpr auto input_types = @@ -176,4 +177,24 @@ struct FwdConvTensorDataTypes using DsDataTypes = typename decltype(GetAuxiliaryTensorDataTypes())::type; }; +template +requires ConvDirectionIsBackwardWeight +struct FwdConvTensorDataTypes +{ + static constexpr auto input_types = + GetTensorDataAndComputeTypes(); + static constexpr auto weight_types = + GetTensorDataAndComputeTypes(); + static constexpr auto output_types = + GetTensorDataAndComputeTypes(); + + using InDataType = typename decltype(input_types.first)::type; + using InComputeType = typename decltype(input_types.second)::type; + using WeiDataType = typename decltype(weight_types.first)::type; + using WeiComputeType = typename decltype(weight_types.second)::type; + using AccDataType = + typename decltype(GetTensorAccumulationType())::type; +}; + } // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp index db741f2112a..6f3a9e8e78d 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp @@ -149,12 +149,27 @@ consteval ck::tensor_operation::device::ConvolutionForwardSpecialization SetFwdC using ck_conv_spec = ck::tensor_operation::device::ConvolutionForwardSpecialization; switch(specialization) { - case ConvFwdSpecialization::DEFAULT: return ck_conv_spec::Default; - case ConvFwdSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0; - case ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0; - case ConvFwdSpecialization::FILTER_3x3: return ck_conv_spec::Filter3x3; - case ConvFwdSpecialization::ODD_C: return ck_conv_spec::OddC; - default: throw "Unknown ConvFwdSpecialization"; + case ConvSpecialization::DEFAULT: return ck_conv_spec::Default; + case ConvSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0; + case ConvSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0; + case ConvSpecialization::FILTER_3x3: return ck_conv_spec::Filter3x3; + case ConvSpecialization::ODD_C: return ck_conv_spec::OddC; + default: throw "Unsupported ConvSpecialization"; + } +} + +template +consteval ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization SetBwdWeightConvSpecialization() +{ + constexpr auto specialization = ALGORITHM.bwd_specialization; + using ck_conv_spec = ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization; + switch(specialization) + { + case ConvSpecialization::DEFAULT: return ck_conv_spec::Default; + case ConvSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0; + case ConvSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0; + case ConvSpecialization::ODD_C: return ck_conv_spec::OddC; + default: throw "Unsupported ConvSpecialization"; } } diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index f7386720b3a..5f08b5ab9cd 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -192,8 +192,8 @@ enum class TileConvSpecialization FILTER_3x3 }; -// Enums for the forward convolution specialization. -enum class ConvFwdSpecialization +// Enums for the convolution specializations. +enum class ConvSpecialization { DEFAULT, FILTER_1X1_PAD0, diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp index 7ae0ba27ea5..87de6dab031 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp @@ -20,7 +20,7 @@ constexpr auto SIGNATURE = .weight = {.config = {.layout = ckb::TensorLayout::GKYXC}}, .output = {.config = {.layout = ckb::TensorLayout::GNHWK}}}; -constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} +constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle{} .with_thread_block(cku::FwdThreadBlock_256_256x256x32) .with_gemm_config(cku::FwdGemmParams_Xdl_4x4_per_wave) .with_transfer(cku::FwdTransfer_4x64x1) @@ -34,7 +34,7 @@ using Instance = Builder::Instance; TEST(BwdWeight_2DFp16_CShufV3_GNHWC, Create) { const auto expected_transfer_parameters = to_string(ALGORITHM); - cku::run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", + cku::run_test({"DeviceGroupedConvBwdWeight_Xdl_CShuffle", expected_transfer_parameters, "Default", "Intrawave", diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index 29c7f3cdcc1..6b87ae77d6f 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -28,18 +28,30 @@ struct ThreadBlock }; static_assert(ckb::ThreadBlockDescriptor); -// Describe gridwise XDL GEMM parameters. -struct GridwiseXdlGemm +struct XdlParams { - // NOTE: ak1 and bk1 are difficult to verify in the kernel instantiation!!! - size_t ak1 = 0; - size_t bk1 = 0; size_t m_per_xdl = 0; size_t n_per_xdl = 0; size_t m_xdl_per_wave = 0; size_t n_xdl_per_wave = 0; }; -static_assert(ckb::GridwiseXdlGemmDescriptor); +static_assert(ckb::GridwiseXdlGemmDescriptor); + +// Describe gridwise XDL GEMM parameters. +struct GridwiseFwdXdlGemm : public XdlParams +{ + // NOTE: ak1 and bk1 are difficult to verify in the kernel instantiation!!! + size_t ak1 = 0; + size_t bk1 = 0; +}; +static_assert(ckb::SpecifiesGridwiseFwdXdlGemm); + +struct GridwiseBwdXdlGemm : public XdlParams +{ + size_t k0_per_block = 0; + size_t k1 = 0; +}; +static_assert(ckb::SpecifiesGridwiseBwdXdlGemm); // Describe gridwise WMMA GEMM parameters. struct GridwiseWmmaGemm @@ -169,9 +181,14 @@ struct ThreadBlock_ ThreadBlock thread_block; }; -struct XdlGemm_ +struct FwdXdlGemm_ { - GridwiseXdlGemm gridwise_gemm; + GridwiseFwdXdlGemm gridwise_gemm; +}; + +struct BwdXdlGemm_ +{ + GridwiseBwdXdlGemm gridwise_gemm; }; struct WmmaGemm_ @@ -184,12 +201,17 @@ struct Transfer_ TransferABC transfer; }; -struct ConvSpecialization_ +struct ConvSpecializationFwd_ { - ConvFwdSpecialization fwd_specialization; + ConvSpecialization fwd_specialization; GemmSpecialization gemm_specialization; }; +struct ConvSpecializationBwdWeight_ +{ + ConvSpecialization bwd_specialization; +}; + struct Prefetch_ { size_t num_gemm_k_prefetch_stages; @@ -197,6 +219,12 @@ struct Prefetch_ PipelineScheduler loop_scheduler; }; +struct TransposeParams_ +{ + size_t max_transpose_transfer_src_scalar_per_vector{1}; + size_t max_transpose_transfer_dst_scalar_per_vector{1}; +}; + struct BlockGemm_ { BlockGemm block_gemm; @@ -329,7 +357,11 @@ struct ConvAlgorithmTemplate : Components... constexpr auto with_gemm_config(const GemmConfig& gemm) const { auto result = *this; - if constexpr(std::is_base_of_v) + if constexpr(std::is_base_of_v) + { + result.gridwise_gemm = gemm; + } + if constexpr(std::is_base_of_v) { result.gridwise_gemm = gemm; } @@ -359,6 +391,14 @@ struct ConvAlgorithmTemplate : Components... return result; } + constexpr auto with_specializations(ConvBwdWeightSpecialization bwd_spec) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.bwd_specialization = bwd_spec; + return result; + } + constexpr auto with_prefetch_config(size_t k_prefetch_stages, size_t groups_to_merge, PipelineScheduler scheduler) const @@ -371,6 +411,16 @@ struct ConvAlgorithmTemplate : Components... return result; } + constexpr auto with_transpose_params(bool max_src_scalar_per_vector, + bool max_dst_scalar_per_vector) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.max_transpose_transfer_src_scalar_per_vector = max_src_scalar_per_vector; + result.max_transpose_transfer_dst_scalar_per_vector = max_dst_scalar_per_vector; + return result; + } + template constexpr auto with_block_gemm(const BG& bg) const { @@ -456,16 +506,17 @@ struct ConvAlgorithmTemplate : Components... // Algorithm types using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = - ConvAlgorithmTemplate; + ConvAlgorithmTemplate; using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 = - ConvAlgorithmTemplate; + ConvAlgorithmTemplate; using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle = - ConvAlgorithmTemplate; + ConvAlgorithmTemplate; + using ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK = ConvAlgorithmTemplate; @@ -479,4 +530,7 @@ using ConvAlgorithm_Tile_GroupedConvolutionKernel = ConvAlgorithmTemplate; +using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle = + ConvAlgorithmTemplate; + } // namespace ck_tile::builder::test From 5a1c9c9a2248cc48116f29a23a17d8ec62ab5ea3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Fri, 19 Dec 2025 09:14:44 -0500 Subject: [PATCH 03/81] Conv builder test refactoring. --- experimental/builder/test/CMakeLists.txt | 6 ++ .../test/conv/ck/test_ckb_conv_bwd_weight.cpp | 11 ++-- .../conv/ck/test_ckb_conv_fwd_1d_bf16.cpp | 6 +- .../conv/ck/test_ckb_conv_fwd_1d_fp16.cpp | 6 +- .../test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp | 6 +- .../conv/ck/test_ckb_conv_fwd_2d_bf16.cpp | 12 ++-- ...est_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp | 6 +- .../conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp | 8 +-- .../conv/ck/test_ckb_conv_fwd_2d_fp16.cpp | 6 +- .../conv/ck/test_ckb_conv_fwd_2d_fp32.cpp | 6 +- .../test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp | 6 +- ...test_ckb_conv_fwd_2d_large_tensor_fp16.cpp | 10 ++-- .../conv/ck/test_ckb_conv_fwd_3d_bf16.cpp | 6 +- .../conv/ck/test_ckb_conv_fwd_3d_fp16.cpp | 6 +- .../conv/ck/test_ckb_conv_fwd_3d_fp32.cpp | 6 +- .../test/impl/conv_algorithm_types.hpp | 4 +- .../test/utils/ckb_conv_test_configs.hpp | 60 +++++++++++-------- 17 files changed, 92 insertions(+), 79 deletions(-) diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index eb4ef134628..43a142a2798 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -140,6 +140,11 @@ add_ck_builder_test(test_ckb_build_fwd_instances conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp) target_link_libraries(test_ckb_build_fwd_instances PRIVATE utility) +add_ck_builder_test(test_ckb_build_bwd_weight_instances + conv/ck/test_ckb_conv_bwd_weight.cpp + ) +target_link_libraries(test_ckb_build_bwd_weight_instances PRIVATE utility) + ################################################################################ # FACTORY TESTS - Expensive Regression Tests (Full MIOpen Kernel Set) @@ -192,6 +197,7 @@ endforeach() set(CKB_REGRESSION_TESTS test_ckb_instance_string test_ckb_build_fwd_instances + test_ckb_build_bwd_weight_instances test_ckb_testing_utils # test_ckb_factory_grouped_convolution_forward_convscale # test_ckb_factory_grouped_convolution_forward_scaleadd_ab diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp index 87de6dab031..f512f094439 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp @@ -21,11 +21,10 @@ constexpr auto SIGNATURE = .output = {.config = {.layout = ckb::TensorLayout::GNHWK}}}; constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle{} - .with_thread_block(cku::FwdThreadBlock_256_256x256x32) - .with_gemm_config(cku::FwdGemmParams_Xdl_4x4_per_wave) - .with_transfer(cku::FwdTransfer_4x64x1) - .with_specializations(ckb::ConvFwdSpecialization::DEFAULT, - ckb::GemmSpecialization::MNKPadding) + .with_thread_block(cku::ThreadBlock_256_256x256x32) + .with_gemm_config(cku::BwdGemmParams_Xdl_4x4_per_wave) + .with_transfer(cku::Transfer_4x64x1) + .with_bwd_specialization(ckb::ConvFwdSpecialization::DEFAULT) .with_block_gemm(cku::BlockGemmDesc_v3_intrawave); using Builder = ckb::ConvBuilder; @@ -83,4 +82,4 @@ TEST(BwdWeight_2DFp16_CShufV3_GNHWC, EndToEnd) auto conv = Instance{}; ckt::run(conv, args, inputs.get(), outputs.get()); -} \ No newline at end of file +} diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp index 284b3929ee5..8c59dd21b16 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp @@ -30,10 +30,10 @@ TEST(FwdConvInstances, constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} - .with_thread_block(FwdThreadBlock_256_256x256x32) + .with_thread_block(ThreadBlock_256_256x256x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) - .with_transfer(FwdTransfer_4x64x1) - .with_specializations(ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0, + .with_transfer(Transfer_4x64x1) + .with_fwd_specializations(ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0, GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v2_intrawave); diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp index 6802e0caf8d..7ab3ac605b8 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp @@ -27,10 +27,10 @@ TEST(FwdConvInstances, constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} - .with_thread_block(FwdThreadBlock_64_64x32x32) + .with_thread_block(ThreadBlock_64_64x32x32) .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) - .with_transfer(FwdTransfer_4x16x1) - .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_transfer(Transfer_4x16x1) + .with_fwd_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_prefetch_config(1, 2, PipelineScheduler::DEFAULT); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp index 14463bbc175..d8fbd3827e0 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp @@ -29,10 +29,10 @@ TEST(FwdConvInstances, constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle{} - .with_thread_block(FwdThreadBlock_128_64x64x64) + .with_thread_block(ThreadBlock_128_64x64x64) .with_gemm_config(FwdGemmParams_Wmma_2x1_per_wave) - .with_transfer(FwdTransfer_4x32x1) - .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_transfer(Transfer_4x32x1) + .with_fwd_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_prefetch_config(1, 0, PipelineScheduler::DEFAULT); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp index 4a5618a6b12..5f3bdfe4140 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp @@ -27,10 +27,10 @@ TEST(FwdConvInstances, constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} - .with_thread_block(FwdThreadBlock_256_256x256x32) + .with_thread_block(ThreadBlock_256_256x256x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) - .with_transfer(FwdTransfer_4x64x1) - .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_transfer(Transfer_4x64x1) + .with_fwd_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v1_intrawave); using Builder = ConvBuilder; @@ -64,10 +64,10 @@ TEST(FwdConvInstances, constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} - .with_thread_block(FwdThreadBlock_256_256x256x32) + .with_thread_block(ThreadBlock_256_256x256x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) - .with_transfer(FwdTransfer_4x64x1) - .with_specializations(ConvFwdSpecialization::FILTER_3x3, GemmSpecialization::MNKPadding) + .with_transfer(Transfer_4x64x1) + .with_fwd_specializations(ConvFwdSpecialization::FILTER_3x3, GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v5_intrawave); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp index e3dc261fe3e..f6403b312cc 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp @@ -32,10 +32,10 @@ TEST(FwdConvInstances, constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} - .with_thread_block(FwdThreadBlock_64_64x32x32) + .with_thread_block(ThreadBlock_64_64x32x32) .with_gemm_config(FwdGemmParams_Xdl_2x2_per_wave) - .with_transfer(FwdTransfer_4x16x1) - .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_transfer(Transfer_4x16x1) + .with_fwd_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp index 9bea834ef90..a49f55e6d67 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp @@ -25,8 +25,8 @@ TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Ins constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK{} - .with_thread_block(FwdThreadBlock_256_128x128x16) - .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_thread_block(ThreadBlock_256_128x128x16) + .with_fwd_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_dl_thread_config(DlThreadConfig_16x2x4x4x1) .with_dl_thread_cluster(DlThreadCluster_8x2) .with_dl_transfer(DlFwdTransfer); @@ -59,8 +59,8 @@ TEST(FwdConvInstances, constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK{} - .with_thread_block(FwdThreadBlock_256_128x128x16) - .with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0, + .with_thread_block(ThreadBlock_256_128x128x16) + .with_fwd_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0, GemmSpecialization::MNKPadding) .with_dl_thread_config(DlThreadConfig_16x2x4x4x1) .with_dl_thread_cluster(DlThreadCluster_8x2) diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp index b7eacf5643d..7b543e0e3ee 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp @@ -21,10 +21,10 @@ constexpr auto SIGNATURE = .output = {.config = {.layout = ckb::TensorLayout::GNHWK}}}; constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} - .with_thread_block(cku::FwdThreadBlock_256_256x256x32) + .with_thread_block(cku::ThreadBlock_256_256x256x32) .with_gemm_config(cku::FwdGemmParams_Xdl_4x4_per_wave) - .with_transfer(cku::FwdTransfer_4x64x1) - .with_specializations(ckb::ConvFwdSpecialization::DEFAULT, + .with_transfer(cku::Transfer_4x64x1) + .with_fwd_specializations(ckb::ConvFwdSpecialization::DEFAULT, ckb::GemmSpecialization::MNKPadding) .with_block_gemm(cku::BlockGemmDesc_v3_intrawave); diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp index 79ee4915e82..a3f493c1d71 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp @@ -26,10 +26,10 @@ TEST(FwdConvInstances, constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} - .with_thread_block(FwdThreadBlock_256_128x128x32) + .with_thread_block(ThreadBlock_256_128x128x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) - .with_transfer(FwdTransfer_4x64x1) - .with_specializations(ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0, + .with_transfer(Transfer_4x64x1) + .with_fwd_specializations(ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0, GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v4_intrawave); diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp index 3e3d7e8c2b6..279c942ba90 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp @@ -27,10 +27,10 @@ TEST(FwdConvInstances, constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} - .with_thread_block(FwdThreadBlock_256_256x128x32) + .with_thread_block(ThreadBlock_256_256x128x32) .with_gemm_config(FwdGemmParams_Xdl_4x2_per_wave) - .with_transfer(FwdTransfer_4x64x1_fp8) - .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_transfer(Transfer_4x64x1_fp8) + .with_fwd_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp index 3019c57a188..cd9655f0b45 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp @@ -27,10 +27,10 @@ TEST(FwdConvInstances, constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{ .base_algorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} - .with_thread_block(FwdThreadBlock_256_256x128x32) + .with_thread_block(ThreadBlock_256_256x128x32) .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) - .with_transfer(FwdTransfer_4x16x1) - .with_specializations(ConvFwdSpecialization::DEFAULT, + .with_transfer(Transfer_4x16x1) + .with_fwd_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT)}; @@ -64,9 +64,9 @@ TEST( constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{ .base_algorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} - .with_thread_block(FwdThreadBlock_128_128x128x32) + .with_thread_block(ThreadBlock_128_128x128x32) .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) - .with_transfer(FwdTransfer_4x16x1) + .with_transfer(Transfer_4x16x1) .with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0, GemmSpecialization::MNKPadding) .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT)}; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp index 3f9bdfb972b..b29b4471c34 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp @@ -27,10 +27,10 @@ TEST(FwdConvInstances, constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} - .with_thread_block(FwdThreadBlock_256_256x256x32) + .with_thread_block(ThreadBlock_256_256x256x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) - .with_transfer(FwdTransfer_4x64x1) - .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_transfer(Transfer_4x64x1) + .with_fwd_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v3_intrawave); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp index 11c81725330..9c4b9d4ec0a 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp @@ -27,10 +27,10 @@ TEST(FwdConvInstances, constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} - .with_thread_block(FwdThreadBlock_256_128x128x32) + .with_thread_block(ThreadBlock_256_128x128x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) - .with_transfer(FwdTransfer_4x64x1) - .with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0, + .with_transfer(Transfer_4x64x1) + .with_fwd_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0, GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v4_intrawave); diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp index 33c01c8ac42..ed0ec0e3c1d 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp @@ -27,10 +27,10 @@ TEST(FwdConvInstances, constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} - .with_thread_block(FwdThreadBlock_256_256x256x32) + .with_thread_block(ThreadBlock_256_256x256x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) - .with_transfer(FwdTransfer_4x64x1) - .with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0, + .with_transfer(Transfer_4x64x1) + .with_fwd-specializations(ConvFwdSpecialization::FILTER_1X1_PAD0, GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v1_intrawave); diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index 6b87ae77d6f..3e861db14c0 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -381,7 +381,7 @@ struct ConvAlgorithmTemplate : Components... return result; } - constexpr auto with_specializations(ConvFwdSpecialization fwd_spec, + constexpr auto with_fwd_specializations(ConvSpecialization fwd_spec, GemmSpecialization gemm_spec) const { static_assert(std::is_base_of_v); @@ -391,7 +391,7 @@ struct ConvAlgorithmTemplate : Components... return result; } - constexpr auto with_specializations(ConvBwdWeightSpecialization bwd_spec) const + constexpr auto with_bwd_specialization(ConvSpecialization bwd_spec) const { static_assert(std::is_base_of_v); auto result = *this; diff --git a/experimental/builder/test/utils/ckb_conv_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_test_configs.hpp index 403c2ffd79c..f7c748b7c7a 100644 --- a/experimental/builder/test/utils/ckb_conv_test_configs.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_configs.hpp @@ -39,7 +39,7 @@ constexpr DlTransferABC DlFwdTransfer{.a = .dst_scalar_per_vector = 4}, }}; -constexpr TransferABC FwdTransfer_4x64x1{ +constexpr TransferABC Transfer_4x64x1{ .a = { .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, @@ -72,7 +72,7 @@ constexpr TransferABC FwdTransfer_4x64x1{ }, }; -constexpr TransferABC FwdTransfer_4x64x1_fp8{ +constexpr TransferABC Transfer_4x64x1_fp8{ .a = { .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, @@ -105,7 +105,7 @@ constexpr TransferABC FwdTransfer_4x64x1_fp8{ }, }; -constexpr TransferABC FwdTransfer_4x16x1{ +constexpr TransferABC Transfer_4x16x1{ .a = { .block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1}, @@ -139,7 +139,7 @@ constexpr TransferABC FwdTransfer_4x16x1{ }, }; -constexpr TransferABC FwdTransfer_4x32x1{ +constexpr TransferABC Transfer_4x32x1{ .a = { .block_transfer = {.k0 = 4, .m_n = 32, .k1 = 1}, @@ -172,17 +172,25 @@ constexpr TransferABC FwdTransfer_4x32x1{ }, }; -constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x4_per_wave{ - .ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}; +constexpr GridwiseBwdXdlGemm BwdGemmParams_Xdl_4x4_per_wave{ + .k0_per_block = 8, .k1 = 8, + {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}}; -constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x2_per_wave{ - .ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 2}; +constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_4x4_per_wave{ + .ak1 = 8, .bk1 = 8, + {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}}; -constexpr GridwiseXdlGemm FwdGemmParams_Xdl_2x2_per_wave{ - .ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2}; +constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_4x2_per_wave{ + .ak1 = 8, .bk1 = 8, + {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 2}}; -constexpr GridwiseXdlGemm FwdGemmParams_Xdl_2x1_per_wave{ - .ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 1}; +constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_2x2_per_wave{ + .ak1 = 8, .bk1 = 8, + {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2}}; + +constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_2x1_per_wave{ + .ak1 = 8, .bk1 = 8, + {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 1}}; constexpr GridwiseWmmaGemm FwdGemmParams_Wmma_2x1_per_wave{.k1 = 8, .m_per_wmma = 32, @@ -191,26 +199,26 @@ constexpr GridwiseWmmaGemm FwdGemmParams_Wmma_2x1_per_wave{.k1 = 8 .n_wmma_per_wave = 1, .pipeline_version = PipelineVersion::V1}; -constexpr ThreadBlock FwdThreadBlock_256_256x256x32{.block_size = 256, - .tile_size = {.m = 256, .n = 256, .k = 32}}; +constexpr ThreadBlock ThreadBlock_256_256x256x32{.block_size = 256, + .tile_size = {.m = 256, .n = 256, .k = 32}}; -constexpr ThreadBlock FwdThreadBlock_256_256x128x32{.block_size = 256, - .tile_size = {.m = 256, .n = 128, .k = 32}}; +constexpr ThreadBlock ThreadBlock_256_256x128x32{.block_size = 256, + .tile_size = {.m = 256, .n = 128, .k = 32}}; -constexpr ThreadBlock FwdThreadBlock_256_128x128x32{.block_size = 256, - .tile_size = {.m = 128, .n = 128, .k = 32}}; +constexpr ThreadBlock ThreadBlock_256_128x128x32{.block_size = 256, + .tile_size = {.m = 128, .n = 128, .k = 32}}; -constexpr ThreadBlock FwdThreadBlock_256_128x128x16{.block_size = 256, - .tile_size = {.m = 128, .n = 128, .k = 16}}; +constexpr ThreadBlock ThreadBlock_256_128x128x16{.block_size = 256, + .tile_size = {.m = 128, .n = 128, .k = 16}}; -constexpr ThreadBlock FwdThreadBlock_64_64x32x32{.block_size = 64, - .tile_size = {.m = 64, .n = 32, .k = 32}}; +constexpr ThreadBlock ThreadBlock_64_64x32x32{.block_size = 64, + .tile_size = {.m = 64, .n = 32, .k = 32}}; -constexpr ThreadBlock FwdThreadBlock_128_128x128x32{.block_size = 128, - .tile_size = {.m = 128, .n = 128, .k = 32}}; +constexpr ThreadBlock ThreadBlock_128_128x128x32{.block_size = 128, + .tile_size = {.m = 128, .n = 128, .k = 32}}; -constexpr ThreadBlock FwdThreadBlock_128_64x64x64{.block_size = 128, - .tile_size = {.m = 64, .n = 64, .k = 64}}; +constexpr ThreadBlock ThreadBlock_128_64x64x64{.block_size = 128, + .tile_size = {.m = 64, .n = 64, .k = 64}}; constexpr BlockGemm BlockGemmDesc_v1_intrawave = {.pipeline_version = PipelineVersion::V1, .scheduler = PipelineScheduler::INTRAWAVE}; From 1df8077528ce9aaa6a4f7f69525fa43057a725ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Fri, 19 Dec 2025 10:38:27 -0500 Subject: [PATCH 04/81] Add missing pieces to bwd weight factory. --- .../builder/conv_algorithm_concepts.hpp | 22 ++++-- .../factory/conv_bwd_weight_xdl_factory.hpp | 3 +- .../builder/factory/conv_dispatcher.hpp | 11 ++- .../helpers/ck/conv_elementwise_op.hpp | 33 ++++---- .../factory/helpers/ck/conv_tensor_layout.hpp | 47 +++++------ .../factory/helpers/ck/conv_tensor_type.hpp | 4 +- .../builder/reflect/conv_description.hpp | 5 +- .../ck_tile/builder/reflect/conv_traits.hpp | 10 +-- .../builder/include/ck_tile/builder/types.hpp | 66 +--------------- .../test/conv/ck/test_ckb_conv_bwd_weight.cpp | 79 +++++++++---------- .../test/impl/conv_algorithm_types.hpp | 10 ++- .../test/utils/ckb_conv_test_configs.hpp | 10 +-- .../test/utils/conv_algorithm_type_utils.hpp | 49 ++++++++++-- 13 files changed, 160 insertions(+), 189 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index fcc9a09c6d0..ddd8a09ec77 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -158,17 +158,17 @@ concept SpecifiesTileThreadBlock = requires { // Concept to check if a struct specifies gridwise XDL GEMM info. template concept SpecifiesGridwiseFwdXdlGemm = requires { - { T::gridwise_gemm.ak1 } -> std::convertible_to; - { T::gridwise_gemm.bk1 } -> std::convertible_to; - { T::gridwise_gemm } -> GridwiseXdlGemmDescriptor; + { T::ak1 } -> std::convertible_to; + { T::bk1 } -> std::convertible_to; + { T::xdl_params } -> GridwiseXdlGemmDescriptor; }; // Concept to check if a struct specifies gridwise XDL GEMM info. template concept SpecifiesGridwiseBwdXdlGemm = requires { - { T::gridwise_gemm.k0_per_block } -> std::convertible_to; - { T::gridwise_gemm.k1 } -> std::convertible_to; - { T::gridwise_gemm } -> GridwiseXdlGemmDescriptor; + { T::k0_per_block } -> std::convertible_to; + { T::k1 } -> std::convertible_to; + { T::xdl_params } -> GridwiseXdlGemmDescriptor; }; // Concept to check if a struct specifies gridwise WMMA GEMM info. @@ -252,12 +252,12 @@ concept SpecifiesTileConvSpecialization = requires { template concept SpecifiesFwdConvSpecialization = requires { - { T::fwd_specialization } -> std::convertible_to; + { T::fwd_specialization } -> std::convertible_to; }; template concept SpecifiesBwdWeightConvSpecialization = requires { - { T::bwd_weight_specialization } -> std::convertible_to; + { T::bwd_weight_specialization } -> std::convertible_to; }; template @@ -286,6 +286,12 @@ concept SpecifiesLargeTensorSupport = requires { requires T::specialization == ConvAlgorithmSpecialization::LARGE_TENSOR; }; +template +concept SpecifiesTransposeTransfer = requires { + { T::max_transpose_transfer_src_scalar_per_vector } -> std::convertible_to; + { T::max_transpose_transfer_dst_scalar_per_vector } -> std::convertible_to; +}; + /******************************************** */ /* DL-specific descriptors and requirements */ /******************************************** */ diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp index 98893560922..0f726fe67d9 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp @@ -27,7 +27,7 @@ struct ConvBwdWeightXdlFactory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; using Layouts = internal::ConvTensorLayouts; - using Types = internal::BwdConvTensorDataTypes; + using Types = internal::BwdWeightConvTensorDataTypes; using Ops = internal::ElementwiseOps; using AlgorithmType = decltype(ALGORITHM); @@ -68,7 +68,6 @@ struct ConvBwdWeightXdlFactory BLOCK.block_size, BLOCK.per_block.m, BLOCK.per_block.n, - BLOCK.per_block.k, GRIDWISE_GEMM.k0_per_block, GRIDWISE_GEMM.k1, GRIDWISE_GEMM.m_per_xdl, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp index b18d54f4891..db2f4cc3292 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -60,7 +60,7 @@ #include "ck_tile/builder/factory/conv_fwd_dl_factory.hpp" #include "ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp" #include "ck_tile/builder/factory/conv_tile_factory.hpp" -#include "ck_tile/builder/factory/conv_bwd_weigth_xdl_factory.hpp" +#include "ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp" namespace ck_tile::builder::factory { @@ -133,7 +133,7 @@ concept IsFwdDlAlgorithm = // XDL-based kernel with large tensor support template concept IsLargeTensorAlgorithm = - IsXdlAlgorithm && SpecifiesLargeTensorSupport; + IsFwdXdlAlgorithm && SpecifiesLargeTensorSupport; template ) { @@ -195,8 +194,8 @@ constexpr auto make_conv_instance() { static_assert( false, - "Backward weight convolution is not yet supported. " - "Only forward convolution (ConvDirection::FORWARD) is currently implemented."); + "No suitable forward convolution kernel factory found for the provided ALGORITHM. " + "The ALGORITHM must satisfy requirements for XDL variant."); } } else diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp index b24344c90a4..00205d414ec 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp @@ -62,27 +62,30 @@ consteval auto GetElementwiseOp() } template -requires ConvDirectionIsForward struct ElementwiseOps { +private: static constexpr auto input_op = GetElementwiseOp(); static constexpr auto weight_op = GetElementwiseOp(); static constexpr auto output_op = GetElementwiseOp(); - using AElementwiseOp = typename decltype(input_op)::Op; - using BElementwiseOp = typename decltype(weight_op)::Op; - using CDEElementwiseOp = typename decltype(output_op)::Op; -}; -template -requires ConvDirectionIsBackwardWeight -struct ElementwiseOps -{ - static constexpr auto input_op = GetElementwiseOp(); - static constexpr auto weight_op = GetElementwiseOp(); - static constexpr auto output_op = GetElementwiseOp(); - using InElementwiseOp = typename decltype(input_op)::Op; - using WeiElementwiseOp = typename decltype(weight_op)::Op; - using OutElementwiseOp = typename decltype(output_op)::Op; + static constexpr bool is_forward = ConvDirectionIsForward; + static constexpr bool is_bwd_weight = ConvDirectionIsBackwardWeight; + + using InputOp = typename decltype(input_op)::Op; + using WeightOp = typename decltype(weight_op)::Op; + using OutputOp = typename decltype(output_op)::Op; + +public: + // Forward convolution elementwise ops + using AElementwiseOp = std::conditional_t; + using BElementwiseOp = std::conditional_t; + using CDEElementwiseOp = std::conditional_t; + + // Backward weight convolution elementwise ops + using InElementwiseOp = std::conditional_t; + using WeiElementwiseOp = std::conditional_t; + using OutElementwiseOp = std::conditional_t; }; } // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp index 3df3f8f37cd..22c026c28f7 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp @@ -190,7 +190,7 @@ consteval auto GetAuxiliaryTensorLayoutTuple(std::index_sequence) decltype(TensorLayoutToCK())...>{}; } -template +template requires(ConvSpatialDim) struct AuxiliaryTensorLayouts { @@ -200,13 +200,12 @@ struct AuxiliaryTensorLayouts }; // TODO: Currently only the ouput tensor can have auxiliary tensors (e.g., bias). -template +template requires(HasElementwiseOpWithAuxiliaryOperands) consteval auto GetAuxiliaryTensorLayouts() { return AuxiliaryTensorLayouts{}; + SPATIAL_DIM>{}; } template @@ -220,27 +219,29 @@ template requires(ConvSpatialDim && ValidConvInputLayoutForSpatialDim && ValidConvWeightLayoutForSpatialDim && - ValidConvOutputLayoutForSpatialDim && - ConvDirectionIsForward) -struct ConvTensorLayouts -{ - using ALayout = decltype(TensorLayoutToCK()); - using BLayout = decltype(TensorLayoutToCK()); - using ELayout = decltype(TensorLayoutToCK()); - using DsLayout = decltype(GetAuxiliaryTensorLayouts())::type; -}; - -template - requires(ConvSpatialDim && - ValidConvInputLayoutForSpatialDim && - ValidConvWeightLayoutForSpatialDim && - ValidConvOutputLayoutForSpatialDim && - ConvDirectionIsBackwardWeight) + ValidConvOutputLayoutForSpatialDim) struct ConvTensorLayouts { - using InLayout = decltype(TensorLayoutToCK()); - using WeiLayout = decltype(TensorLayoutToCK()); - using OutLayout = decltype(TensorLayoutToCK()); +private: + static constexpr bool is_forward = ConvDirectionIsForward; + static constexpr bool is_bwd_weight = ConvDirectionIsBackwardWeight; + + using InputLayout = decltype(TensorLayoutToCK()); + using WeightLayout = decltype(TensorLayoutToCK()); + using OutputLayout = decltype(TensorLayoutToCK()); + using AuxLayout = decltype(GetAuxiliaryTensorLayouts())::type; + +public: + // Forward convolution layouts + using ALayout = std::conditional_t; + using BLayout = std::conditional_t; + using ELayout = std::conditional_t; + using DsLayout = std::conditional_t; + + // Backward weight convolution layouts + using InLayout = std::conditional_t; + using WeiLayout = std::conditional_t; + using OutLayout = std::conditional_t; }; } // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp index 7839cb7f4a6..d4a470dcedf 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp @@ -151,7 +151,6 @@ consteval auto GetAuxiliaryTensorDataTypes() } template -requires ConvDirectionIsForward struct FwdConvTensorDataTypes { static constexpr auto input_types = @@ -178,8 +177,7 @@ struct FwdConvTensorDataTypes }; template -requires ConvDirectionIsBackwardWeight -struct FwdConvTensorDataTypes +struct BwdWeightConvTensorDataTypes { static constexpr auto input_types = GetTensorDataAndComputeTypes(); diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp index 46c9bb488e6..a7b6c60a73e 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp @@ -63,10 +63,7 @@ struct GemmAlgorithmInfo OutputTileTransferInfo c_tile_transfer; builder::PipelineVersion pipeline_version; builder::PipelineScheduler pipeline_scheduler; - std::variant - conv_specialization; + builder::ConvSpecialization conv_specialization; builder::GemmPadding padding; }; diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp index a91abd1a46a..8caa11618ec 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp @@ -197,18 +197,16 @@ constexpr builder::ConvDirection conv_direction() /// @brief Derives the convolution-specific specialization from a device kernel `Instance` type. /// @tparam Instance The device kernel instance type. -/// @return A `builder::ConvFwdSpecialization`, `builder::ConvBwdDataSpecialization`, or -/// `builder::ConvBwdWeightSpecialization` enum value. +/// @return A `builder::ConvSpecialization` enum value. template constexpr auto conv_spec() { using InstTraits = InstanceTraits; + using enum builder::ConvSpecialization; if constexpr(requires { InstTraits::kConvForwardSpecialization; }) { using enum ck::tensor_operation::device::ConvolutionForwardSpecialization; - using enum builder::ConvFwdSpecialization; - switch(InstTraits::kConvForwardSpecialization) { case Default: return DEFAULT; @@ -221,8 +219,6 @@ constexpr auto conv_spec() else if constexpr(requires { InstTraits::kConvBwdDataSpecialization; }) { using enum ck::tensor_operation::device::ConvolutionBackwardDataSpecialization; - using enum builder::ConvBwdDataSpecialization; - switch(InstTraits::kConvBwdDataSpecialization) { case Default: return DEFAULT; @@ -232,8 +228,6 @@ constexpr auto conv_spec() else if constexpr(requires { InstTraits::kConvBwdWeightSpecialization; }) { using enum ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization; - using enum builder::ConvBwdWeightSpecialization; - switch(InstTraits::kConvBwdWeightSpecialization) { case Default: return DEFAULT; diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index 5f08b5ab9cd..ade9484640f 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -202,22 +202,6 @@ enum class ConvSpecialization ODD_C }; -// Enums for the backward data convolution specialization. -enum class ConvBwdDataSpecialization -{ - DEFAULT, - FILTER_1X1_STRIDE1_PAD0, -}; - -// Enums for the backward weight convolution specialization. -enum class ConvBwdWeightSpecialization -{ - DEFAULT, - FILTER_1X1_STRIDE1_PAD0, - FILTER_1X1_PAD0, - ODD_C, -}; - // Enums for the Gemm padding. enum class GemmPadding { @@ -371,9 +355,9 @@ inline std::string_view toString(GemmSpecialization spec) } } -inline std::string_view toString(ConvFwdSpecialization spec) +inline std::string_view toString(ConvSpecialization spec) { - using enum ConvFwdSpecialization; + using enum ConvSpecialization; switch(spec) { case DEFAULT: return "DEFAULT"; @@ -385,30 +369,6 @@ inline std::string_view toString(ConvFwdSpecialization spec) } } -inline std::string_view toString(ConvBwdDataSpecialization spec) -{ - using enum ConvBwdDataSpecialization; - switch(spec) - { - case DEFAULT: return "DEFAULT"; - case FILTER_1X1_STRIDE1_PAD0: return "FILTER_1X1_STRIDE1_PAD0"; - default: return "Unknown"; - } -} - -inline std::string_view toString(ConvBwdWeightSpecialization spec) -{ - using enum ConvBwdWeightSpecialization; - switch(spec) - { - case DEFAULT: return "DEFAULT"; - case FILTER_1X1_STRIDE1_PAD0: return "FILTER_1X1_STRIDE1_PAD0"; - case FILTER_1X1_PAD0: return "FILTER_1X1_PAD0"; - case ODD_C: return "ODD_C"; - default: return "Unknown"; - } -} - inline std::string_view toString(GemmPadding padding) { using enum GemmPadding; @@ -521,17 +481,7 @@ inline std::ostream& operator<<(std::ostream& os, GemmSpecialization spec) return os << toString(spec); } -inline std::ostream& operator<<(std::ostream& os, ConvFwdSpecialization spec) -{ - return os << toString(spec); -} - -inline std::ostream& operator<<(std::ostream& os, ConvBwdDataSpecialization spec) -{ - return os << toString(spec); -} - -inline std::ostream& operator<<(std::ostream& os, ConvBwdWeightSpecialization spec) +inline std::ostream& operator<<(std::ostream& os, ConvSpecialization spec) { return os << toString(spec); } @@ -551,14 +501,4 @@ inline std::ostream& operator<<(std::ostream& os, TensorLayout layout) return os << toString(layout); } -// ostream operator overload for std::variant of convolution specializations -inline std::ostream& operator<<(std::ostream& os, - const std::variant& spec) -{ - std::visit([&os](const auto& s) { os << s; }, spec); - return os; -} - } // namespace ck_tile::builder diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp index f512f094439..366c2b27514 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp @@ -4,7 +4,7 @@ #include "utils/ckb_conv_test_configs.hpp" #include "utils/ckb_conv_test_utils.hpp" #include "utils/conv_algorithm_type_utils.hpp" -#include "ck_tile/builder/testing/conv_bwd_ck.hpp" +//#include "ck_tile/builder/testing/conv_bwd_ck.hpp" #include "ck_tile/host/device_prop.hpp" namespace ckb = ck_tile::builder; @@ -24,8 +24,7 @@ constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CSh .with_thread_block(cku::ThreadBlock_256_256x256x32) .with_gemm_config(cku::BwdGemmParams_Xdl_4x4_per_wave) .with_transfer(cku::Transfer_4x64x1) - .with_bwd_specialization(ckb::ConvFwdSpecialization::DEFAULT) - .with_block_gemm(cku::BlockGemmDesc_v3_intrawave); + .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT); using Builder = ckb::ConvBuilder; using Instance = Builder::Instance; @@ -43,43 +42,43 @@ TEST(BwdWeight_2DFp16_CShufV3_GNHWC, Create) "MNKPadding"}); } -TEST(BwdWeight_2DFp16_CShufV3_GNHWC, EndToEnd) -{ - if(!ck_tile::get_device_name().starts_with("gfx9")) - { - GTEST_SKIP() << "unsupported architecture"; - } +// TEST(BwdWeight_2DFp16_CShufV3_GNHWC, EndToEnd) +// { +// if(!ck_tile::get_device_name().starts_with("gfx9")) +// { +// GTEST_SKIP() << "unsupported architecture"; +// } - ckt::Args args = { - .lengths = - { - .batch_size = 16, - .groups = 1, - .input_channels = 32, - .output_channels = 48, - .image = - { - .width = 56, - .height = 64, - }, - .filter = - { - .width = 3, - .height = 5, - }, - }, - .filter_strides = {.width = 1, .height = 1}, - .filter_dilation = {.width = 1, .height = 1}, - .input_left_pad = {.width = 0, .height = 0}, - .input_right_pad = {.width = 0, .height = 0}, - .a_elementwise_op = {}, - .b_elementwise_op = {}, - .cde_elementwise_op = {}, - }; +// ckt::Args args = { +// .lengths = +// { +// .batch_size = 16, +// .groups = 1, +// .input_channels = 32, +// .output_channels = 48, +// .image = +// { +// .width = 56, +// .height = 64, +// }, +// .filter = +// { +// .width = 3, +// .height = 5, +// }, +// }, +// .filter_strides = {.width = 1, .height = 1}, +// .filter_dilation = {.width = 1, .height = 1}, +// .input_left_pad = {.width = 0, .height = 0}, +// .input_right_pad = {.width = 0, .height = 0}, +// .a_elementwise_op = {}, +// .b_elementwise_op = {}, +// .cde_elementwise_op = {}, +// }; - auto inputs = alloc_inputs(args); - auto outputs = alloc_outputs(args); +// auto inputs = alloc_inputs(args); +// auto outputs = alloc_outputs(args); - auto conv = Instance{}; - ckt::run(conv, args, inputs.get(), outputs.get()); -} +// auto conv = Instance{}; +// ckt::run(conv, args, inputs.get(), outputs.get()); +// } diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index 3e861db14c0..46babfa79a1 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -38,20 +38,22 @@ struct XdlParams static_assert(ckb::GridwiseXdlGemmDescriptor); // Describe gridwise XDL GEMM parameters. -struct GridwiseFwdXdlGemm : public XdlParams +struct GridwiseFwdXdlGemm { // NOTE: ak1 and bk1 are difficult to verify in the kernel instantiation!!! size_t ak1 = 0; size_t bk1 = 0; + XdlParams xdl_params; }; static_assert(ckb::SpecifiesGridwiseFwdXdlGemm); -struct GridwiseBwdXdlGemm : public XdlParams +struct GridwiseBwdXdlGemm { size_t k0_per_block = 0; size_t k1 = 0; + XdlParams xdl_params; }; -static_assert(ckb::SpecifiesGridwiseBwdXdlGemm); +static_assert(ckb::SpecifiesGridwiseBwdXdlGemm); // Describe gridwise WMMA GEMM parameters. struct GridwiseWmmaGemm @@ -384,7 +386,7 @@ struct ConvAlgorithmTemplate : Components... constexpr auto with_fwd_specializations(ConvSpecialization fwd_spec, GemmSpecialization gemm_spec) const { - static_assert(std::is_base_of_v); + static_assert(std::is_base_of_v); auto result = *this; result.fwd_specialization = fwd_spec; result.gemm_specialization = gemm_spec; diff --git a/experimental/builder/test/utils/ckb_conv_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_test_configs.hpp index f7c748b7c7a..8e5963b22ea 100644 --- a/experimental/builder/test/utils/ckb_conv_test_configs.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_configs.hpp @@ -174,23 +174,23 @@ constexpr TransferABC Transfer_4x32x1{ constexpr GridwiseBwdXdlGemm BwdGemmParams_Xdl_4x4_per_wave{ .k0_per_block = 8, .k1 = 8, - {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}}; + .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}}; constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_4x4_per_wave{ .ak1 = 8, .bk1 = 8, - {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}}; + .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}}; constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_4x2_per_wave{ .ak1 = 8, .bk1 = 8, - {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 2}}; + .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 2}}; constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_2x2_per_wave{ .ak1 = 8, .bk1 = 8, - {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2}}; + .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2}}; constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_2x1_per_wave{ .ak1 = 8, .bk1 = 8, - {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 1}}; + .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 1}}; constexpr GridwiseWmmaGemm FwdGemmParams_Wmma_2x1_per_wave{.k1 = 8, .m_per_wmma = 32, diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index e4db149a988..6c1d9ae15f9 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -54,7 +54,7 @@ inline std::string to_string(PipelineScheduler t) } template <> -inline std::string to_string(ConvFwdSpecialization t) +inline std::string to_string(ConvSpecialization t) { std::ostringstream oss; oss << t; @@ -86,11 +86,20 @@ inline std::string to_string(ThreadBlock t) } template <> -inline std::string to_string(GridwiseXdlGemm t) +inline std::string to_string(GridwiseBwdXdlGemm t) { std::ostringstream oss; - oss << t.ak1 << "," << t.bk1 << "," << t.m_per_xdl << "," << t.n_per_xdl << "," - << t.m_xdl_per_wave << "," << t.n_xdl_per_wave; + oss << t.k0_per_block << "," << t.k1 << "," << t.xdl_params.m_per_xdl << "," << t.xdl_params.n_per_xdl << "," + << t.xdl_params.m_xdl_per_wave << "," << t.xdl_params.n_xdl_per_wave; + return oss.str(); +} + +template <> +inline std::string to_string(GridwiseFwdXdlGemm t) +{ + std::ostringstream oss; + oss << t.ak1 << "," << t.bk1 << "," << t.xdl_params.m_per_xdl << "," << t.xdl_params.n_per_xdl << "," + << t.xdl_params.m_xdl_per_wave << "," << t.xdl_params.n_xdl_per_wave; return oss.str(); } @@ -234,7 +243,13 @@ inline std::string to_string(ThreadBlock_ t) } template <> -inline std::string to_string(XdlGemm_ t) +inline std::string to_string(FwdXdlGemm_ t) +{ + return to_string(t.gridwise_gemm); +} + +template <> +inline std::string to_string(BwdXdlGemm_ t) { return to_string(t.gridwise_gemm); } @@ -252,13 +267,21 @@ inline std::string to_string(Transfer_ t) } template <> -inline std::string to_string(ConvSpecialization_ t) +inline std::string to_string(ConvSpecializationFwd_ t) { std::ostringstream oss; oss << to_string(t.fwd_specialization) << "," << to_string(t.gemm_specialization); return oss.str(); } +template <> +inline std::string to_string(ConvSpecializationBwdWeight_ t) +{ + std::ostringstream oss; + oss << to_string(t.bwd_specialization); + return oss.str(); +} + template <> inline std::string to_string(Prefetch_ t) { @@ -299,7 +322,7 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) << "," << to_string(static_cast(t)); return oss.str(); } @@ -309,7 +332,7 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) << "," << to_string(static_cast(t)); return oss.str(); } @@ -343,4 +366,14 @@ inline std::string to_string +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast(t)); + return oss.str(); +} + } // namespace ck_tile::builder::test From 4d5b5b7ef38280264479729cb80e7323f3b833c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Mon, 22 Dec 2025 07:12:46 -0500 Subject: [PATCH 05/81] Improve compile time erros message when no matching factory is found. --- .../builder/factory/conv_algorithms.hpp | 318 ++++++++++++++++++ .../builder/factory/conv_dispatcher.hpp | 80 +---- experimental/builder/test/CMakeLists.txt | 1 + 3 files changed, 331 insertions(+), 68 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp new file mode 100644 index 00000000000..a192a34df15 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -0,0 +1,318 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/conv_algorithm_concepts.hpp" + +namespace ck_tile::builder::factory { + +#define CHECK_MARK(cond) (cond ? "[✓]" : "[✗]") + +template +struct FwdXdlV3Algorithm { + static constexpr bool c1 = ConvAlgorithmDescriptor; + static constexpr bool c2 = SpecifiesThreadBlock; + static constexpr bool c3 = SpecifiesBlockTransfer; + static constexpr bool c4 = SpecifiesLdsTransfer; + static constexpr bool c5 = SpecifiesThreadClusterAccessOrder; + static constexpr bool c6 = SpecifiesSourceAccessOrder; + static constexpr bool c7 = SpecifiesGridwiseFwdXdlGemm; + static constexpr bool c8 = SpecifiesFwdConvSpecialization; + static constexpr bool c9 = SpecifiesGemmSpecialization; + static constexpr bool c10 = SpecifiesBlockGemm; + + static consteval bool is_valid() { + return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && c10; + } + + static consteval const std::string message() { + return "\n=== Forward XDL V3 Algorithm Diagnostic (closest match) ===\n" + "Concepts for FwdXdlV3 Algorithm:\n" + " ConvAlgorithmDescriptor: " + std::string(CHECK_MARK(c1)) + "\n" + " SpecifiesThreadBlock: " + std::string(CHECK_MARK(c2)) + "\n" + " SpecifiesBlockTransfer: " + std::string(CHECK_MARK(c3)) + "\n" + " SpecifiesLdsTransfer: " + std::string(CHECK_MARK(c4)) + "\n" + " SpecifiesThreadClusterAccessOrder: " + std::string(CHECK_MARK(c5)) + "\n" + " SpecifiesSourceAccessOrder: " + std::string(CHECK_MARK(c6)) + "\n" + " SpecifiesGridwiseFwdXdlGemm: " + std::string(CHECK_MARK(c7)) + "\n" + " SpecifiesFwdConvSpecialization: " + std::string(CHECK_MARK(c8)) + "\n" + " SpecifiesGemmSpecialization: " + std::string(CHECK_MARK(c9)) + "\n" + " SpecifiesBlockGemm: " + std::string(CHECK_MARK(c10)) + "\n"; + } +}; + +template +struct FwdXdlAlgorithm { + static constexpr bool c1 = ConvAlgorithmDescriptor; + static constexpr bool c2 = SpecifiesThreadBlock; + static constexpr bool c3 = SpecifiesBlockTransfer; + static constexpr bool c4 = SpecifiesLdsTransfer; + static constexpr bool c5 = SpecifiesThreadClusterAccessOrder; + static constexpr bool c6 = SpecifiesSourceAccessOrder; + static constexpr bool c7 = SpecifiesGridwiseFwdXdlGemm; + static constexpr bool c8 = SpecifiesFwdConvSpecialization; + static constexpr bool c9 = SpecifiesGemmSpecialization; + static constexpr bool c10 = SpecifiesNumPrefetchStages; + static constexpr bool c11 = SpecifiesNumGroupsToMerge; + static constexpr bool c12 = SpecifiesLoopScheduler; + + static consteval bool is_valid() { + return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && c10 && c11 && c12; + } + + static consteval const std::string message() { + return "\n=== Forward XDL Algorithm Diagnostic (closest match) ===\n" + "Concepts for FwdXdl Algorithm:\n" + " ConvAlgorithmDescriptor: " + std::string(CHECK_MARK(c1)) + "\n" + " SpecifiesThreadBlock: " + std::string(CHECK_MARK(c2)) + "\n" + " SpecifiesBlockTransfer: " + std::string(CHECK_MARK(c3)) + "\n" + " SpecifiesLdsTransfer: " + std::string(CHECK_MARK(c4)) + "\n" + " SpecifiesThreadClusterAccessOrder: " + std::string(CHECK_MARK(c5)) + "\n" + " SpecifiesSourceAccessOrder: " + std::string(CHECK_MARK(c6)) + "\n" + " SpecifiesGridwiseFwdXdlGemm: " + std::string(CHECK_MARK(c7)) + "\n" + " SpecifiesFwdConvSpecialization: " + std::string(CHECK_MARK(c8)) + "\n" + " SpecifiesGemmSpecialization: " + std::string(CHECK_MARK(c9)) + "\n" + " SpecifiesNumPrefetchStages: " + std::string(CHECK_MARK(c10)) + "\n" + " SpecifiesNumGroupsToMerge: " + std::string(CHECK_MARK(c11)) + "\n" + " SpecifiesLoopScheduler: " + std::string(CHECK_MARK(c12)) + "\n"; + } +}; + +template +struct FwdWmmaAlgorithm { + static constexpr bool c1 = ConvAlgorithmDescriptor; + static constexpr bool c2 = SpecifiesThreadBlock; + static constexpr bool c3 = SpecifiesBlockTransfer; + static constexpr bool c4 = SpecifiesLdsTransfer; + static constexpr bool c5 = SpecifiesThreadClusterAccessOrder; + static constexpr bool c6 = SpecifiesSourceAccessOrder; + static constexpr bool c7 = SpecifiesGridwiseWmmaGemm; + static constexpr bool c8 = SpecifiesFwdConvSpecialization; + static constexpr bool c9 = SpecifiesGemmSpecialization; + static constexpr bool c10 = SpecifiesNumPrefetchStages; + static constexpr bool c11 = SpecifiesLoopScheduler; + + static consteval bool is_valid() { + return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && c10 && c11; + } + + static consteval const std::string message() { + return "\n=== Forward WMMA Algorithm Diagnostic (closest match) ===\n" + "Concepts for FwdWmma Algorithm:\n" + " ConvAlgorithmDescriptor: " + std::string(CHECK_MARK(c1)) + "\n" + " SpecifiesThreadBlock: " + std::string(CHECK_MARK(c2)) + "\n" + " SpecifiesBlockTransfer: " + std::string(CHECK_MARK(c3)) + "\n" + " SpecifiesLdsTransfer: " + std::string(CHECK_MARK(c4)) + "\n" + " SpecifiesThreadClusterAccessOrder: " + std::string(CHECK_MARK(c5)) + "\n" + " SpecifiesSourceAccessOrder: " + std::string(CHECK_MARK(c6)) + "\n" + " SpecifiesGridwiseWmmaGemm: " + std::string(CHECK_MARK(c7)) + "\n" + " SpecifiesFwdConvSpecialization: " + std::string(CHECK_MARK(c8)) + "\n" + " SpecifiesGemmSpecialization: " + std::string(CHECK_MARK(c9)) + "\n" + " SpecifiesNumPrefetchStages: " + std::string(CHECK_MARK(c10)) + "\n" + " SpecifiesLoopScheduler: " + std::string(CHECK_MARK(c11)) + "\n"; + } +}; + +template +struct FwdDlAlgorithm { + static constexpr bool c1 = ConvAlgorithmDescriptor; + static constexpr bool c2 = SpecifiesThreadBlock; + static constexpr bool c3 = SpecifiesFwdConvSpecialization; + static constexpr bool c4 = SpecifiesGemmSpecialization; + static constexpr bool c5 = SpecifiesDlThreadConfig; + static constexpr bool c6 = SpecifiesDlThreadCluster; + static constexpr bool c7 = SpecifiesDlBlockTransfer; + static constexpr bool c8 = SpecifiesDlEpilogue; + + static consteval bool is_valid() { + return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8; + } + + static consteval const std::string message() { + return "\n=== Forward DL Algorithm Diagnostic (closest match) ===\n" + "Concepts for FwdDl Algorithm:\n" + " ConvAlgorithmDescriptor: " + std::string(CHECK_MARK(c1)) + "\n" + " SpecifiesThreadBlock: " + std::string(CHECK_MARK(c2)) + "\n" + " SpecifiesFwdConvSpecialization: " + std::string(CHECK_MARK(c3)) + "\n" + " SpecifiesGemmSpecialization: " + std::string(CHECK_MARK(c4)) + "\n" + " SpecifiesDlThreadConfig: " + std::string(CHECK_MARK(c5)) + "\n" + " SpecifiesDlThreadCluster: " + std::string(CHECK_MARK(c6)) + "\n" + " SpecifiesDlBlockTransfer: " + std::string(CHECK_MARK(c7)) + "\n" + " SpecifiesDlEpilogue: " + std::string(CHECK_MARK(c8)) + "\n"; + } +}; + +template +struct TileAlgorithm { + static constexpr bool c1 = ConvAlgorithmDescriptor; + static constexpr bool c2 = SpecifiesTileThreadBlock; + static constexpr bool c3 = SpecifiesTileTransfer; + static constexpr bool c4 = SpecifiesTileConvSpecialization; + static constexpr bool c5 = SpecifiesTileBlockGemm; + static constexpr bool c6 = SpecifiesTileOptimizations; + + static consteval bool is_valid() { + return c1 && c2 && c3 && c4 && c5 && c6; + } + + static consteval const std::string message() { + return "\n=== CK Tile Algorithm Diagnostic (closest match) ===\n" + "Concepts for CK Tile Conv Algorithm:\n" + " ConvAlgorithmDescriptor: " + std::string(CHECK_MARK(c1)) + "\n" + " SpecifiesTileThreadBlock: " + std::string(CHECK_MARK(c2)) + "\n" + " SpecifiesTileTransfer: " + std::string(CHECK_MARK(c3)) + "\n" + " SpecifiesTileConvSpecialization: " + std::string(CHECK_MARK(c4)) + "\n" + " SpecifiesTileBlockGemm: " + std::string(CHECK_MARK(c5)) + "\n" + " SpecifiesTileOptimizations: " + std::string(CHECK_MARK(c6)) + "\n"; + } +}; + +template +struct LargeTensorAlgorithm : public FwdXdlAlgorithm +{ + using BaseAlgorithmType = decltype(T::base_algorithm); + static constexpr bool c13 = SpecifiesLargeTensorSupport; + + static consteval bool is_valid() { + return FwdXdlAlgorithm::is_valid() && c13; + } + + static consteval const std::string message() { + return FwdXdlAlgorithm::message() + + " SpecifiesLargeTensorSupport: " + std::string(CHECK_MARK(c13)) + "\n"; + } +}; + +template +struct BwdXdlAlgorithm { + static constexpr bool c1 = ConvAlgorithmDescriptor; + static constexpr bool c2 = SpecifiesThreadBlock; + static constexpr bool c3 = SpecifiesBlockTransfer; + static constexpr bool c4 = SpecifiesLdsTransfer; + static constexpr bool c5 = SpecifiesThreadClusterAccessOrder; + static constexpr bool c6 = SpecifiesSourceAccessOrder; + static constexpr bool c7 = SpecifiesGridwiseBwdXdlGemm; + static constexpr bool c8 = SpecifiesBwdWeightConvSpecialization; + static constexpr bool c9 = SpecifiesTransposeTransfer; + + static consteval bool is_valid() { + return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9; + } + + static consteval const std::string message() { + return "\n=== Backward XDL Algorithm Diagnostic (closest match) ===\n" + "Concepts for BwdXdl Algorithm:\n" + " ConvAlgorithmDescriptor: " + std::string(CHECK_MARK(c1)) + "\n" + " SpecifiesThreadBlock: " + std::string(CHECK_MARK(c2)) + "\n" + " SpecifiesBlockTransfer: " + std::string(CHECK_MARK(c3)) + "\n" + " SpecifiesLdsTransfer: " + std::string(CHECK_MARK(c4)) + "\n" + " SpecifiesThreadClusterAccessOrder: " + std::string(CHECK_MARK(c5)) + "\n" + " SpecifiesSourceAccessOrder: " + std::string(CHECK_MARK(c6)) + "\n" + " SpecifiesGridwiseBwdXdlGemm: " + std::string(CHECK_MARK(c7)) + "\n" + " SpecifiesBwdWeightConvSpecialization: " + std::string(CHECK_MARK(c8)) + "\n" + " SpecifiesTransposeTransfer: " + std::string(CHECK_MARK(c9)) + "\n"; + } +}; + +template +consteval int count_matches_fwd_xdl_v3() { + using Alg = FwdXdlV3Algorithm; + return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + Alg::c10; +} + +template +consteval int count_matches_fwd_xdl() { + using Alg = FwdXdlAlgorithm; + return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + Alg::c10 + Alg::c11 + Alg::c12; +} + +template +consteval int count_matches_fwd_wmma() { + using Alg = FwdWmmaAlgorithm; + return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + Alg::c10 + Alg::c11; +} + +template +consteval int count_matches_fwd_dl() { + using Alg = FwdDlAlgorithm; + return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8; +} + +template +consteval int count_matches_bwd_xdl() { + using Alg = BwdXdlAlgorithm; + return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9; +} + +template +consteval int count_matches_large_tensor() { + using Alg = LargeTensorAlgorithm; + return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + Alg::c10 + Alg::c11 + Alg::c12 + Alg::c13; +} + +template +consteval int count_matches_tile() { + using Alg = TileAlgorithm; + return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6; +} + +template +consteval void diagnose_fwd_algorithm_signature() +{ + // Find closest matching variant + constexpr int xdl_v3_matches = count_matches_fwd_xdl_v3(); + constexpr int xdl_matches = count_matches_fwd_xdl(); + constexpr int wmma_matches = count_matches_fwd_wmma(); + constexpr int dl_matches = count_matches_fwd_dl(); + constexpr int large_tensor_matches = count_matches_large_tensor(); + constexpr int tile_matches = count_matches_tile(); + + // Find maximum matches across all variants + constexpr int max_1 = xdl_v3_matches > xdl_matches ? xdl_v3_matches : xdl_matches; + constexpr int max_2 = wmma_matches > dl_matches ? wmma_matches : dl_matches; + constexpr int max_3 = max_1 > max_2 ? max_1 : max_2; + constexpr int max_4 = max_3 > large_tensor_matches ? max_3 : large_tensor_matches; + constexpr int max_matches = max_4 > tile_matches ? max_4 : tile_matches; + + // Generate detailed diagnostic for the closest match + if constexpr(max_matches == xdl_v3_matches) { + using Alg = FwdXdlV3Algorithm; + static_assert(Alg::is_valid(), Alg::message()); + } else if constexpr(max_matches == xdl_matches) { + using Alg = FwdXdlAlgorithm; + static_assert(Alg::is_valid(), Alg::message()); + } else if constexpr(max_matches == wmma_matches) { + using Alg = FwdWmmaAlgorithm; + static_assert(Alg::is_valid(), Alg::message()); + } else if constexpr(max_matches == dl_matches) { + using Alg = FwdDlAlgorithm; + static_assert(Alg::is_valid(), Alg::message()); + } else if constexpr (max_matches == large_tensor_matches) { + using Alg = LargeTensorAlgorithm; + static_assert(Alg::is_valid(), Alg::message()); + } else if constexpr (max_matches == tile_matches) { + using Alg = TileAlgorithm; + static_assert(Alg::is_valid(), Alg::message()); + } + else { + // This should never happen + static_assert(false, "Internal Error: No matching algorithm variant found for diagnostics."); + } +} + +template +consteval void diagnose_bwd_weight_algorithm_signature() +{ + constexpr int xdl_matches = count_matches_fwd_xdl(); + constexpr int max_matches = xdl_matches; + if constexpr (max_matches == xdl_matches) { + using Alg = BwdXdlAlgorithm; + static_assert(Alg::is_valid(), Alg::message()); + } else { + // This should never happen + static_assert(false, "Internal Error: No matching algorithm variant found for diagnostics."); + } +} + +} diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp index db2f4cc3292..adbf12992e5 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -53,6 +53,9 @@ #include "ck_tile/builder/conv_algorithm_concepts.hpp" #include "ck_tile/builder/types.hpp" +// Compile time diagnostics +#include "ck_tile/builder/factory/conv_algorithms.hpp" + // Include all factory implementations #include "ck_tile/builder/factory/conv_fwd_v3_factory.hpp" #include "ck_tile/builder/factory/conv_fwd_xdl_factory.hpp" @@ -83,58 +86,6 @@ namespace ck_tile::builder::factory { // // TODO: Make this dispatch logic much more robust and clear for users. -// CK Tile kernel -template -concept IsTileAlgorithm = ConvAlgorithmDescriptor && SpecifiesTileThreadBlock && - SpecifiesTileTransfer && SpecifiesTileConvSpecialization && - SpecifiesTileBlockGemm && SpecifiesTileOptimizations; - -template -concept SpecifiesDataTransfer = - SpecifiesThreadBlock && SpecifiesBlockTransfer && - SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && - SpecifiesSourceAccessOrder; - -// XDL-based kernel with V3 pipeline structure (newer block GEMM pipeline) -template -concept IsFwdXdlV3Algorithm = ConvAlgorithmDescriptor && - SpecifiesDataTransfer && SpecifiesGridwiseFwdXdlGemm && - SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && - SpecifiesBlockGemm; - -// Standard XDL-based fwd kernel (uses XDLops hardware instructions for matrix multiply) -template -concept IsFwdXdlAlgorithm = ConvAlgorithmDescriptor && - SpecifiesDataTransfer && SpecifiesGridwiseFwdXdlGemm && - SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && - SpecifiesNumPrefetchStages && SpecifiesNumGroupsToMerge && - SpecifiesLoopScheduler; - -// Standard XDL-based bwd weight kernel (uses XDLops hardware instructions for matrix multiply) -template -concept IsBwdXdlAlgorithm = ConvAlgorithmDescriptor && - SpecifiesDataTransfer && SpecifiesGridwiseBwdXdlGemm && - SpecifiesBwdWeightConvSpecialization && SpecifiesTransposeTransfer; - -// WMMA-based kernel (uses Wavefront Matrix-Matrix Accumulate instructions) -template -concept IsFwdWmmaAlgorithm = ConvAlgorithmDescriptor && - SpecifiesDataTransfer && SpecifiesGridwiseWmmaGemm && - SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && - SpecifiesNumPrefetchStages && SpecifiesLoopScheduler; - -// Specialized DL kernel for specific NHWC/KYXC/NHWK data layouts -template -concept IsFwdDlAlgorithm = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesFwdConvSpecialization && - SpecifiesGemmSpecialization && SpecifiesDlThreadConfig && SpecifiesDlThreadCluster && - SpecifiesDlBlockTransfer && SpecifiesDlEpilogue; - -// XDL-based kernel with large tensor support -template -concept IsLargeTensorAlgorithm = - IsFwdXdlAlgorithm && SpecifiesLargeTensorSupport; - template @@ -143,39 +94,35 @@ constexpr auto make_conv_instance() using AlgoType = std::remove_const_t; // CK Tile supports common factory for each direction - if constexpr(IsTileAlgorithm) + if constexpr(TileAlgorithm::is_valid()) { return typename ConvTileFactory::Instance{}; } else if constexpr(ConvDirectionIsForward) { - if constexpr(IsFwdXdlV3Algorithm) + if constexpr(FwdXdlV3Algorithm::is_valid()) { return typename ConvFwdXdlV3Factory::Instance{}; } - else if constexpr(IsFwdXdlAlgorithm) + else if constexpr(FwdXdlAlgorithm::is_valid()) { return typename ConvFwdXdlFactory::Instance{}; } - else if constexpr(IsFwdWmmaAlgorithm) + else if constexpr(FwdWmmaAlgorithm::is_valid()) { return typename ConvFwdWmmaFactory::Instance{}; } - else if constexpr(IsFwdDlAlgorithm) + else if constexpr(FwdDlAlgorithm::is_valid()) { return typename ConvFwdDlFactory::Instance{}; } - else if constexpr(IsLargeTensorAlgorithm) + else if constexpr(LargeTensorAlgorithm::is_valid()) { return typename ConvFwdLargeTensorFactory::Instance{}; } else { - static_assert( - false, - "No suitable forward convolution kernel factory found for the provided ALGORITHM. " - "The ALGORITHM must satisfy requirements for one of: XDL V3, XDL, WMMA, DL (NHWC " - "layout), or Large Tensor variant."); + diagnose_fwd_algorithm_signature(); } } else if constexpr(ConvDirectionIsBackwardData) @@ -186,16 +133,13 @@ constexpr auto make_conv_instance() } else if constexpr(ConvDirectionIsBackwardWeight) { - if constexpr (IsBwdXdlAlgorithm) + if constexpr (BwdXdlAlgorithm::is_valid()) { return typename ConvBwdWeightXdlFactory::Instance{}; } else { - static_assert( - false, - "No suitable forward convolution kernel factory found for the provided ALGORITHM. " - "The ALGORITHM must satisfy requirements for XDL variant."); + diagnose_bwd_weight_algorithm_signature(); } } else diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 43a142a2798..d0f645e9059 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -54,6 +54,7 @@ function(add_ck_builder_test test_name) target_compile_options(${test_name} PRIVATE -Wno-global-constructors -Wno-c++20-compat + -Wno-c++26-extensions # Allow C++26 extensions for better compile-time diagnostics ) target_link_libraries(${test_name} PRIVATE GTest::gtest_main GTest::gmock) endfunction() From 4d20cc6b4df097af54d3faef1f7f12ceb2814ba4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Mon, 22 Dec 2025 07:36:13 -0500 Subject: [PATCH 06/81] Use amcro to ensure automatic macthing between concepts are their string representations. --- .../builder/factory/conv_algorithms.hpp | 327 +++++++++++------- 1 file changed, 200 insertions(+), 127 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index a192a34df15..f96fc2c86be 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -9,162 +9,223 @@ namespace ck_tile::builder::factory { #define CHECK_MARK(cond) (cond ? "[✓]" : "[✗]") +// Macro to check a concept and generate both the boolean and the string representation +#define CHECK_CONCEPT(Type, Concept) \ + static constexpr bool c_##Concept = Concept; \ + static constexpr const char* s_##Concept = #Concept; + +// Helper to create diagnostic message line +#define DIAGNOSTIC_LINE(Concept) \ + " " + std::string(s_##Concept) + ": " + std::string(CHECK_MARK(c_##Concept)) + "\n" + template struct FwdXdlV3Algorithm { - static constexpr bool c1 = ConvAlgorithmDescriptor; - static constexpr bool c2 = SpecifiesThreadBlock; - static constexpr bool c3 = SpecifiesBlockTransfer; - static constexpr bool c4 = SpecifiesLdsTransfer; - static constexpr bool c5 = SpecifiesThreadClusterAccessOrder; - static constexpr bool c6 = SpecifiesSourceAccessOrder; - static constexpr bool c7 = SpecifiesGridwiseFwdXdlGemm; - static constexpr bool c8 = SpecifiesFwdConvSpecialization; - static constexpr bool c9 = SpecifiesGemmSpecialization; - static constexpr bool c10 = SpecifiesBlockGemm; + CHECK_CONCEPT(T, ConvAlgorithmDescriptor) + CHECK_CONCEPT(T, SpecifiesThreadBlock) + CHECK_CONCEPT(T, SpecifiesBlockTransfer) + CHECK_CONCEPT(T, SpecifiesLdsTransfer) + CHECK_CONCEPT(T, SpecifiesThreadClusterAccessOrder) + CHECK_CONCEPT(T, SpecifiesSourceAccessOrder) + CHECK_CONCEPT(T, SpecifiesGridwiseFwdXdlGemm) + CHECK_CONCEPT(T, SpecifiesFwdConvSpecialization) + CHECK_CONCEPT(T, SpecifiesGemmSpecialization) + CHECK_CONCEPT(T, SpecifiesBlockGemm) + + static constexpr bool c1 = c_ConvAlgorithmDescriptor; + static constexpr bool c2 = c_SpecifiesThreadBlock; + static constexpr bool c3 = c_SpecifiesBlockTransfer; + static constexpr bool c4 = c_SpecifiesLdsTransfer; + static constexpr bool c5 = c_SpecifiesThreadClusterAccessOrder; + static constexpr bool c6 = c_SpecifiesSourceAccessOrder; + static constexpr bool c7 = c_SpecifiesGridwiseFwdXdlGemm; + static constexpr bool c8 = c_SpecifiesFwdConvSpecialization; + static constexpr bool c9 = c_SpecifiesGemmSpecialization; + static constexpr bool c10 = c_SpecifiesBlockGemm; static consteval bool is_valid() { return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && c10; } static consteval const std::string message() { - return "\n=== Forward XDL V3 Algorithm Diagnostic (closest match) ===\n" - "Concepts for FwdXdlV3 Algorithm:\n" - " ConvAlgorithmDescriptor: " + std::string(CHECK_MARK(c1)) + "\n" - " SpecifiesThreadBlock: " + std::string(CHECK_MARK(c2)) + "\n" - " SpecifiesBlockTransfer: " + std::string(CHECK_MARK(c3)) + "\n" - " SpecifiesLdsTransfer: " + std::string(CHECK_MARK(c4)) + "\n" - " SpecifiesThreadClusterAccessOrder: " + std::string(CHECK_MARK(c5)) + "\n" - " SpecifiesSourceAccessOrder: " + std::string(CHECK_MARK(c6)) + "\n" - " SpecifiesGridwiseFwdXdlGemm: " + std::string(CHECK_MARK(c7)) + "\n" - " SpecifiesFwdConvSpecialization: " + std::string(CHECK_MARK(c8)) + "\n" - " SpecifiesGemmSpecialization: " + std::string(CHECK_MARK(c9)) + "\n" - " SpecifiesBlockGemm: " + std::string(CHECK_MARK(c10)) + "\n"; + return std::string("\n=== Forward XDL V3 Algorithm Diagnostic (closest match) ===\n" + "Concepts for FwdXdlV3 Algorithm:\n") + + DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + + DIAGNOSTIC_LINE(SpecifiesThreadBlock) + + DIAGNOSTIC_LINE(SpecifiesBlockTransfer) + + DIAGNOSTIC_LINE(SpecifiesLdsTransfer) + + DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) + + DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) + + DIAGNOSTIC_LINE(SpecifiesGridwiseFwdXdlGemm) + + DIAGNOSTIC_LINE(SpecifiesFwdConvSpecialization) + + DIAGNOSTIC_LINE(SpecifiesGemmSpecialization) + + DIAGNOSTIC_LINE(SpecifiesBlockGemm); } }; template struct FwdXdlAlgorithm { - static constexpr bool c1 = ConvAlgorithmDescriptor; - static constexpr bool c2 = SpecifiesThreadBlock; - static constexpr bool c3 = SpecifiesBlockTransfer; - static constexpr bool c4 = SpecifiesLdsTransfer; - static constexpr bool c5 = SpecifiesThreadClusterAccessOrder; - static constexpr bool c6 = SpecifiesSourceAccessOrder; - static constexpr bool c7 = SpecifiesGridwiseFwdXdlGemm; - static constexpr bool c8 = SpecifiesFwdConvSpecialization; - static constexpr bool c9 = SpecifiesGemmSpecialization; - static constexpr bool c10 = SpecifiesNumPrefetchStages; - static constexpr bool c11 = SpecifiesNumGroupsToMerge; - static constexpr bool c12 = SpecifiesLoopScheduler; + CHECK_CONCEPT(T, ConvAlgorithmDescriptor) + CHECK_CONCEPT(T, SpecifiesThreadBlock) + CHECK_CONCEPT(T, SpecifiesBlockTransfer) + CHECK_CONCEPT(T, SpecifiesLdsTransfer) + CHECK_CONCEPT(T, SpecifiesThreadClusterAccessOrder) + CHECK_CONCEPT(T, SpecifiesSourceAccessOrder) + CHECK_CONCEPT(T, SpecifiesGridwiseFwdXdlGemm) + CHECK_CONCEPT(T, SpecifiesFwdConvSpecialization) + CHECK_CONCEPT(T, SpecifiesGemmSpecialization) + CHECK_CONCEPT(T, SpecifiesNumPrefetchStages) + CHECK_CONCEPT(T, SpecifiesNumGroupsToMerge) + CHECK_CONCEPT(T, SpecifiesLoopScheduler) + + static constexpr bool c1 = c_ConvAlgorithmDescriptor; + static constexpr bool c2 = c_SpecifiesThreadBlock; + static constexpr bool c3 = c_SpecifiesBlockTransfer; + static constexpr bool c4 = c_SpecifiesLdsTransfer; + static constexpr bool c5 = c_SpecifiesThreadClusterAccessOrder; + static constexpr bool c6 = c_SpecifiesSourceAccessOrder; + static constexpr bool c7 = c_SpecifiesGridwiseFwdXdlGemm; + static constexpr bool c8 = c_SpecifiesFwdConvSpecialization; + static constexpr bool c9 = c_SpecifiesGemmSpecialization; + static constexpr bool c10 = c_SpecifiesNumPrefetchStages; + static constexpr bool c11 = c_SpecifiesNumGroupsToMerge; + static constexpr bool c12 = c_SpecifiesLoopScheduler; static consteval bool is_valid() { return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && c10 && c11 && c12; } static consteval const std::string message() { - return "\n=== Forward XDL Algorithm Diagnostic (closest match) ===\n" - "Concepts for FwdXdl Algorithm:\n" - " ConvAlgorithmDescriptor: " + std::string(CHECK_MARK(c1)) + "\n" - " SpecifiesThreadBlock: " + std::string(CHECK_MARK(c2)) + "\n" - " SpecifiesBlockTransfer: " + std::string(CHECK_MARK(c3)) + "\n" - " SpecifiesLdsTransfer: " + std::string(CHECK_MARK(c4)) + "\n" - " SpecifiesThreadClusterAccessOrder: " + std::string(CHECK_MARK(c5)) + "\n" - " SpecifiesSourceAccessOrder: " + std::string(CHECK_MARK(c6)) + "\n" - " SpecifiesGridwiseFwdXdlGemm: " + std::string(CHECK_MARK(c7)) + "\n" - " SpecifiesFwdConvSpecialization: " + std::string(CHECK_MARK(c8)) + "\n" - " SpecifiesGemmSpecialization: " + std::string(CHECK_MARK(c9)) + "\n" - " SpecifiesNumPrefetchStages: " + std::string(CHECK_MARK(c10)) + "\n" - " SpecifiesNumGroupsToMerge: " + std::string(CHECK_MARK(c11)) + "\n" - " SpecifiesLoopScheduler: " + std::string(CHECK_MARK(c12)) + "\n"; + return std::string("\n=== Forward XDL Algorithm Diagnostic (closest match) ===\n" + "Concepts for FwdXdl Algorithm:\n") + + DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + + DIAGNOSTIC_LINE(SpecifiesThreadBlock) + + DIAGNOSTIC_LINE(SpecifiesBlockTransfer) + + DIAGNOSTIC_LINE(SpecifiesLdsTransfer) + + DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) + + DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) + + DIAGNOSTIC_LINE(SpecifiesGridwiseFwdXdlGemm) + + DIAGNOSTIC_LINE(SpecifiesFwdConvSpecialization) + + DIAGNOSTIC_LINE(SpecifiesGemmSpecialization) + + DIAGNOSTIC_LINE(SpecifiesNumPrefetchStages) + + DIAGNOSTIC_LINE(SpecifiesNumGroupsToMerge) + + DIAGNOSTIC_LINE(SpecifiesLoopScheduler); } }; template struct FwdWmmaAlgorithm { - static constexpr bool c1 = ConvAlgorithmDescriptor; - static constexpr bool c2 = SpecifiesThreadBlock; - static constexpr bool c3 = SpecifiesBlockTransfer; - static constexpr bool c4 = SpecifiesLdsTransfer; - static constexpr bool c5 = SpecifiesThreadClusterAccessOrder; - static constexpr bool c6 = SpecifiesSourceAccessOrder; - static constexpr bool c7 = SpecifiesGridwiseWmmaGemm; - static constexpr bool c8 = SpecifiesFwdConvSpecialization; - static constexpr bool c9 = SpecifiesGemmSpecialization; - static constexpr bool c10 = SpecifiesNumPrefetchStages; - static constexpr bool c11 = SpecifiesLoopScheduler; + CHECK_CONCEPT(T, ConvAlgorithmDescriptor) + CHECK_CONCEPT(T, SpecifiesThreadBlock) + CHECK_CONCEPT(T, SpecifiesBlockTransfer) + CHECK_CONCEPT(T, SpecifiesLdsTransfer) + CHECK_CONCEPT(T, SpecifiesThreadClusterAccessOrder) + CHECK_CONCEPT(T, SpecifiesSourceAccessOrder) + CHECK_CONCEPT(T, SpecifiesGridwiseWmmaGemm) + CHECK_CONCEPT(T, SpecifiesFwdConvSpecialization) + CHECK_CONCEPT(T, SpecifiesGemmSpecialization) + CHECK_CONCEPT(T, SpecifiesNumPrefetchStages) + CHECK_CONCEPT(T, SpecifiesLoopScheduler) + + static constexpr bool c1 = c_ConvAlgorithmDescriptor; + static constexpr bool c2 = c_SpecifiesThreadBlock; + static constexpr bool c3 = c_SpecifiesBlockTransfer; + static constexpr bool c4 = c_SpecifiesLdsTransfer; + static constexpr bool c5 = c_SpecifiesThreadClusterAccessOrder; + static constexpr bool c6 = c_SpecifiesSourceAccessOrder; + static constexpr bool c7 = c_SpecifiesGridwiseWmmaGemm; + static constexpr bool c8 = c_SpecifiesFwdConvSpecialization; + static constexpr bool c9 = c_SpecifiesGemmSpecialization; + static constexpr bool c10 = c_SpecifiesNumPrefetchStages; + static constexpr bool c11 = c_SpecifiesLoopScheduler; static consteval bool is_valid() { return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && c10 && c11; } static consteval const std::string message() { - return "\n=== Forward WMMA Algorithm Diagnostic (closest match) ===\n" - "Concepts for FwdWmma Algorithm:\n" - " ConvAlgorithmDescriptor: " + std::string(CHECK_MARK(c1)) + "\n" - " SpecifiesThreadBlock: " + std::string(CHECK_MARK(c2)) + "\n" - " SpecifiesBlockTransfer: " + std::string(CHECK_MARK(c3)) + "\n" - " SpecifiesLdsTransfer: " + std::string(CHECK_MARK(c4)) + "\n" - " SpecifiesThreadClusterAccessOrder: " + std::string(CHECK_MARK(c5)) + "\n" - " SpecifiesSourceAccessOrder: " + std::string(CHECK_MARK(c6)) + "\n" - " SpecifiesGridwiseWmmaGemm: " + std::string(CHECK_MARK(c7)) + "\n" - " SpecifiesFwdConvSpecialization: " + std::string(CHECK_MARK(c8)) + "\n" - " SpecifiesGemmSpecialization: " + std::string(CHECK_MARK(c9)) + "\n" - " SpecifiesNumPrefetchStages: " + std::string(CHECK_MARK(c10)) + "\n" - " SpecifiesLoopScheduler: " + std::string(CHECK_MARK(c11)) + "\n"; + return std::string("\n=== Forward WMMA Algorithm Diagnostic (closest match) ===\n" + "Concepts for FwdWmma Algorithm:\n") + + DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + + DIAGNOSTIC_LINE(SpecifiesThreadBlock) + + DIAGNOSTIC_LINE(SpecifiesBlockTransfer) + + DIAGNOSTIC_LINE(SpecifiesLdsTransfer) + + DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) + + DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) + + DIAGNOSTIC_LINE(SpecifiesGridwiseWmmaGemm) + + DIAGNOSTIC_LINE(SpecifiesFwdConvSpecialization) + + DIAGNOSTIC_LINE(SpecifiesGemmSpecialization) + + DIAGNOSTIC_LINE(SpecifiesNumPrefetchStages) + + DIAGNOSTIC_LINE(SpecifiesLoopScheduler); } }; template struct FwdDlAlgorithm { - static constexpr bool c1 = ConvAlgorithmDescriptor; - static constexpr bool c2 = SpecifiesThreadBlock; - static constexpr bool c3 = SpecifiesFwdConvSpecialization; - static constexpr bool c4 = SpecifiesGemmSpecialization; - static constexpr bool c5 = SpecifiesDlThreadConfig; - static constexpr bool c6 = SpecifiesDlThreadCluster; - static constexpr bool c7 = SpecifiesDlBlockTransfer; - static constexpr bool c8 = SpecifiesDlEpilogue; + CHECK_CONCEPT(T, ConvAlgorithmDescriptor) + CHECK_CONCEPT(T, SpecifiesThreadBlock) + CHECK_CONCEPT(T, SpecifiesFwdConvSpecialization) + CHECK_CONCEPT(T, SpecifiesGemmSpecialization) + CHECK_CONCEPT(T, SpecifiesDlThreadConfig) + CHECK_CONCEPT(T, SpecifiesDlThreadCluster) + CHECK_CONCEPT(T, SpecifiesDlBlockTransfer) + CHECK_CONCEPT(T, SpecifiesDlEpilogue) + + static constexpr bool c1 = c_ConvAlgorithmDescriptor; + static constexpr bool c2 = c_SpecifiesThreadBlock; + static constexpr bool c3 = c_SpecifiesFwdConvSpecialization; + static constexpr bool c4 = c_SpecifiesGemmSpecialization; + static constexpr bool c5 = c_SpecifiesDlThreadConfig; + static constexpr bool c6 = c_SpecifiesDlThreadCluster; + static constexpr bool c7 = c_SpecifiesDlBlockTransfer; + static constexpr bool c8 = c_SpecifiesDlEpilogue; static consteval bool is_valid() { return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8; } static consteval const std::string message() { - return "\n=== Forward DL Algorithm Diagnostic (closest match) ===\n" - "Concepts for FwdDl Algorithm:\n" - " ConvAlgorithmDescriptor: " + std::string(CHECK_MARK(c1)) + "\n" - " SpecifiesThreadBlock: " + std::string(CHECK_MARK(c2)) + "\n" - " SpecifiesFwdConvSpecialization: " + std::string(CHECK_MARK(c3)) + "\n" - " SpecifiesGemmSpecialization: " + std::string(CHECK_MARK(c4)) + "\n" - " SpecifiesDlThreadConfig: " + std::string(CHECK_MARK(c5)) + "\n" - " SpecifiesDlThreadCluster: " + std::string(CHECK_MARK(c6)) + "\n" - " SpecifiesDlBlockTransfer: " + std::string(CHECK_MARK(c7)) + "\n" - " SpecifiesDlEpilogue: " + std::string(CHECK_MARK(c8)) + "\n"; + return std::string("\n=== Forward DL Algorithm Diagnostic (closest match) ===\n" + "Concepts for FwdDl Algorithm:\n") + + DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + + DIAGNOSTIC_LINE(SpecifiesThreadBlock) + + DIAGNOSTIC_LINE(SpecifiesFwdConvSpecialization) + + DIAGNOSTIC_LINE(SpecifiesGemmSpecialization) + + DIAGNOSTIC_LINE(SpecifiesDlThreadConfig) + + DIAGNOSTIC_LINE(SpecifiesDlThreadCluster) + + DIAGNOSTIC_LINE(SpecifiesDlBlockTransfer) + + DIAGNOSTIC_LINE(SpecifiesDlEpilogue); } }; template struct TileAlgorithm { - static constexpr bool c1 = ConvAlgorithmDescriptor; - static constexpr bool c2 = SpecifiesTileThreadBlock; - static constexpr bool c3 = SpecifiesTileTransfer; - static constexpr bool c4 = SpecifiesTileConvSpecialization; - static constexpr bool c5 = SpecifiesTileBlockGemm; - static constexpr bool c6 = SpecifiesTileOptimizations; + CHECK_CONCEPT(T, ConvAlgorithmDescriptor) + CHECK_CONCEPT(T, SpecifiesTileThreadBlock) + CHECK_CONCEPT(T, SpecifiesTileTransfer) + CHECK_CONCEPT(T, SpecifiesTileConvSpecialization) + CHECK_CONCEPT(T, SpecifiesTileBlockGemm) + CHECK_CONCEPT(T, SpecifiesTileOptimizations) + + static constexpr bool c1 = c_ConvAlgorithmDescriptor; + static constexpr bool c2 = c_SpecifiesTileThreadBlock; + static constexpr bool c3 = c_SpecifiesTileTransfer; + static constexpr bool c4 = c_SpecifiesTileConvSpecialization; + static constexpr bool c5 = c_SpecifiesTileBlockGemm; + static constexpr bool c6 = c_SpecifiesTileOptimizations; static consteval bool is_valid() { return c1 && c2 && c3 && c4 && c5 && c6; } static consteval const std::string message() { - return "\n=== CK Tile Algorithm Diagnostic (closest match) ===\n" - "Concepts for CK Tile Conv Algorithm:\n" - " ConvAlgorithmDescriptor: " + std::string(CHECK_MARK(c1)) + "\n" - " SpecifiesTileThreadBlock: " + std::string(CHECK_MARK(c2)) + "\n" - " SpecifiesTileTransfer: " + std::string(CHECK_MARK(c3)) + "\n" - " SpecifiesTileConvSpecialization: " + std::string(CHECK_MARK(c4)) + "\n" - " SpecifiesTileBlockGemm: " + std::string(CHECK_MARK(c5)) + "\n" - " SpecifiesTileOptimizations: " + std::string(CHECK_MARK(c6)) + "\n"; + return std::string("\n=== CK Tile Algorithm Diagnostic (closest match) ===\n" + "Concepts for CK Tile Conv Algorithm:\n") + + DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + + DIAGNOSTIC_LINE(SpecifiesTileThreadBlock) + + DIAGNOSTIC_LINE(SpecifiesTileTransfer) + + DIAGNOSTIC_LINE(SpecifiesTileConvSpecialization) + + DIAGNOSTIC_LINE(SpecifiesTileBlockGemm) + + DIAGNOSTIC_LINE(SpecifiesTileOptimizations); } }; @@ -172,7 +233,9 @@ template struct LargeTensorAlgorithm : public FwdXdlAlgorithm { using BaseAlgorithmType = decltype(T::base_algorithm); - static constexpr bool c13 = SpecifiesLargeTensorSupport; + CHECK_CONCEPT(T, SpecifiesLargeTensorSupport) + + static constexpr bool c13 = c_SpecifiesLargeTensorSupport; static consteval bool is_valid() { return FwdXdlAlgorithm::is_valid() && c13; @@ -180,38 +243,48 @@ struct LargeTensorAlgorithm : public FwdXdlAlgorithm::message() + - " SpecifiesLargeTensorSupport: " + std::string(CHECK_MARK(c13)) + "\n"; + DIAGNOSTIC_LINE(SpecifiesLargeTensorSupport); } }; template struct BwdXdlAlgorithm { - static constexpr bool c1 = ConvAlgorithmDescriptor; - static constexpr bool c2 = SpecifiesThreadBlock; - static constexpr bool c3 = SpecifiesBlockTransfer; - static constexpr bool c4 = SpecifiesLdsTransfer; - static constexpr bool c5 = SpecifiesThreadClusterAccessOrder; - static constexpr bool c6 = SpecifiesSourceAccessOrder; - static constexpr bool c7 = SpecifiesGridwiseBwdXdlGemm; - static constexpr bool c8 = SpecifiesBwdWeightConvSpecialization; - static constexpr bool c9 = SpecifiesTransposeTransfer; + CHECK_CONCEPT(T, ConvAlgorithmDescriptor) + CHECK_CONCEPT(T, SpecifiesThreadBlock) + CHECK_CONCEPT(T, SpecifiesBlockTransfer) + CHECK_CONCEPT(T, SpecifiesLdsTransfer) + CHECK_CONCEPT(T, SpecifiesThreadClusterAccessOrder) + CHECK_CONCEPT(T, SpecifiesSourceAccessOrder) + CHECK_CONCEPT(T, SpecifiesGridwiseBwdXdlGemm) + CHECK_CONCEPT(T, SpecifiesBwdWeightConvSpecialization) + CHECK_CONCEPT(T, SpecifiesTransposeTransfer) + + static constexpr bool c1 = c_ConvAlgorithmDescriptor; + static constexpr bool c2 = c_SpecifiesThreadBlock; + static constexpr bool c3 = c_SpecifiesBlockTransfer; + static constexpr bool c4 = c_SpecifiesLdsTransfer; + static constexpr bool c5 = c_SpecifiesThreadClusterAccessOrder; + static constexpr bool c6 = c_SpecifiesSourceAccessOrder; + static constexpr bool c7 = c_SpecifiesGridwiseBwdXdlGemm; + static constexpr bool c8 = c_SpecifiesBwdWeightConvSpecialization; + static constexpr bool c9 = c_SpecifiesTransposeTransfer; static consteval bool is_valid() { return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9; } static consteval const std::string message() { - return "\n=== Backward XDL Algorithm Diagnostic (closest match) ===\n" - "Concepts for BwdXdl Algorithm:\n" - " ConvAlgorithmDescriptor: " + std::string(CHECK_MARK(c1)) + "\n" - " SpecifiesThreadBlock: " + std::string(CHECK_MARK(c2)) + "\n" - " SpecifiesBlockTransfer: " + std::string(CHECK_MARK(c3)) + "\n" - " SpecifiesLdsTransfer: " + std::string(CHECK_MARK(c4)) + "\n" - " SpecifiesThreadClusterAccessOrder: " + std::string(CHECK_MARK(c5)) + "\n" - " SpecifiesSourceAccessOrder: " + std::string(CHECK_MARK(c6)) + "\n" - " SpecifiesGridwiseBwdXdlGemm: " + std::string(CHECK_MARK(c7)) + "\n" - " SpecifiesBwdWeightConvSpecialization: " + std::string(CHECK_MARK(c8)) + "\n" - " SpecifiesTransposeTransfer: " + std::string(CHECK_MARK(c9)) + "\n"; + return std::string("\n=== Backward XDL Algorithm Diagnostic (closest match) ===\n" + "Concepts for BwdXdl Algorithm:\n") + + DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + + DIAGNOSTIC_LINE(SpecifiesThreadBlock) + + DIAGNOSTIC_LINE(SpecifiesBlockTransfer) + + DIAGNOSTIC_LINE(SpecifiesLdsTransfer) + + DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) + + DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) + + DIAGNOSTIC_LINE(SpecifiesGridwiseBwdXdlGemm) + + DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization) + + DIAGNOSTIC_LINE(SpecifiesTransposeTransfer); } }; @@ -315,4 +388,4 @@ consteval void diagnose_bwd_weight_algorithm_signature() } } -} +} From c6798d367327b33a0af6ea9d387b1410549a5d7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Mon, 22 Dec 2025 08:06:41 -0500 Subject: [PATCH 07/81] Improve compile time diagnostics. --- .../builder/conv_algorithm_diagnostics.hpp | 708 ++++++++++++++++++ .../builder/factory/conv_algorithms.hpp | 27 +- .../test/conv/ck/test_ckb_conv_bwd_weight.cpp | 2 +- 3 files changed, 718 insertions(+), 19 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp new file mode 100644 index 00000000000..d97ee85abe4 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp @@ -0,0 +1,708 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/conv_algorithm_concepts.hpp" + +namespace ck_tile::builder::diagnostics { + +#define CHECK_MARK(cond) (cond ? "[✓]" : "[✗]") + +// Macro to check a concept and generate both the boolean and the string representation +#define CHECK_CONCEPT(Type, Concept) \ + static constexpr bool c_##Concept = Concept; \ + static constexpr const char* s_##Concept = #Concept; + +// Helper to create diagnostic message line +#define DIAGNOSTIC_LINE(Concept) \ + " " + std::string(s_##Concept) + ": " + std::string(CHECK_MARK(c_##Concept)) + "\n" + \ + (c_##Concept ? std::string("") : detailed_diagnostic_##Concept()) + +namespace detail { + +// ThreadBlockDescriptor diagnostics +template +consteval auto diagnose_thread_block_descriptor() -> std::string { + if constexpr (!requires { T::thread_block; }) { + return " → T::thread_block member: [✗] (not found)\n"; + } else { + using TB = decltype(T::thread_block); + std::string msg; + + constexpr bool has_block_size = requires(TB t) { { t.block_size } -> std::convertible_to; }; + constexpr bool has_tile_m = requires(TB t) { { t.tile_size.m } -> std::convertible_to; }; + constexpr bool has_tile_n = requires(TB t) { { t.tile_size.n } -> std::convertible_to; }; + constexpr bool has_tile_k = requires(TB t) { { t.tile_size.k } -> std::convertible_to; }; + + msg += " → thread_block.block_size: " + std::string(CHECK_MARK(has_block_size)) + + (has_block_size ? "\n" : " (missing or wrong type)\n"); + msg += " → thread_block.tile_size.m: " + std::string(CHECK_MARK(has_tile_m)) + + (has_tile_m ? "\n" : " (missing or wrong type)\n"); + msg += " → thread_block.tile_size.n: " + std::string(CHECK_MARK(has_tile_n)) + + (has_tile_n ? "\n" : " (missing or wrong type)\n"); + msg += " → thread_block.tile_size.k: " + std::string(CHECK_MARK(has_tile_k)) + + (has_tile_k ? "\n" : " (missing or wrong type)\n"); + + return msg; + } +} + +// GridwiseXdlGemmDescriptor diagnostics +template +consteval auto diagnose_xdl_params() -> std::string { + std::string msg; + + constexpr bool has_m_per_xdl = requires(XdlParams t) { { t.m_per_xdl } -> std::convertible_to; }; + constexpr bool has_n_per_xdl = requires(XdlParams t) { { t.n_per_xdl } -> std::convertible_to; }; + constexpr bool has_m_xdl_per_wave = requires(XdlParams t) { { t.m_xdl_per_wave } -> std::convertible_to; }; + constexpr bool has_n_xdl_per_wave = requires(XdlParams t) { { t.n_xdl_per_wave } -> std::convertible_to; }; + + msg += " → xdl_params.m_per_xdl: " + std::string(CHECK_MARK(has_m_per_xdl)) + + (has_m_per_xdl ? "\n" : " (missing or wrong type)\n"); + msg += " → xdl_params.n_per_xdl: " + std::string(CHECK_MARK(has_n_per_xdl)) + + (has_n_per_xdl ? "\n" : " (missing or wrong type)\n"); + msg += " → xdl_params.m_xdl_per_wave: " + std::string(CHECK_MARK(has_m_xdl_per_wave)) + + (has_m_xdl_per_wave ? "\n" : " (missing or wrong type)\n"); + msg += " → xdl_params.n_xdl_per_wave: " + std::string(CHECK_MARK(has_n_xdl_per_wave)) + + (has_n_xdl_per_wave ? "\n" : " (missing or wrong type)\n"); + + return msg; +} + +// BlockTransferDescriptor diagnostics +template +consteval auto diagnose_block_transfer(const char* prefix) -> std::string { + std::string msg; + + constexpr bool has_k0 = requires(BT t) { { t.k0 } -> std::convertible_to; }; + constexpr bool has_m_n = requires(BT t) { { t.m_n } -> std::convertible_to; }; + constexpr bool has_k1 = requires(BT t) { { t.k1 } -> std::convertible_to; }; + + msg += std::string(" → ") + prefix + ".k0: " + std::string(CHECK_MARK(has_k0)) + + (has_k0 ? "\n" : " (missing or wrong type)\n"); + msg += std::string(" → ") + prefix + ".m_n: " + std::string(CHECK_MARK(has_m_n)) + + (has_m_n ? "\n" : " (missing or wrong type)\n"); + msg += std::string(" → ") + prefix + ".k1: " + std::string(CHECK_MARK(has_k1)) + + (has_k1 ? "\n" : " (missing or wrong type)\n"); + + return msg; +} + +// LdsTransferDescriptor diagnostics +template +consteval auto diagnose_lds_transfer(const char* prefix) -> std::string { + std::string msg; + + constexpr bool has_src_vector_dim = requires(LT t) { { t.src_vector_dim } -> std::convertible_to; }; + constexpr bool has_src_scalar_per_vector = requires(LT t) { { t.src_scalar_per_vector } -> std::convertible_to; }; + constexpr bool has_lds_dst_scalar_per_vector = requires(LT t) { { t.lds_dst_scalar_per_vector } -> std::convertible_to; }; + constexpr bool has_is_direct_load = requires(LT t) { { t.is_direct_load } -> std::convertible_to; }; + constexpr bool has_lds_padding = requires(LT t) { { t.lds_padding } -> std::convertible_to; }; + + msg += std::string(" → ") + prefix + ".src_vector_dim: " + std::string(CHECK_MARK(has_src_vector_dim)) + + (has_src_vector_dim ? "\n" : " (missing or wrong type)\n"); + msg += std::string(" → ") + prefix + ".src_scalar_per_vector: " + std::string(CHECK_MARK(has_src_scalar_per_vector)) + + (has_src_scalar_per_vector ? "\n" : " (missing or wrong type)\n"); + msg += std::string(" → ") + prefix + ".lds_dst_scalar_per_vector: " + std::string(CHECK_MARK(has_lds_dst_scalar_per_vector)) + + (has_lds_dst_scalar_per_vector ? "\n" : " (missing or wrong type)\n"); + msg += std::string(" → ") + prefix + ".is_direct_load: " + std::string(CHECK_MARK(has_is_direct_load)) + + (has_is_direct_load ? "\n" : " (missing or wrong type)\n"); + msg += std::string(" → ") + prefix + ".lds_padding: " + std::string(CHECK_MARK(has_lds_padding)) + + (has_lds_padding ? "\n" : " (missing or wrong type)\n"); + + return msg; +} + +// ThreadClusterDescriptor diagnostics +template +consteval auto diagnose_thread_cluster(const char* prefix) -> std::string { + std::string msg; + + constexpr bool has_m_block = requires(TC t) { { t.m_block } -> std::convertible_to; }; + constexpr bool has_m_wave_per_xdl = requires(TC t) { { t.m_wave_per_xdl } -> std::convertible_to; }; + constexpr bool has_n_block = requires(TC t) { { t.n_block } -> std::convertible_to; }; + constexpr bool has_n_wave_per_xdl = requires(TC t) { { t.n_wave_per_xdl } -> std::convertible_to; }; + + msg += std::string(" → ") + prefix + ".m_block: " + std::string(CHECK_MARK(has_m_block)) + + (has_m_block ? "\n" : " (missing or wrong type)\n"); + msg += std::string(" → ") + prefix + ".m_wave_per_xdl: " + std::string(CHECK_MARK(has_m_wave_per_xdl)) + + (has_m_wave_per_xdl ? "\n" : " (missing or wrong type)\n"); + msg += std::string(" → ") + prefix + ".n_block: " + std::string(CHECK_MARK(has_n_block)) + + (has_n_block ? "\n" : " (missing or wrong type)\n"); + msg += std::string(" → ") + prefix + ".n_wave_per_xdl: " + std::string(CHECK_MARK(has_n_wave_per_xdl)) + + (has_n_wave_per_xdl ? "\n" : " (missing or wrong type)\n"); + + return msg; +} + +// AccessOrderDescriptor diagnostics +template +consteval auto diagnose_access_order(const char* prefix) -> std::string { + std::string msg; + + constexpr bool has_order = requires(AO t) { { t.order } -> std::convertible_to>; }; + + msg += std::string(" → ") + prefix + ".order: " + std::string(CHECK_MARK(has_order)) + + (has_order ? "\n" : " (missing or wrong type)\n"); + + return msg; +} + +// EpilogueDescriptor diagnostics +template +consteval auto diagnose_epilogue(const char* prefix) -> std::string { + std::string msg; + + constexpr bool has_m_xdl = requires(E t) { { t.m_xdl_per_wave_per_shuffle } -> std::convertible_to; }; + constexpr bool has_n_per_wave = requires(E t) { { t.n_per_wave_per_shuffle } -> std::convertible_to; }; + constexpr bool has_scalar_per_vector = requires(E t) { { t.scalar_per_vector } -> std::convertible_to; }; + + msg += std::string(" → ") + prefix + ".m_xdl_per_wave_per_shuffle: " + std::string(CHECK_MARK(has_m_xdl)) + + (has_m_xdl ? "\n" : " (missing or wrong type)\n"); + msg += std::string(" → ") + prefix + ".n_per_wave_per_shuffle: " + std::string(CHECK_MARK(has_n_per_wave)) + + (has_n_per_wave ? "\n" : " (missing or wrong type)\n"); + msg += std::string(" → ") + prefix + ".scalar_per_vector: " + std::string(CHECK_MARK(has_scalar_per_vector)) + + (has_scalar_per_vector ? "\n" : " (missing or wrong type)\n"); + + return msg; +} + +} // namespace detail + +// Detailed diagnostic functions for high-level concepts +template +consteval auto detailed_diagnostic_ConvAlgorithmDescriptor() -> std::string { + return ""; // Base concept, no sub-requirements to check +} + +template +consteval auto detailed_diagnostic_SpecifiesThreadBlock() -> std::string { + if constexpr (!requires { T::thread_block; }) { + return " → T::thread_block member: [✗] (not found)\n"; + } else { + return " → T::thread_block member: [✓]\n" + + detail::diagnose_thread_block_descriptor(); + } +} + +template +consteval auto detailed_diagnostic_SpecifiesGridwiseFwdXdlGemm() -> std::string { + std::string msg; + + constexpr bool has_ak1 = requires { { T::ak1 } -> std::convertible_to; }; + constexpr bool has_bk1 = requires { { T::bk1 } -> std::convertible_to; }; + constexpr bool has_xdl_params = requires { T::xdl_params; }; + + msg += " → T::ak1: " + std::string(CHECK_MARK(has_ak1)) + + (has_ak1 ? "\n" : " (missing or wrong type)\n"); + msg += " → T::bk1: " + std::string(CHECK_MARK(has_bk1)) + + (has_bk1 ? "\n" : " (missing or wrong type)\n"); + msg += " → T::xdl_params member: " + std::string(CHECK_MARK(has_xdl_params)) + "\n"; + + if constexpr (has_xdl_params) { + msg += detail::diagnose_xdl_params(); + } + + return msg; +} + +template +consteval auto detailed_diagnostic_SpecifiesGridwiseBwdXdlGemm() -> std::string { + std::string msg; + + constexpr bool has_k0 = requires { { T::k0_per_block } -> std::convertible_to; }; + constexpr bool has_k1 = requires { { T::k1 } -> std::convertible_to; }; + constexpr bool has_xdl_params = requires { T::xdl_params; }; + + msg += " → T::k0_per_block: " + std::string(CHECK_MARK(has_k0)) + + (has_k0 ? "\n" : " (missing or wrong type)\n"); + msg += " → T::k1: " + std::string(CHECK_MARK(has_k1)) + + (has_k1 ? "\n" : " (missing or wrong type)\n"); + msg += " → T::xdl_params member: " + std::string(CHECK_MARK(has_xdl_params)) + "\n"; + + if constexpr (has_xdl_params) { + msg += detail::diagnose_xdl_params(); + } + + return msg; +} + +template +consteval auto detailed_diagnostic_SpecifiesBlockTransfer() -> std::string { + std::string msg; + + constexpr bool has_transfer = requires { T::transfer; }; + msg += " → T::transfer member: " + std::string(CHECK_MARK(has_transfer)) + "\n"; + + if constexpr (!has_transfer) { + return msg; + } + + constexpr bool has_a = requires { T::transfer.a; }; + constexpr bool has_b = requires { T::transfer.b; }; + constexpr bool has_c = requires { T::transfer.c; }; + + msg += " → T::transfer.a: " + std::string(CHECK_MARK(has_a)) + "\n"; + msg += " → T::transfer.b: " + std::string(CHECK_MARK(has_b)) + "\n"; + msg += " → T::transfer.c: " + std::string(CHECK_MARK(has_c)) + "\n"; + + if constexpr (has_a && requires { T::transfer.a.block_transfer; }) { + msg += detail::diagnose_block_transfer("transfer.a.block_transfer"); + } else if constexpr (has_a) { + msg += " → T::transfer.a.block_transfer: [✗] (missing)\n"; + } + + if constexpr (has_b && requires { T::transfer.b.block_transfer; }) { + msg += detail::diagnose_block_transfer("transfer.b.block_transfer"); + } else if constexpr (has_b) { + msg += " → T::transfer.b.block_transfer: [✗] (missing)\n"; + } + + if constexpr (has_c && requires { T::transfer.c.thread_cluster_dims; }) { + msg += detail::diagnose_thread_cluster("transfer.c.thread_cluster_dims"); + } else if constexpr (has_c) { + msg += " → T::transfer.c.thread_cluster_dims: [✗] (missing)\n"; + } + + return msg; +} + +template +consteval auto detailed_diagnostic_SpecifiesLdsTransfer() -> std::string { + std::string msg; + + constexpr bool has_transfer = requires { T::transfer; }; + msg += " → T::transfer member: " + std::string(CHECK_MARK(has_transfer)) + "\n"; + + if constexpr (!has_transfer) { + return msg; + } + + constexpr bool has_a = requires { T::transfer.a; }; + constexpr bool has_b = requires { T::transfer.b; }; + constexpr bool has_c = requires { T::transfer.c; }; + + if constexpr (has_a && requires { T::transfer.a.lds_transfer; }) { + msg += detail::diagnose_lds_transfer("transfer.a.lds_transfer"); + } else if constexpr (has_a) { + msg += " → T::transfer.a.lds_transfer: [✗] (missing)\n"; + } + + if constexpr (has_b && requires { T::transfer.b.lds_transfer; }) { + msg += detail::diagnose_lds_transfer("transfer.b.lds_transfer"); + } else if constexpr (has_b) { + msg += " → T::transfer.b.lds_transfer: [✗] (missing)\n"; + } + + if constexpr (has_c && requires { T::transfer.c.epilogue; }) { + msg += detail::diagnose_epilogue("transfer.c.epilogue"); + } else if constexpr (has_c) { + msg += " → T::transfer.c.epilogue: [✗] (missing)\n"; + } + + return msg; +} + +template +consteval auto detailed_diagnostic_SpecifiesThreadClusterAccessOrder() -> std::string { + std::string msg; + + constexpr bool has_transfer = requires { T::transfer; }; + if constexpr (!has_transfer) { + return " → T::transfer member: [✗] (not found)\n"; + } + + constexpr bool has_a = requires { T::transfer.a; }; + constexpr bool has_b = requires { T::transfer.b; }; + + if constexpr (has_a && requires { T::transfer.a.block_transfer_access_order; }) { + msg += detail::diagnose_access_order("transfer.a.block_transfer_access_order"); + } else if constexpr (has_a) { + msg += " → T::transfer.a.block_transfer_access_order: [✗] (missing)\n"; + } + + if constexpr (has_b && requires { T::transfer.b.block_transfer_access_order; }) { + msg += detail::diagnose_access_order("transfer.b.block_transfer_access_order"); + } else if constexpr (has_b) { + msg += " → T::transfer.b.block_transfer_access_order: [✗] (missing)\n"; + } + + return msg; +} + +template +consteval auto detailed_diagnostic_SpecifiesSourceAccessOrder() -> std::string { + std::string msg; + + constexpr bool has_transfer = requires { T::transfer; }; + if constexpr (!has_transfer) { + return " → T::transfer member: [✗] (not found)\n"; + } + + constexpr bool has_a = requires { T::transfer.a; }; + constexpr bool has_b = requires { T::transfer.b; }; + + if constexpr (has_a && requires { T::transfer.a.src_access_order; }) { + msg += detail::diagnose_access_order("transfer.a.src_access_order"); + } else if constexpr (has_a) { + msg += " → T::transfer.a.src_access_order: [✗] (missing)\n"; + } + + if constexpr (has_b && requires { T::transfer.b.src_access_order; }) { + msg += detail::diagnose_access_order("transfer.b.src_access_order"); + } else if constexpr (has_b) { + msg += " → T::transfer.b.src_access_order: [✗] (missing)\n"; + } + + return msg; +} + +template +consteval auto detailed_diagnostic_SpecifiesBlockGemm() -> std::string { + std::string msg; + + constexpr bool has_block_gemm = requires { T::block_gemm; }; + msg += " → T::block_gemm member: " + std::string(CHECK_MARK(has_block_gemm)) + "\n"; + + if constexpr (!has_block_gemm) { + return msg; + } + + constexpr bool has_pipeline = requires { { T::block_gemm.pipeline_version } -> std::convertible_to; }; + constexpr bool has_scheduler = requires { { T::block_gemm.scheduler } -> std::convertible_to; }; + + msg += " → block_gemm.pipeline_version: " + std::string(CHECK_MARK(has_pipeline)) + + (has_pipeline ? "\n" : " (missing or wrong type)\n"); + msg += " → block_gemm.scheduler: " + std::string(CHECK_MARK(has_scheduler)) + + (has_scheduler ? "\n" : " (missing or wrong type)\n"); + + return msg; +} + +template +consteval auto detailed_diagnostic_SpecifiesFwdConvSpecialization() -> std::string { + constexpr bool has_member = requires { { T::fwd_specialization } -> std::convertible_to; }; + return " → T::fwd_specialization: " + std::string(CHECK_MARK(has_member)) + + (has_member ? "\n" : " (missing or wrong type)\n"); +} + +template +consteval auto detailed_diagnostic_SpecifiesBwdWeightConvSpecialization() -> std::string { + constexpr bool has_member = requires { { T::bwd_weight_specialization } -> std::convertible_to; }; + return " → T::bwd_weight_specialization: " + std::string(CHECK_MARK(has_member)) + + (has_member ? "\n" : " (missing or wrong type)\n"); +} + +template +consteval auto detailed_diagnostic_SpecifiesGemmSpecialization() -> std::string { + constexpr bool has_member = requires { { T::gemm_specialization } -> std::convertible_to; }; + return " → T::gemm_specialization: " + std::string(CHECK_MARK(has_member)) + + (has_member ? "\n" : " (missing or wrong type)\n"); +} + +template +consteval auto detailed_diagnostic_SpecifiesNumPrefetchStages() -> std::string { + constexpr bool has_member = requires { { T::num_gemm_k_prefetch_stages } -> std::convertible_to; }; + return " → T::num_gemm_k_prefetch_stages: " + std::string(CHECK_MARK(has_member)) + + (has_member ? "\n" : " (missing or wrong type)\n"); +} + +template +consteval auto detailed_diagnostic_SpecifiesNumGroupsToMerge() -> std::string { + constexpr bool has_member = requires { { T::num_groups_to_merge } -> std::convertible_to; }; + return " → T::num_groups_to_merge: " + std::string(CHECK_MARK(has_member)) + + (has_member ? "\n" : " (missing or wrong type)\n"); +} + +template +consteval auto detailed_diagnostic_SpecifiesLoopScheduler() -> std::string { + constexpr bool has_member = requires { { T::loop_scheduler } -> std::convertible_to; }; + return " → T::loop_scheduler: " + std::string(CHECK_MARK(has_member)) + + (has_member ? "\n" : " (missing or wrong type)\n"); +} + +template +consteval auto detailed_diagnostic_SpecifiesLargeTensorSupport() -> std::string { + std::string msg; + constexpr bool has_specialization = requires { { T::specialization } -> std::convertible_to; }; + msg += " → T::specialization: " + std::string(CHECK_MARK(has_specialization)) + + (has_specialization ? "\n" : " (missing or wrong type)\n"); + + if constexpr (has_specialization) { + constexpr bool is_large_tensor = (T::specialization == ConvAlgorithmSpecialization::LARGE_TENSOR); + msg += " → specialization == LARGE_TENSOR: " + std::string(CHECK_MARK(is_large_tensor)) + "\n"; + } + + return msg; +} + +template +consteval auto detailed_diagnostic_SpecifiesTransposeTransfer() -> std::string { + std::string msg; + constexpr bool has_src = requires { { T::max_transpose_transfer_src_scalar_per_vector } -> std::convertible_to; }; + constexpr bool has_dst = requires { { T::max_transpose_transfer_dst_scalar_per_vector } -> std::convertible_to; }; + + msg += " → T::max_transpose_transfer_src_scalar_per_vector: " + std::string(CHECK_MARK(has_src)) + + (has_src ? "\n" : " (missing or wrong type)\n"); + msg += " → T::max_transpose_transfer_dst_scalar_per_vector: " + std::string(CHECK_MARK(has_dst)) + + (has_dst ? "\n" : " (missing or wrong type)\n"); + + return msg; +} + +template +consteval auto detailed_diagnostic_SpecifiesGridwiseWmmaGemm() -> std::string { + std::string msg; + constexpr bool has_gridwise_gemm = requires { T::gridwise_gemm; }; + msg += " → T::gridwise_gemm member: " + std::string(CHECK_MARK(has_gridwise_gemm)) + "\n"; + + if constexpr (!has_gridwise_gemm) { + return msg; + } + + using GG = decltype(T::gridwise_gemm); + constexpr bool has_k1 = requires(GG t) { { t.k1 } -> std::convertible_to; }; + constexpr bool has_m_per_wmma = requires(GG t) { { t.m_per_wmma } -> std::convertible_to; }; + constexpr bool has_n_per_wmma = requires(GG t) { { t.n_per_wmma } -> std::convertible_to; }; + constexpr bool has_m_wmma_per_wave = requires(GG t) { { t.m_wmma_per_wave } -> std::convertible_to; }; + constexpr bool has_n_wmma_per_wave = requires(GG t) { { t.n_wmma_per_wave } -> std::convertible_to; }; + constexpr bool has_pipeline = requires(GG t) { { t.pipeline_version } -> std::convertible_to; }; + + msg += " → gridwise_gemm.k1: " + std::string(CHECK_MARK(has_k1)) + (has_k1 ? "\n" : " (missing or wrong type)\n"); + msg += " → gridwise_gemm.m_per_wmma: " + std::string(CHECK_MARK(has_m_per_wmma)) + (has_m_per_wmma ? "\n" : " (missing or wrong type)\n"); + msg += " → gridwise_gemm.n_per_wmma: " + std::string(CHECK_MARK(has_n_per_wmma)) + (has_n_per_wmma ? "\n" : " (missing or wrong type)\n"); + msg += " → gridwise_gemm.m_wmma_per_wave: " + std::string(CHECK_MARK(has_m_wmma_per_wave)) + (has_m_wmma_per_wave ? "\n" : " (missing or wrong type)\n"); + msg += " → gridwise_gemm.n_wmma_per_wave: " + std::string(CHECK_MARK(has_n_wmma_per_wave)) + (has_n_wmma_per_wave ? "\n" : " (missing or wrong type)\n"); + msg += " → gridwise_gemm.pipeline_version: " + std::string(CHECK_MARK(has_pipeline)) + (has_pipeline ? "\n" : " (missing or wrong type)\n"); + + return msg; +} + +// Tile-specific diagnostics +template +consteval auto detailed_diagnostic_SpecifiesTileThreadBlock() -> std::string { + if constexpr (!requires { T::thread_block; }) { + return " → T::thread_block member: [✗] (not found)\n"; + } else { + using TB = decltype(T::thread_block); + std::string msg = " → T::thread_block member: [✓]\n"; + + constexpr bool has_tile_m = requires(TB t) { { t.tile_size.m } -> std::convertible_to; }; + constexpr bool has_tile_n = requires(TB t) { { t.tile_size.n } -> std::convertible_to; }; + constexpr bool has_tile_k = requires(TB t) { { t.tile_size.k } -> std::convertible_to; }; + + msg += " → thread_block.tile_size.m: " + std::string(CHECK_MARK(has_tile_m)) + (has_tile_m ? "\n" : " (missing or wrong type)\n"); + msg += " → thread_block.tile_size.n: " + std::string(CHECK_MARK(has_tile_n)) + (has_tile_n ? "\n" : " (missing or wrong type)\n"); + msg += " → thread_block.tile_size.k: " + std::string(CHECK_MARK(has_tile_k)) + (has_tile_k ? "\n" : " (missing or wrong type)\n"); + + return msg; + } +} + +template +consteval auto detailed_diagnostic_SpecifiesTileTransfer() -> std::string { + std::string msg; + constexpr bool has_transfer = requires { T::transfer; }; + msg += " → T::transfer member: " + std::string(CHECK_MARK(has_transfer)) + "\n"; + + if constexpr (!has_transfer) { + return msg; + } + + constexpr bool has_a_scalar = requires { { T::transfer.a_scalar_per_vector } -> std::convertible_to; }; + constexpr bool has_b_scalar = requires { { T::transfer.b_scalar_per_vector } -> std::convertible_to; }; + constexpr bool has_c_scalar = requires { { T::transfer.c_scalar_per_vector } -> std::convertible_to; }; + + msg += " → transfer.a_scalar_per_vector: " + std::string(CHECK_MARK(has_a_scalar)) + (has_a_scalar ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.b_scalar_per_vector: " + std::string(CHECK_MARK(has_b_scalar)) + (has_b_scalar ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.c_scalar_per_vector: " + std::string(CHECK_MARK(has_c_scalar)) + (has_c_scalar ? "\n" : " (missing or wrong type)\n"); + + return msg; +} + +template +consteval auto detailed_diagnostic_SpecifiesTileConvSpecialization() -> std::string { + constexpr bool has_member = requires { { T::specialization } -> std::convertible_to; }; + return " → T::specialization: " + std::string(CHECK_MARK(has_member)) + (has_member ? "\n" : " (missing or wrong type)\n"); +} + +template +consteval auto detailed_diagnostic_SpecifiesTileBlockGemm() -> std::string { + std::string msg; + constexpr bool has_block_gemm = requires { T::block_gemm; }; + msg += " → T::block_gemm member: " + std::string(CHECK_MARK(has_block_gemm)) + "\n"; + + if constexpr (!has_block_gemm) { + return msg; + } + + using BG = decltype(T::block_gemm); + constexpr bool has_warps_m = requires(BG t) { { t.warps.m } -> std::convertible_to; }; + constexpr bool has_warps_n = requires(BG t) { { t.warps.n } -> std::convertible_to; }; + constexpr bool has_warps_k = requires(BG t) { { t.warps.k } -> std::convertible_to; }; + constexpr bool has_warp_tile_m = requires(BG t) { { t.warp_tile.m } -> std::convertible_to; }; + constexpr bool has_warp_tile_n = requires(BG t) { { t.warp_tile.n } -> std::convertible_to; }; + constexpr bool has_warp_tile_k = requires(BG t) { { t.warp_tile.k } -> std::convertible_to; }; + constexpr bool has_double_smem = requires(BG t) { { t.double_smem_buffer } -> std::convertible_to; }; + constexpr bool has_num_wave_groups = requires(BG t) { { t.num_wave_groups } -> std::convertible_to; }; + constexpr bool has_pipeline = requires(BG t) { { t.pipeline_version } -> std::convertible_to; }; + constexpr bool has_scheduler = requires(BG t) { { t.scheduler } -> std::convertible_to; }; + + msg += " → block_gemm.warps.m: " + std::string(CHECK_MARK(has_warps_m)) + (has_warps_m ? "\n" : " (missing or wrong type)\n"); + msg += " → block_gemm.warps.n: " + std::string(CHECK_MARK(has_warps_n)) + (has_warps_n ? "\n" : " (missing or wrong type)\n"); + msg += " → block_gemm.warps.k: " + std::string(CHECK_MARK(has_warps_k)) + (has_warps_k ? "\n" : " (missing or wrong type)\n"); + msg += " → block_gemm.warp_tile.m: " + std::string(CHECK_MARK(has_warp_tile_m)) + (has_warp_tile_m ? "\n" : " (missing or wrong type)\n"); + msg += " → block_gemm.warp_tile.n: " + std::string(CHECK_MARK(has_warp_tile_n)) + (has_warp_tile_n ? "\n" : " (missing or wrong type)\n"); + msg += " → block_gemm.warp_tile.k: " + std::string(CHECK_MARK(has_warp_tile_k)) + (has_warp_tile_k ? "\n" : " (missing or wrong type)\n"); + msg += " → block_gemm.double_smem_buffer: " + std::string(CHECK_MARK(has_double_smem)) + (has_double_smem ? "\n" : " (missing or wrong type)\n"); + msg += " → block_gemm.num_wave_groups: " + std::string(CHECK_MARK(has_num_wave_groups)) + (has_num_wave_groups ? "\n" : " (missing or wrong type)\n"); + msg += " → block_gemm.pipeline_version: " + std::string(CHECK_MARK(has_pipeline)) + (has_pipeline ? "\n" : " (missing or wrong type)\n"); + msg += " → block_gemm.scheduler: " + std::string(CHECK_MARK(has_scheduler)) + (has_scheduler ? "\n" : " (missing or wrong type)\n"); + + return msg; +} + +template +consteval auto detailed_diagnostic_SpecifiesTileOptimizations() -> std::string { + std::string msg; + constexpr bool has_optimizations = requires { T::optimizations; }; + msg += " → T::optimizations member: " + std::string(CHECK_MARK(has_optimizations)) + "\n"; + + if constexpr (!has_optimizations) { + return msg; + } + + using OPT = decltype(T::optimizations); + constexpr bool has_num_groups = requires(OPT t) { { t.num_groups_to_merge } -> std::convertible_to; }; + constexpr bool has_split_image = requires(OPT t) { { t.split_image } -> std::convertible_to; }; + constexpr bool has_explicit_gemm = requires(OPT t) { { t.explicit_gemm } -> std::convertible_to; }; + + msg += " → optimizations.num_groups_to_merge: " + std::string(CHECK_MARK(has_num_groups)) + (has_num_groups ? "\n" : " (missing or wrong type)\n"); + msg += " → optimizations.split_image: " + std::string(CHECK_MARK(has_split_image)) + (has_split_image ? "\n" : " (missing or wrong type)\n"); + msg += " → optimizations.explicit_gemm: " + std::string(CHECK_MARK(has_explicit_gemm)) + (has_explicit_gemm ? "\n" : " (missing or wrong type)\n"); + + return msg; +} + +// DL-specific diagnostics +template +consteval auto detailed_diagnostic_SpecifiesDlThreadConfig() -> std::string { + std::string msg; + constexpr bool has_thread_config = requires { T::thread_config; }; + msg += " → T::thread_config member: " + std::string(CHECK_MARK(has_thread_config)) + "\n"; + + if constexpr (!has_thread_config) { + return msg; + } + + using TC = decltype(T::thread_config); + constexpr bool has_k0 = requires(TC t) { { t.k0_per_block } -> std::convertible_to; }; + constexpr bool has_k1 = requires(TC t) { { t.k1 } -> std::convertible_to; }; + constexpr bool has_m1 = requires(TC t) { { t.m1_per_thread } -> std::convertible_to; }; + constexpr bool has_n1 = requires(TC t) { { t.n1_per_thread } -> std::convertible_to; }; + constexpr bool has_k = requires(TC t) { { t.k_per_thread } -> std::convertible_to; }; + + msg += " → thread_config.k0_per_block: " + std::string(CHECK_MARK(has_k0)) + (has_k0 ? "\n" : " (missing or wrong type)\n"); + msg += " → thread_config.k1: " + std::string(CHECK_MARK(has_k1)) + (has_k1 ? "\n" : " (missing or wrong type)\n"); + msg += " → thread_config.m1_per_thread: " + std::string(CHECK_MARK(has_m1)) + (has_m1 ? "\n" : " (missing or wrong type)\n"); + msg += " → thread_config.n1_per_thread: " + std::string(CHECK_MARK(has_n1)) + (has_n1 ? "\n" : " (missing or wrong type)\n"); + msg += " → thread_config.k_per_thread: " + std::string(CHECK_MARK(has_k)) + (has_k ? "\n" : " (missing or wrong type)\n"); + + return msg; +} + +template +consteval auto detailed_diagnostic_SpecifiesDlThreadCluster() -> std::string { + std::string msg; + constexpr bool has_thread_cluster = requires { T::thread_cluster; }; + msg += " → T::thread_cluster member: " + std::string(CHECK_MARK(has_thread_cluster)) + "\n"; + + if constexpr (!has_thread_cluster) { + return msg; + } + + using TC = decltype(T::thread_cluster); + constexpr bool has_m1_xs = requires(TC t) { { t.m1_xs } -> std::convertible_to>; }; + constexpr bool has_n1_xs = requires(TC t) { { t.n1_xs } -> std::convertible_to>; }; + + msg += " → thread_cluster.m1_xs: " + std::string(CHECK_MARK(has_m1_xs)) + (has_m1_xs ? "\n" : " (missing or wrong type)\n"); + msg += " → thread_cluster.n1_xs: " + std::string(CHECK_MARK(has_n1_xs)) + (has_n1_xs ? "\n" : " (missing or wrong type)\n"); + + return msg; +} + +template +consteval auto detailed_diagnostic_SpecifiesDlBlockTransfer() -> std::string { + std::string msg; + constexpr bool has_transfer = requires { T::transfer; }; + msg += " → T::transfer member: " + std::string(CHECK_MARK(has_transfer)) + "\n"; + + if constexpr (!has_transfer) { + return msg; + } + + constexpr bool has_a = requires { T::transfer.a; }; + constexpr bool has_b = requires { T::transfer.b; }; + msg += " → T::transfer.a: " + std::string(CHECK_MARK(has_a)) + "\n"; + msg += " → T::transfer.b: " + std::string(CHECK_MARK(has_b)) + "\n"; + + if constexpr (has_a && requires { T::transfer.a.block_transfer; }) { + using ABT = decltype(T::transfer.a.block_transfer); + constexpr bool has_thread_slice = requires(ABT t) { { t.thread_slice_lengths } -> std::convertible_to>; }; + constexpr bool has_thread_cluster = requires(ABT t) { { t.thread_cluster_lengths } -> std::convertible_to>; }; + constexpr bool has_cluster_arrange = requires(ABT t) { { t.thread_cluster_arrange_order } -> std::convertible_to>; }; + constexpr bool has_src_access = requires(ABT t) { { t.src_access_order } -> std::convertible_to>; }; + constexpr bool has_src_vector = requires(ABT t) { { t.src_vector_tensor_lengths } -> std::convertible_to>; }; + constexpr bool has_src_contiguous = requires(ABT t) { { t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to>; }; + constexpr bool has_dst_vector = requires(ABT t) { { t.dst_vector_tensor_lengths } -> std::convertible_to>; }; + + msg += " → transfer.a.block_transfer.thread_slice_lengths: " + std::string(CHECK_MARK(has_thread_slice)) + (has_thread_slice ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.a.block_transfer.thread_cluster_lengths: " + std::string(CHECK_MARK(has_thread_cluster)) + (has_thread_cluster ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.a.block_transfer.thread_cluster_arrange_order: " + std::string(CHECK_MARK(has_cluster_arrange)) + (has_cluster_arrange ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.a.block_transfer.src_access_order: " + std::string(CHECK_MARK(has_src_access)) + (has_src_access ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.a.block_transfer.src_vector_tensor_lengths: " + std::string(CHECK_MARK(has_src_vector)) + (has_src_vector ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.a.block_transfer.src_vector_tensor_contiguous_dim_order: " + std::string(CHECK_MARK(has_src_contiguous)) + (has_src_contiguous ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.a.block_transfer.dst_vector_tensor_lengths: " + std::string(CHECK_MARK(has_dst_vector)) + (has_dst_vector ? "\n" : " (missing or wrong type)\n"); + } else if constexpr (has_a) { + msg += " → T::transfer.a.block_transfer: [✗] (missing)\n"; + } + + // Similar checks for transfer.b + if constexpr (has_b && requires { T::transfer.b.block_transfer; }) { + msg += " → T::transfer.b.block_transfer: [✓] (similar fields as transfer.a)\n"; + } else if constexpr (has_b) { + msg += " → T::transfer.b.block_transfer: [✗] (missing)\n"; + } + + return msg; +} + +template +consteval auto detailed_diagnostic_SpecifiesDlEpilogue() -> std::string { + std::string msg; + constexpr bool has_transfer = requires { T::transfer; }; + if constexpr (!has_transfer) { + return " → T::transfer member: [✗] (not found)\n"; + } + + constexpr bool has_c = requires { T::transfer.c; }; + msg += " → T::transfer.c: " + std::string(CHECK_MARK(has_c)) + "\n"; + + if constexpr (has_c && requires { T::transfer.c.epilogue; }) { + using E = decltype(T::transfer.c.epilogue); + constexpr bool has_src_dst_access = requires(E t) { { t.src_dst_access_order } -> std::convertible_to>; }; + constexpr bool has_src_dst_vector_dim = requires(E t) { { t.src_dst_vector_dim } -> std::convertible_to; }; + constexpr bool has_dst_scalar = requires(E t) { { t.dst_scalar_per_vector } -> std::convertible_to; }; + + msg += " → transfer.c.epilogue.src_dst_access_order: " + std::string(CHECK_MARK(has_src_dst_access)) + (has_src_dst_access ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.c.epilogue.src_dst_vector_dim: " + std::string(CHECK_MARK(has_src_dst_vector_dim)) + (has_src_dst_vector_dim ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.c.epilogue.dst_scalar_per_vector: " + std::string(CHECK_MARK(has_dst_scalar)) + (has_dst_scalar ? "\n" : " (missing or wrong type)\n"); + } else if constexpr (has_c) { + msg += " → T::transfer.c.epilogue: [✗] (missing)\n"; + } + + return msg; +} + +} // namespace ck::detail::diagnostics diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index f96fc2c86be..4f268b05e92 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -3,20 +3,11 @@ #pragma once -#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_diagnostics.hpp" namespace ck_tile::builder::factory { -#define CHECK_MARK(cond) (cond ? "[✓]" : "[✗]") - -// Macro to check a concept and generate both the boolean and the string representation -#define CHECK_CONCEPT(Type, Concept) \ - static constexpr bool c_##Concept = Concept; \ - static constexpr const char* s_##Concept = #Concept; - -// Helper to create diagnostic message line -#define DIAGNOSTIC_LINE(Concept) \ - " " + std::string(s_##Concept) + ": " + std::string(CHECK_MARK(c_##Concept)) + "\n" +using namespace ck_tile::builder::diagnostics; template struct FwdXdlV3Algorithm { @@ -46,7 +37,7 @@ struct FwdXdlV3Algorithm { return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && c10; } - static consteval const std::string message() { + static consteval auto message() -> std::string { return std::string("\n=== Forward XDL V3 Algorithm Diagnostic (closest match) ===\n" "Concepts for FwdXdlV3 Algorithm:\n") + DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + @@ -94,7 +85,7 @@ struct FwdXdlAlgorithm { return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && c10 && c11 && c12; } - static consteval const std::string message() { + static consteval auto message() -> std::string { return std::string("\n=== Forward XDL Algorithm Diagnostic (closest match) ===\n" "Concepts for FwdXdl Algorithm:\n") + DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + @@ -142,7 +133,7 @@ struct FwdWmmaAlgorithm { return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && c10 && c11; } - static consteval const std::string message() { + static consteval auto message() -> std::string { return std::string("\n=== Forward WMMA Algorithm Diagnostic (closest match) ===\n" "Concepts for FwdWmma Algorithm:\n") + DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + @@ -183,7 +174,7 @@ struct FwdDlAlgorithm { return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8; } - static consteval const std::string message() { + static consteval auto message() -> std::string { return std::string("\n=== Forward DL Algorithm Diagnostic (closest match) ===\n" "Concepts for FwdDl Algorithm:\n") + DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + @@ -217,7 +208,7 @@ struct TileAlgorithm { return c1 && c2 && c3 && c4 && c5 && c6; } - static consteval const std::string message() { + static consteval auto message() -> std::string { return std::string("\n=== CK Tile Algorithm Diagnostic (closest match) ===\n" "Concepts for CK Tile Conv Algorithm:\n") + DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + @@ -241,7 +232,7 @@ struct LargeTensorAlgorithm : public FwdXdlAlgorithm::is_valid() && c13; } - static consteval const std::string message() { + static consteval auto message() -> std::string { return FwdXdlAlgorithm::message() + DIAGNOSTIC_LINE(SpecifiesLargeTensorSupport); } @@ -273,7 +264,7 @@ struct BwdXdlAlgorithm { return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9; } - static consteval const std::string message() { + static consteval auto message() -> std::string { return std::string("\n=== Backward XDL Algorithm Diagnostic (closest match) ===\n" "Concepts for BwdXdl Algorithm:\n") + DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp index 366c2b27514..975212999b0 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp @@ -25,7 +25,7 @@ constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CSh .with_gemm_config(cku::BwdGemmParams_Xdl_4x4_per_wave) .with_transfer(cku::Transfer_4x64x1) .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT); - + using Builder = ckb::ConvBuilder; using Instance = Builder::Instance; From 8d40e6d9fe8bc5222cc9f23360a84b1e74a4851b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Mon, 22 Dec 2025 08:39:01 -0500 Subject: [PATCH 08/81] Small improvements. --- .../include/ck_tile/builder/conv_algorithm_diagnostics.hpp | 6 ++++-- experimental/builder/test/impl/conv_algorithm_types.hpp | 3 +++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp index d97ee85abe4..bd6d8778d9f 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp @@ -198,7 +198,8 @@ consteval auto detailed_diagnostic_SpecifiesGridwiseFwdXdlGemm() -> std::string (has_ak1 ? "\n" : " (missing or wrong type)\n"); msg += " → T::bk1: " + std::string(CHECK_MARK(has_bk1)) + (has_bk1 ? "\n" : " (missing or wrong type)\n"); - msg += " → T::xdl_params member: " + std::string(CHECK_MARK(has_xdl_params)) + "\n"; + msg += " → T::xdl_params member: " + std::string(CHECK_MARK(has_xdl_params)) + + (has_xdl_params ? "\n" : " (missing or wrong type)\n"); if constexpr (has_xdl_params) { msg += detail::diagnose_xdl_params(); @@ -219,7 +220,8 @@ consteval auto detailed_diagnostic_SpecifiesGridwiseBwdXdlGemm() -> std::string (has_k0 ? "\n" : " (missing or wrong type)\n"); msg += " → T::k1: " + std::string(CHECK_MARK(has_k1)) + (has_k1 ? "\n" : " (missing or wrong type)\n"); - msg += " → T::xdl_params member: " + std::string(CHECK_MARK(has_xdl_params)) + "\n"; + msg += " → T::xdl_params member: " + std::string(CHECK_MARK(has_xdl_params)) + + (has_xdl_params ? "\n" : " (missing or wrong type)\n"); if constexpr (has_xdl_params) { msg += detail::diagnose_xdl_params(); diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index 46babfa79a1..2fbd5d3fc7f 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -371,6 +371,9 @@ struct ConvAlgorithmTemplate : Components... { result.gridwise_gemm = gemm; } + else { + static_assert(false, "Unrecognized GemmConfig type"); + } return result; } From 9679d9b141148cc72280e29dc64c536a5a34628c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Mon, 22 Dec 2025 08:39:47 -0500 Subject: [PATCH 09/81] Improve missing member/wrong type compile-time errors. --- .../builder/conv_algorithm_diagnostics.hpp | 618 ++++++++++++------ 1 file changed, 432 insertions(+), 186 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp index bd6d8778d9f..64d181feff3 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp @@ -21,28 +21,85 @@ namespace ck_tile::builder::diagnostics { namespace detail { +// Helper to get type information +template +consteval auto get_type_info() -> const char* { + // Returns a descriptive string about the type + if constexpr (std::is_same_v) { + return " (type: size_t)"; + } else if constexpr (std::is_same_v) { + return " (type: int)"; + } else if constexpr (std::is_same_v) { + return " (type: bool)"; + } else if constexpr (std::is_same_v) { + return " (type: PipelineVersion)"; + } else if constexpr (std::is_same_v) { + return " (type: PipelineScheduler)"; + } else if constexpr (std::is_same_v) { + return " (type: ConvSpecialization)"; + } else if constexpr (std::is_same_v) { + return " (type: GemmSpecialization)"; + } else if constexpr (std::is_same_v) { + return " (type: TileConvSpecialization)"; + } else if constexpr (std::is_same_v) { + return " (type: ConvAlgorithmSpecialization)"; + } else if constexpr (std::is_same_v>) { + return " (type: std::array)"; + } else if constexpr (std::is_same_v>) { + return " (type: std::array)"; + } else if constexpr (std::is_same_v>) { + return " (type: std::array)"; + } else if constexpr (std::is_same_v>) { + return " (type: std::array)"; + } else { + return " (type: found but unknown)"; + } +} + // ThreadBlockDescriptor diagnostics template consteval auto diagnose_thread_block_descriptor() -> std::string { if constexpr (!requires { T::thread_block; }) { - return " → T::thread_block member: [✗] (not found)\n"; + return " → T::thread_block member: [✗] (missing member)\n"; } else { using TB = decltype(T::thread_block); std::string msg; - constexpr bool has_block_size = requires(TB t) { { t.block_size } -> std::convertible_to; }; - constexpr bool has_tile_m = requires(TB t) { { t.tile_size.m } -> std::convertible_to; }; - constexpr bool has_tile_n = requires(TB t) { { t.tile_size.n } -> std::convertible_to; }; - constexpr bool has_tile_k = requires(TB t) { { t.tile_size.k } -> std::convertible_to; }; + if constexpr (requires(TB t) { t.block_size; }) { + using BlockSizeType = decltype(std::declval().block_size); + constexpr bool convertible = std::convertible_to; + msg += " → thread_block.block_size: " + std::string(CHECK_MARK(convertible)) + + std::string(get_type_info()) + "\n"; + } else { + msg += " → thread_block.block_size: [✗] (missing member)\n"; + } + + if constexpr (requires(TB t) { t.tile_size.m; }) { + using TileMType = decltype(std::declval().tile_size.m); + constexpr bool convertible = std::convertible_to; + msg += " → thread_block.tile_size.m: " + std::string(CHECK_MARK(convertible)) + + std::string(get_type_info()) + "\n"; + } else { + msg += " → thread_block.tile_size.m: [✗] (missing member)\n"; + } - msg += " → thread_block.block_size: " + std::string(CHECK_MARK(has_block_size)) + - (has_block_size ? "\n" : " (missing or wrong type)\n"); - msg += " → thread_block.tile_size.m: " + std::string(CHECK_MARK(has_tile_m)) + - (has_tile_m ? "\n" : " (missing or wrong type)\n"); - msg += " → thread_block.tile_size.n: " + std::string(CHECK_MARK(has_tile_n)) + - (has_tile_n ? "\n" : " (missing or wrong type)\n"); - msg += " → thread_block.tile_size.k: " + std::string(CHECK_MARK(has_tile_k)) + - (has_tile_k ? "\n" : " (missing or wrong type)\n"); + if constexpr (requires(TB t) { t.tile_size.n; }) { + using TileNType = decltype(std::declval().tile_size.n); + constexpr bool convertible = std::convertible_to; + msg += " → thread_block.tile_size.n: " + std::string(CHECK_MARK(convertible)) + + std::string(get_type_info()) + "\n"; + } else { + msg += " → thread_block.tile_size.n: [✗] (missing member)\n"; + } + + if constexpr (requires(TB t) { t.tile_size.k; }) { + using TileKType = decltype(std::declval().tile_size.k); + constexpr bool convertible = std::convertible_to; + msg += " → thread_block.tile_size.k: " + std::string(CHECK_MARK(convertible)) + + std::string(get_type_info()) + "\n"; + } else { + msg += " → thread_block.tile_size.k: [✗] (missing member)\n"; + } return msg; } @@ -53,19 +110,41 @@ template consteval auto diagnose_xdl_params() -> std::string { std::string msg; - constexpr bool has_m_per_xdl = requires(XdlParams t) { { t.m_per_xdl } -> std::convertible_to; }; - constexpr bool has_n_per_xdl = requires(XdlParams t) { { t.n_per_xdl } -> std::convertible_to; }; - constexpr bool has_m_xdl_per_wave = requires(XdlParams t) { { t.m_xdl_per_wave } -> std::convertible_to; }; - constexpr bool has_n_xdl_per_wave = requires(XdlParams t) { { t.n_xdl_per_wave } -> std::convertible_to; }; + if constexpr (requires(XdlParams t) { t.m_per_xdl; }) { + using MPerXdlType = decltype(std::declval().m_per_xdl); + constexpr bool convertible = std::convertible_to; + msg += " → xdl_params.m_per_xdl: " + std::string(CHECK_MARK(convertible)) + + std::string(get_type_info()) + "\n"; + } else { + msg += " → xdl_params.m_per_xdl: [✗] (missing member)\n"; + } - msg += " → xdl_params.m_per_xdl: " + std::string(CHECK_MARK(has_m_per_xdl)) + - (has_m_per_xdl ? "\n" : " (missing or wrong type)\n"); - msg += " → xdl_params.n_per_xdl: " + std::string(CHECK_MARK(has_n_per_xdl)) + - (has_n_per_xdl ? "\n" : " (missing or wrong type)\n"); - msg += " → xdl_params.m_xdl_per_wave: " + std::string(CHECK_MARK(has_m_xdl_per_wave)) + - (has_m_xdl_per_wave ? "\n" : " (missing or wrong type)\n"); - msg += " → xdl_params.n_xdl_per_wave: " + std::string(CHECK_MARK(has_n_xdl_per_wave)) + - (has_n_xdl_per_wave ? "\n" : " (missing or wrong type)\n"); + if constexpr (requires(XdlParams t) { t.n_per_xdl; }) { + using NPerXdlType = decltype(std::declval().n_per_xdl); + constexpr bool convertible = std::convertible_to; + msg += " → xdl_params.n_per_xdl: " + std::string(CHECK_MARK(convertible)) + + std::string(get_type_info()) + "\n"; + } else { + msg += " → xdl_params.n_per_xdl: [✗] (missing member)\n"; + } + + if constexpr (requires(XdlParams t) { t.m_xdl_per_wave; }) { + using MXdlPerWaveType = decltype(std::declval().m_xdl_per_wave); + constexpr bool convertible = std::convertible_to; + msg += " → xdl_params.m_xdl_per_wave: " + std::string(CHECK_MARK(convertible)) + + std::string(get_type_info()) + "\n"; + } else { + msg += " → xdl_params.m_xdl_per_wave: [✗] (missing member)\n"; + } + + if constexpr (requires(XdlParams t) { t.n_xdl_per_wave; }) { + using NXdlPerWaveType = decltype(std::declval().n_xdl_per_wave); + constexpr bool convertible = std::convertible_to; + msg += " → xdl_params.n_xdl_per_wave: " + std::string(CHECK_MARK(convertible)) + + std::string(get_type_info()) + "\n"; + } else { + msg += " → xdl_params.n_xdl_per_wave: [✗] (missing member)\n"; + } return msg; } @@ -75,16 +154,32 @@ template consteval auto diagnose_block_transfer(const char* prefix) -> std::string { std::string msg; - constexpr bool has_k0 = requires(BT t) { { t.k0 } -> std::convertible_to; }; - constexpr bool has_m_n = requires(BT t) { { t.m_n } -> std::convertible_to; }; - constexpr bool has_k1 = requires(BT t) { { t.k1 } -> std::convertible_to; }; + if constexpr (requires(BT t) { t.k0; }) { + using K0Type = decltype(std::declval().k0); + constexpr bool convertible = std::convertible_to; + msg += std::string(" → ") + prefix + ".k0: " + std::string(CHECK_MARK(convertible)) + + std::string(get_type_info()) + "\n"; + } else { + msg += std::string(" → ") + prefix + ".k0: [✗] (missing member)\n"; + } - msg += std::string(" → ") + prefix + ".k0: " + std::string(CHECK_MARK(has_k0)) + - (has_k0 ? "\n" : " (missing or wrong type)\n"); - msg += std::string(" → ") + prefix + ".m_n: " + std::string(CHECK_MARK(has_m_n)) + - (has_m_n ? "\n" : " (missing or wrong type)\n"); - msg += std::string(" → ") + prefix + ".k1: " + std::string(CHECK_MARK(has_k1)) + - (has_k1 ? "\n" : " (missing or wrong type)\n"); + if constexpr (requires(BT t) { t.m_n; }) { + using MNType = decltype(std::declval().m_n); + constexpr bool convertible = std::convertible_to; + msg += std::string(" → ") + prefix + ".m_n: " + std::string(CHECK_MARK(convertible)) + + std::string(get_type_info()) + "\n"; + } else { + msg += std::string(" → ") + prefix + ".m_n: [✗] (missing member)\n"; + } + + if constexpr (requires(BT t) { t.k1; }) { + using K1Type = decltype(std::declval().k1); + constexpr bool convertible = std::convertible_to; + msg += std::string(" → ") + prefix + ".k1: " + std::string(CHECK_MARK(convertible)) + + std::string(get_type_info()) + "\n"; + } else { + msg += std::string(" → ") + prefix + ".k1: [✗] (missing member)\n"; + } return msg; } @@ -94,22 +189,50 @@ template consteval auto diagnose_lds_transfer(const char* prefix) -> std::string { std::string msg; - constexpr bool has_src_vector_dim = requires(LT t) { { t.src_vector_dim } -> std::convertible_to; }; - constexpr bool has_src_scalar_per_vector = requires(LT t) { { t.src_scalar_per_vector } -> std::convertible_to; }; - constexpr bool has_lds_dst_scalar_per_vector = requires(LT t) { { t.lds_dst_scalar_per_vector } -> std::convertible_to; }; - constexpr bool has_is_direct_load = requires(LT t) { { t.is_direct_load } -> std::convertible_to; }; - constexpr bool has_lds_padding = requires(LT t) { { t.lds_padding } -> std::convertible_to; }; - - msg += std::string(" → ") + prefix + ".src_vector_dim: " + std::string(CHECK_MARK(has_src_vector_dim)) + - (has_src_vector_dim ? "\n" : " (missing or wrong type)\n"); - msg += std::string(" → ") + prefix + ".src_scalar_per_vector: " + std::string(CHECK_MARK(has_src_scalar_per_vector)) + - (has_src_scalar_per_vector ? "\n" : " (missing or wrong type)\n"); - msg += std::string(" → ") + prefix + ".lds_dst_scalar_per_vector: " + std::string(CHECK_MARK(has_lds_dst_scalar_per_vector)) + - (has_lds_dst_scalar_per_vector ? "\n" : " (missing or wrong type)\n"); - msg += std::string(" → ") + prefix + ".is_direct_load: " + std::string(CHECK_MARK(has_is_direct_load)) + - (has_is_direct_load ? "\n" : " (missing or wrong type)\n"); - msg += std::string(" → ") + prefix + ".lds_padding: " + std::string(CHECK_MARK(has_lds_padding)) + - (has_lds_padding ? "\n" : " (missing or wrong type)\n"); + if constexpr (requires(LT t) { t.src_vector_dim; }) { + using SrcVectorDimType = decltype(std::declval().src_vector_dim); + constexpr bool convertible = std::convertible_to; + msg += std::string(" → ") + prefix + ".src_vector_dim: " + std::string(CHECK_MARK(convertible)) + + std::string(get_type_info()) + "\n"; + } else { + msg += std::string(" → ") + prefix + ".src_vector_dim: [✗] (missing member)\n"; + } + + if constexpr (requires(LT t) { t.src_scalar_per_vector; }) { + using SrcScalarType = decltype(std::declval().src_scalar_per_vector); + constexpr bool convertible = std::convertible_to; + msg += std::string(" → ") + prefix + ".src_scalar_per_vector: " + std::string(CHECK_MARK(convertible)) + + std::string(get_type_info()) + "\n"; + } else { + msg += std::string(" → ") + prefix + ".src_scalar_per_vector: [✗] (missing member)\n"; + } + + if constexpr (requires(LT t) { t.lds_dst_scalar_per_vector; }) { + using LdsDstScalarType = decltype(std::declval().lds_dst_scalar_per_vector); + constexpr bool convertible = std::convertible_to; + msg += std::string(" → ") + prefix + ".lds_dst_scalar_per_vector: " + std::string(CHECK_MARK(convertible)) + + std::string(get_type_info()) + "\n"; + } else { + msg += std::string(" → ") + prefix + ".lds_dst_scalar_per_vector: [✗] (missing member)\n"; + } + + if constexpr (requires(LT t) { t.is_direct_load; }) { + using IsDirectLoadType = decltype(std::declval().is_direct_load); + constexpr bool convertible = std::convertible_to; + msg += std::string(" → ") + prefix + ".is_direct_load: " + std::string(CHECK_MARK(convertible)) + + std::string(get_type_info()) + "\n"; + } else { + msg += std::string(" → ") + prefix + ".is_direct_load: [✗] (missing member)\n"; + } + + if constexpr (requires(LT t) { t.lds_padding; }) { + using LdsPaddingType = decltype(std::declval().lds_padding); + constexpr bool convertible = std::convertible_to; + msg += std::string(" → ") + prefix + ".lds_padding: " + std::string(CHECK_MARK(convertible)) + + std::string(get_type_info()) + "\n"; + } else { + msg += std::string(" → ") + prefix + ".lds_padding: [✗] (missing member)\n"; + } return msg; } @@ -119,19 +242,41 @@ template consteval auto diagnose_thread_cluster(const char* prefix) -> std::string { std::string msg; - constexpr bool has_m_block = requires(TC t) { { t.m_block } -> std::convertible_to; }; - constexpr bool has_m_wave_per_xdl = requires(TC t) { { t.m_wave_per_xdl } -> std::convertible_to; }; - constexpr bool has_n_block = requires(TC t) { { t.n_block } -> std::convertible_to; }; - constexpr bool has_n_wave_per_xdl = requires(TC t) { { t.n_wave_per_xdl } -> std::convertible_to; }; + if constexpr (requires(TC t) { t.m_block; }) { + using MBlockType = decltype(std::declval().m_block); + constexpr bool convertible = std::convertible_to; + msg += std::string(" → ") + prefix + ".m_block: " + std::string(CHECK_MARK(convertible)) + + std::string(get_type_info()) + "\n"; + } else { + msg += std::string(" → ") + prefix + ".m_block: [✗] (missing member)\n"; + } + + if constexpr (requires(TC t) { t.m_wave_per_xdl; }) { + using MWaveType = decltype(std::declval().m_wave_per_xdl); + constexpr bool convertible = std::convertible_to; + msg += std::string(" → ") + prefix + ".m_wave_per_xdl: " + std::string(CHECK_MARK(convertible)) + + std::string(get_type_info()) + "\n"; + } else { + msg += std::string(" → ") + prefix + ".m_wave_per_xdl: [✗] (missing member)\n"; + } + + if constexpr (requires(TC t) { t.n_block; }) { + using NBlockType = decltype(std::declval().n_block); + constexpr bool convertible = std::convertible_to; + msg += std::string(" → ") + prefix + ".n_block: " + std::string(CHECK_MARK(convertible)) + + std::string(get_type_info()) + "\n"; + } else { + msg += std::string(" → ") + prefix + ".n_block: [✗] (missing member)\n"; + } - msg += std::string(" → ") + prefix + ".m_block: " + std::string(CHECK_MARK(has_m_block)) + - (has_m_block ? "\n" : " (missing or wrong type)\n"); - msg += std::string(" → ") + prefix + ".m_wave_per_xdl: " + std::string(CHECK_MARK(has_m_wave_per_xdl)) + - (has_m_wave_per_xdl ? "\n" : " (missing or wrong type)\n"); - msg += std::string(" → ") + prefix + ".n_block: " + std::string(CHECK_MARK(has_n_block)) + - (has_n_block ? "\n" : " (missing or wrong type)\n"); - msg += std::string(" → ") + prefix + ".n_wave_per_xdl: " + std::string(CHECK_MARK(has_n_wave_per_xdl)) + - (has_n_wave_per_xdl ? "\n" : " (missing or wrong type)\n"); + if constexpr (requires(TC t) { t.n_wave_per_xdl; }) { + using NWaveType = decltype(std::declval().n_wave_per_xdl); + constexpr bool convertible = std::convertible_to; + msg += std::string(" → ") + prefix + ".n_wave_per_xdl: " + std::string(CHECK_MARK(convertible)) + + std::string(get_type_info()) + "\n"; + } else { + msg += std::string(" → ") + prefix + ".n_wave_per_xdl: [✗] (missing member)\n"; + } return msg; } @@ -141,10 +286,14 @@ template consteval auto diagnose_access_order(const char* prefix) -> std::string { std::string msg; - constexpr bool has_order = requires(AO t) { { t.order } -> std::convertible_to>; }; - - msg += std::string(" → ") + prefix + ".order: " + std::string(CHECK_MARK(has_order)) + - (has_order ? "\n" : " (missing or wrong type)\n"); + if constexpr (requires(AO t) { t.order; }) { + using OrderType = decltype(std::declval().order); + constexpr bool convertible = std::convertible_to>; + msg += std::string(" → ") + prefix + ".order: " + std::string(CHECK_MARK(convertible)) + + std::string(get_type_info()) + "\n"; + } else { + msg += std::string(" → ") + prefix + ".order: [✗] (missing member)\n"; + } return msg; } @@ -154,16 +303,32 @@ template consteval auto diagnose_epilogue(const char* prefix) -> std::string { std::string msg; - constexpr bool has_m_xdl = requires(E t) { { t.m_xdl_per_wave_per_shuffle } -> std::convertible_to; }; - constexpr bool has_n_per_wave = requires(E t) { { t.n_per_wave_per_shuffle } -> std::convertible_to; }; - constexpr bool has_scalar_per_vector = requires(E t) { { t.scalar_per_vector } -> std::convertible_to; }; + if constexpr (requires(E t) { t.m_xdl_per_wave_per_shuffle; }) { + using MXdlType = decltype(std::declval().m_xdl_per_wave_per_shuffle); + constexpr bool convertible = std::convertible_to; + msg += std::string(" → ") + prefix + ".m_xdl_per_wave_per_shuffle: " + std::string(CHECK_MARK(convertible)) + + std::string(get_type_info()) + "\n"; + } else { + msg += std::string(" → ") + prefix + ".m_xdl_per_wave_per_shuffle: [✗] (missing member)\n"; + } - msg += std::string(" → ") + prefix + ".m_xdl_per_wave_per_shuffle: " + std::string(CHECK_MARK(has_m_xdl)) + - (has_m_xdl ? "\n" : " (missing or wrong type)\n"); - msg += std::string(" → ") + prefix + ".n_per_wave_per_shuffle: " + std::string(CHECK_MARK(has_n_per_wave)) + - (has_n_per_wave ? "\n" : " (missing or wrong type)\n"); - msg += std::string(" → ") + prefix + ".scalar_per_vector: " + std::string(CHECK_MARK(has_scalar_per_vector)) + - (has_scalar_per_vector ? "\n" : " (missing or wrong type)\n"); + if constexpr (requires(E t) { t.n_per_wave_per_shuffle; }) { + using NPerWaveType = decltype(std::declval().n_per_wave_per_shuffle); + constexpr bool convertible = std::convertible_to; + msg += std::string(" → ") + prefix + ".n_per_wave_per_shuffle: " + std::string(CHECK_MARK(convertible)) + + std::string(get_type_info()) + "\n"; + } else { + msg += std::string(" → ") + prefix + ".n_per_wave_per_shuffle: [✗] (missing member)\n"; + } + + if constexpr (requires(E t) { t.scalar_per_vector; }) { + using ScalarType = decltype(std::declval().scalar_per_vector); + constexpr bool convertible = std::convertible_to; + msg += std::string(" → ") + prefix + ".scalar_per_vector: " + std::string(CHECK_MARK(convertible)) + + std::string(get_type_info()) + "\n"; + } else { + msg += std::string(" → ") + prefix + ".scalar_per_vector: [✗] (missing member)\n"; + } return msg; } @@ -190,19 +355,29 @@ template consteval auto detailed_diagnostic_SpecifiesGridwiseFwdXdlGemm() -> std::string { std::string msg; - constexpr bool has_ak1 = requires { { T::ak1 } -> std::convertible_to; }; - constexpr bool has_bk1 = requires { { T::bk1 } -> std::convertible_to; }; - constexpr bool has_xdl_params = requires { T::xdl_params; }; + if constexpr (requires { T::ak1; }) { + using AK1Type = decltype(T::ak1); + constexpr bool convertible = std::convertible_to; + msg += " → T::ak1: " + std::string(CHECK_MARK(convertible)) + + std::string(detail::get_type_info()) + "\n"; + } else { + msg += " → T::ak1: [✗] (missing member)\n"; + } - msg += " → T::ak1: " + std::string(CHECK_MARK(has_ak1)) + - (has_ak1 ? "\n" : " (missing or wrong type)\n"); - msg += " → T::bk1: " + std::string(CHECK_MARK(has_bk1)) + - (has_bk1 ? "\n" : " (missing or wrong type)\n"); - msg += " → T::xdl_params member: " + std::string(CHECK_MARK(has_xdl_params)) + - (has_xdl_params ? "\n" : " (missing or wrong type)\n"); + if constexpr (requires { T::bk1; }) { + using BK1Type = decltype(T::bk1); + constexpr bool convertible = std::convertible_to; + msg += " → T::bk1: " + std::string(CHECK_MARK(convertible)) + + std::string(detail::get_type_info()) + "\n"; + } else { + msg += " → T::bk1: [✗] (missing member)\n"; + } - if constexpr (has_xdl_params) { + if constexpr (requires { T::xdl_params; }) { + msg += " → T::xdl_params member: [✓]\n"; msg += detail::diagnose_xdl_params(); + } else { + msg += " → T::xdl_params: [✗] (missing member)\n"; } return msg; @@ -212,19 +387,29 @@ template consteval auto detailed_diagnostic_SpecifiesGridwiseBwdXdlGemm() -> std::string { std::string msg; - constexpr bool has_k0 = requires { { T::k0_per_block } -> std::convertible_to; }; - constexpr bool has_k1 = requires { { T::k1 } -> std::convertible_to; }; - constexpr bool has_xdl_params = requires { T::xdl_params; }; + if constexpr (requires { T::k0_per_block; }) { + using K0Type = decltype(T::k0_per_block); + constexpr bool convertible = std::convertible_to; + msg += " → T::k0_per_block: " + std::string(CHECK_MARK(convertible)) + + std::string(detail::get_type_info()) + "\n"; + } else { + msg += " → T::k0_per_block: [✗] (missing member)\n"; + } - msg += " → T::k0_per_block: " + std::string(CHECK_MARK(has_k0)) + - (has_k0 ? "\n" : " (missing or wrong type)\n"); - msg += " → T::k1: " + std::string(CHECK_MARK(has_k1)) + - (has_k1 ? "\n" : " (missing or wrong type)\n"); - msg += " → T::xdl_params member: " + std::string(CHECK_MARK(has_xdl_params)) + - (has_xdl_params ? "\n" : " (missing or wrong type)\n"); + if constexpr (requires { T::k1; }) { + using K1Type = decltype(T::k1); + constexpr bool convertible = std::convertible_to; + msg += " → T::k1: " + std::string(CHECK_MARK(convertible)) + + std::string(detail::get_type_info()) + "\n"; + } else { + msg += " → T::k1: [✗] (missing member)\n"; + } - if constexpr (has_xdl_params) { + if constexpr (requires { T::xdl_params; }) { + msg += " → T::xdl_params member: [✓]\n"; msg += detail::diagnose_xdl_params(); + } else { + msg += " → T::xdl_params: [✗] (missing member)\n"; } return msg; @@ -270,49 +455,13 @@ consteval auto detailed_diagnostic_SpecifiesBlockTransfer() -> std::string { return msg; } -template -consteval auto detailed_diagnostic_SpecifiesLdsTransfer() -> std::string { - std::string msg; - - constexpr bool has_transfer = requires { T::transfer; }; - msg += " → T::transfer member: " + std::string(CHECK_MARK(has_transfer)) + "\n"; - - if constexpr (!has_transfer) { - return msg; - } - - constexpr bool has_a = requires { T::transfer.a; }; - constexpr bool has_b = requires { T::transfer.b; }; - constexpr bool has_c = requires { T::transfer.c; }; - - if constexpr (has_a && requires { T::transfer.a.lds_transfer; }) { - msg += detail::diagnose_lds_transfer("transfer.a.lds_transfer"); - } else if constexpr (has_a) { - msg += " → T::transfer.a.lds_transfer: [✗] (missing)\n"; - } - - if constexpr (has_b && requires { T::transfer.b.lds_transfer; }) { - msg += detail::diagnose_lds_transfer("transfer.b.lds_transfer"); - } else if constexpr (has_b) { - msg += " → T::transfer.b.lds_transfer: [✗] (missing)\n"; - } - - if constexpr (has_c && requires { T::transfer.c.epilogue; }) { - msg += detail::diagnose_epilogue("transfer.c.epilogue"); - } else if constexpr (has_c) { - msg += " → T::transfer.c.epilogue: [✗] (missing)\n"; - } - - return msg; -} - template consteval auto detailed_diagnostic_SpecifiesThreadClusterAccessOrder() -> std::string { std::string msg; constexpr bool has_transfer = requires { T::transfer; }; if constexpr (!has_transfer) { - return " → T::transfer member: [✗] (not found)\n"; + return " → T::transfer member: [✗] (missing member)\n"; } constexpr bool has_a = requires { T::transfer.a; }; @@ -321,13 +470,13 @@ consteval auto detailed_diagnostic_SpecifiesThreadClusterAccessOrder() -> std::s if constexpr (has_a && requires { T::transfer.a.block_transfer_access_order; }) { msg += detail::diagnose_access_order("transfer.a.block_transfer_access_order"); } else if constexpr (has_a) { - msg += " → T::transfer.a.block_transfer_access_order: [✗] (missing)\n"; + msg += " → T::transfer.a.block_transfer_access_order: [✗] (missing member)\n"; } if constexpr (has_b && requires { T::transfer.b.block_transfer_access_order; }) { msg += detail::diagnose_access_order("transfer.b.block_transfer_access_order"); } else if constexpr (has_b) { - msg += " → T::transfer.b.block_transfer_access_order: [✗] (missing)\n"; + msg += " → T::transfer.b.block_transfer_access_order: [✗] (missing member)\n"; } return msg; @@ -339,7 +488,7 @@ consteval auto detailed_diagnostic_SpecifiesSourceAccessOrder() -> std::string { constexpr bool has_transfer = requires { T::transfer; }; if constexpr (!has_transfer) { - return " → T::transfer member: [✗] (not found)\n"; + return " → T::transfer member: [✗] (missing member)\n"; } constexpr bool has_a = requires { T::transfer.a; }; @@ -348,13 +497,13 @@ consteval auto detailed_diagnostic_SpecifiesSourceAccessOrder() -> std::string { if constexpr (has_a && requires { T::transfer.a.src_access_order; }) { msg += detail::diagnose_access_order("transfer.a.src_access_order"); } else if constexpr (has_a) { - msg += " → T::transfer.a.src_access_order: [✗] (missing)\n"; + msg += " → T::transfer.a.src_access_order: [✗] (missing member)\n"; } if constexpr (has_b && requires { T::transfer.b.src_access_order; }) { msg += detail::diagnose_access_order("transfer.b.src_access_order"); } else if constexpr (has_b) { - msg += " → T::transfer.b.src_access_order: [✗] (missing)\n"; + msg += " → T::transfer.b.src_access_order: [✗] (missing member)\n"; } return msg; @@ -364,76 +513,120 @@ template consteval auto detailed_diagnostic_SpecifiesBlockGemm() -> std::string { std::string msg; - constexpr bool has_block_gemm = requires { T::block_gemm; }; - msg += " → T::block_gemm member: " + std::string(CHECK_MARK(has_block_gemm)) + "\n"; - - if constexpr (!has_block_gemm) { - return msg; + if constexpr (!requires { T::block_gemm; }) { + return " → T::block_gemm: [✗] (missing member)\n"; } - constexpr bool has_pipeline = requires { { T::block_gemm.pipeline_version } -> std::convertible_to; }; - constexpr bool has_scheduler = requires { { T::block_gemm.scheduler } -> std::convertible_to; }; + msg += " → T::block_gemm member: [✓]\n"; - msg += " → block_gemm.pipeline_version: " + std::string(CHECK_MARK(has_pipeline)) + - (has_pipeline ? "\n" : " (missing or wrong type)\n"); - msg += " → block_gemm.scheduler: " + std::string(CHECK_MARK(has_scheduler)) + - (has_scheduler ? "\n" : " (missing or wrong type)\n"); + if constexpr (requires { T::block_gemm.pipeline_version; }) { + using PipelineType = decltype(T::block_gemm.pipeline_version); + constexpr bool convertible = std::convertible_to; + msg += " → block_gemm.pipeline_version: " + std::string(CHECK_MARK(convertible)) + + std::string(detail::get_type_info()) + "\n"; + } else { + msg += " → block_gemm.pipeline_version: [✗] (missing member)\n"; + } + + if constexpr (requires { T::block_gemm.scheduler; }) { + using SchedulerType = decltype(T::block_gemm.scheduler); + constexpr bool convertible = std::convertible_to; + msg += " → block_gemm.scheduler: " + std::string(CHECK_MARK(convertible)) + + std::string(detail::get_type_info()) + "\n"; + } else { + msg += " → block_gemm.scheduler: [✗] (missing member)\n"; + } return msg; } template consteval auto detailed_diagnostic_SpecifiesFwdConvSpecialization() -> std::string { - constexpr bool has_member = requires { { T::fwd_specialization } -> std::convertible_to; }; - return " → T::fwd_specialization: " + std::string(CHECK_MARK(has_member)) + - (has_member ? "\n" : " (missing or wrong type)\n"); + if constexpr (requires { T::fwd_specialization; }) { + using FwdSpecType = decltype(T::fwd_specialization); + constexpr bool convertible = std::convertible_to; + return " → T::fwd_specialization: " + std::string(CHECK_MARK(convertible)) + + std::string(detail::get_type_info()) + "\n"; + } else { + return " → T::fwd_specialization: [✗] (missing member)\n"; + } } template consteval auto detailed_diagnostic_SpecifiesBwdWeightConvSpecialization() -> std::string { - constexpr bool has_member = requires { { T::bwd_weight_specialization } -> std::convertible_to; }; - return " → T::bwd_weight_specialization: " + std::string(CHECK_MARK(has_member)) + - (has_member ? "\n" : " (missing or wrong type)\n"); + if constexpr (requires { T::bwd_weight_specialization; }) { + using BwdSpecType = decltype(T::bwd_weight_specialization); + constexpr bool convertible = std::convertible_to; + return " → T::bwd_weight_specialization: " + std::string(CHECK_MARK(convertible)) + + std::string(detail::get_type_info()) + "\n"; + } else { + return " → T::bwd_weight_specialization: [✗] (missing member)\n"; + } } template consteval auto detailed_diagnostic_SpecifiesGemmSpecialization() -> std::string { - constexpr bool has_member = requires { { T::gemm_specialization } -> std::convertible_to; }; - return " → T::gemm_specialization: " + std::string(CHECK_MARK(has_member)) + - (has_member ? "\n" : " (missing or wrong type)\n"); + if constexpr (requires { T::gemm_specialization; }) { + using GemmSpecType = decltype(T::gemm_specialization); + constexpr bool convertible = std::convertible_to; + return " → T::gemm_specialization: " + std::string(CHECK_MARK(convertible)) + + std::string(detail::get_type_info()) + "\n"; + } else { + return " → T::gemm_specialization: [✗] (missing member)\n"; + } } template consteval auto detailed_diagnostic_SpecifiesNumPrefetchStages() -> std::string { - constexpr bool has_member = requires { { T::num_gemm_k_prefetch_stages } -> std::convertible_to; }; - return " → T::num_gemm_k_prefetch_stages: " + std::string(CHECK_MARK(has_member)) + - (has_member ? "\n" : " (missing or wrong type)\n"); + if constexpr (requires { T::num_gemm_k_prefetch_stages; }) { + using NumPrefetchType = decltype(T::num_gemm_k_prefetch_stages); + constexpr bool convertible = std::convertible_to; + return " → T::num_gemm_k_prefetch_stages: " + std::string(CHECK_MARK(convertible)) + + std::string(detail::get_type_info()) + "\n"; + } else { + return " → T::num_gemm_k_prefetch_stages: [✗] (missing member)\n"; + } } template consteval auto detailed_diagnostic_SpecifiesNumGroupsToMerge() -> std::string { - constexpr bool has_member = requires { { T::num_groups_to_merge } -> std::convertible_to; }; - return " → T::num_groups_to_merge: " + std::string(CHECK_MARK(has_member)) + - (has_member ? "\n" : " (missing or wrong type)\n"); + if constexpr (requires { T::num_groups_to_merge; }) { + using NumGroupsType = decltype(T::num_groups_to_merge); + constexpr bool convertible = std::convertible_to; + return " → T::num_groups_to_merge: " + std::string(CHECK_MARK(convertible)) + + std::string(detail::get_type_info()) + "\n"; + } else { + return " → T::num_groups_to_merge: [✗] (missing member)\n"; + } } template consteval auto detailed_diagnostic_SpecifiesLoopScheduler() -> std::string { - constexpr bool has_member = requires { { T::loop_scheduler } -> std::convertible_to; }; - return " → T::loop_scheduler: " + std::string(CHECK_MARK(has_member)) + - (has_member ? "\n" : " (missing or wrong type)\n"); + if constexpr (requires { T::loop_scheduler; }) { + using LoopSchedulerType = decltype(T::loop_scheduler); + constexpr bool convertible = std::convertible_to; + return " → T::loop_scheduler: " + std::string(CHECK_MARK(convertible)) + + std::string(detail::get_type_info()) + "\n"; + } else { + return " → T::loop_scheduler: [✗] (missing member)\n"; + } } template consteval auto detailed_diagnostic_SpecifiesLargeTensorSupport() -> std::string { std::string msg; - constexpr bool has_specialization = requires { { T::specialization } -> std::convertible_to; }; - msg += " → T::specialization: " + std::string(CHECK_MARK(has_specialization)) + - (has_specialization ? "\n" : " (missing or wrong type)\n"); - - if constexpr (has_specialization) { - constexpr bool is_large_tensor = (T::specialization == ConvAlgorithmSpecialization::LARGE_TENSOR); - msg += " → specialization == LARGE_TENSOR: " + std::string(CHECK_MARK(is_large_tensor)) + "\n"; + if constexpr (requires { T::specialization; }) { + using SpecType = decltype(T::specialization); + constexpr bool convertible = std::convertible_to; + msg += " → T::specialization: " + std::string(CHECK_MARK(convertible)) + + std::string(detail::get_type_info()) + "\n"; + + if constexpr (convertible) { + constexpr bool is_large_tensor = (T::specialization == ConvAlgorithmSpecialization::LARGE_TENSOR); + msg += " → specialization == LARGE_TENSOR: " + std::string(CHECK_MARK(is_large_tensor)) + "\n"; + } + } else { + msg += " → T::specialization: [✗] (missing member)\n"; } return msg; @@ -442,13 +635,24 @@ consteval auto detailed_diagnostic_SpecifiesLargeTensorSupport() -> std::string template consteval auto detailed_diagnostic_SpecifiesTransposeTransfer() -> std::string { std::string msg; - constexpr bool has_src = requires { { T::max_transpose_transfer_src_scalar_per_vector } -> std::convertible_to; }; - constexpr bool has_dst = requires { { T::max_transpose_transfer_dst_scalar_per_vector } -> std::convertible_to; }; - msg += " → T::max_transpose_transfer_src_scalar_per_vector: " + std::string(CHECK_MARK(has_src)) + - (has_src ? "\n" : " (missing or wrong type)\n"); - msg += " → T::max_transpose_transfer_dst_scalar_per_vector: " + std::string(CHECK_MARK(has_dst)) + - (has_dst ? "\n" : " (missing or wrong type)\n"); + if constexpr (requires { T::max_transpose_transfer_src_scalar_per_vector; }) { + using SrcType = decltype(T::max_transpose_transfer_src_scalar_per_vector); + constexpr bool convertible = std::convertible_to; + msg += " → T::max_transpose_transfer_src_scalar_per_vector: " + std::string(CHECK_MARK(convertible)) + + std::string(detail::get_type_info()) + "\n"; + } else { + msg += " → T::max_transpose_transfer_src_scalar_per_vector: [✗] (missing member)\n"; + } + + if constexpr (requires { T::max_transpose_transfer_dst_scalar_per_vector; }) { + using DstType = decltype(T::max_transpose_transfer_dst_scalar_per_vector); + constexpr bool convertible = std::convertible_to; + msg += " → T::max_transpose_transfer_dst_scalar_per_vector: " + std::string(CHECK_MARK(convertible)) + + std::string(detail::get_type_info()) + "\n"; + } else { + msg += " → T::max_transpose_transfer_dst_scalar_per_vector: [✗] (missing member)\n"; + } return msg; } @@ -523,12 +727,6 @@ consteval auto detailed_diagnostic_SpecifiesTileTransfer() -> std::string { return msg; } -template -consteval auto detailed_diagnostic_SpecifiesTileConvSpecialization() -> std::string { - constexpr bool has_member = requires { { T::specialization } -> std::convertible_to; }; - return " → T::specialization: " + std::string(CHECK_MARK(has_member)) + (has_member ? "\n" : " (missing or wrong type)\n"); -} - template consteval auto detailed_diagnostic_SpecifiesTileBlockGemm() -> std::string { std::string msg; @@ -707,4 +905,52 @@ consteval auto detailed_diagnostic_SpecifiesDlEpilogue() -> std::string { return msg; } -} // namespace ck::detail::diagnostics +template +consteval auto detailed_diagnostic_SpecifiesTileConvSpecialization() -> std::string { + if constexpr (requires { T::specialization; }) { + using SpecType = decltype(T::specialization); + constexpr bool convertible = std::convertible_to; + return " → T::specialization: " + std::string(CHECK_MARK(convertible)) + + std::string(detail::get_type_info()) + "\n"; + } else { + return " → T::specialization: [✗] (missing member)\n"; + } +} + +template +consteval auto detailed_diagnostic_SpecifiesLdsTransfer() -> std::string { + std::string msg; + + constexpr bool has_transfer = requires { T::transfer; }; + msg += " → T::transfer member: " + std::string(CHECK_MARK(has_transfer)) + "\n"; + + if constexpr (!has_transfer) { + return msg; + } + + constexpr bool has_a = requires { T::transfer.a; }; + constexpr bool has_b = requires { T::transfer.b; }; + constexpr bool has_c = requires { T::transfer.c; }; + + if constexpr (has_a && requires { T::transfer.a.lds_transfer; }) { + msg += detail::diagnose_lds_transfer("transfer.a.lds_transfer"); + } else if constexpr (has_a) { + msg += " → T::transfer.a.lds_transfer: [✗] (missing member)\n"; + } + + if constexpr (has_b && requires { T::transfer.b.lds_transfer; }) { + msg += detail::diagnose_lds_transfer("transfer.b.lds_transfer"); + } else if constexpr (has_b) { + msg += " → T::transfer.b.lds_transfer: [✗] (missing member)\n"; + } + + if constexpr (has_c && requires { T::transfer.c.epilogue; }) { + msg += detail::diagnose_epilogue("transfer.c.epilogue"); + } else if constexpr (has_c) { + msg += " → T::transfer.c.epilogue: [✗] (missing member)\n"; + } + + return msg; +} + +} // namespace ck_tile::builder::diagnostics From 5ee99d83d5e5157ab389c44c8e6f1bf75861dfa1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Mon, 22 Dec 2025 08:59:46 -0500 Subject: [PATCH 10/81] Improve compile time diagnostics. --- .../builder/conv_algorithm_concepts.hpp | 18 +- .../builder/conv_algorithm_diagnostics.hpp | 156 ++++++++++-------- .../test/impl/conv_algorithm_types.hpp | 4 +- 3 files changed, 102 insertions(+), 76 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index ddd8a09ec77..2a3d3cd75b8 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -157,7 +157,7 @@ concept SpecifiesTileThreadBlock = requires { // Concept to check if a struct specifies gridwise XDL GEMM info. template -concept SpecifiesGridwiseFwdXdlGemm = requires { +concept GridwiseFwdXdlGemmDescriptor = requires { { T::ak1 } -> std::convertible_to; { T::bk1 } -> std::convertible_to; { T::xdl_params } -> GridwiseXdlGemmDescriptor; @@ -165,16 +165,28 @@ concept SpecifiesGridwiseFwdXdlGemm = requires { // Concept to check if a struct specifies gridwise XDL GEMM info. template -concept SpecifiesGridwiseBwdXdlGemm = requires { +concept GridwiseBwdXdlGemmDescriptor = requires { { T::k0_per_block } -> std::convertible_to; { T::k1 } -> std::convertible_to; { T::xdl_params } -> GridwiseXdlGemmDescriptor; }; +// Concept to check if a struct specifies gridwise XDL GEMM info. +template +concept SpecifiesGridwiseFwdXdlGemm = requires { + { T::gridwise_gemm } -> GridwiseFwdXdlGemmDescriptor; +}; + +// Concept to check if a struct specifies gridwise XDL GEMM info. +template +concept SpecifiesGridwiseBwdXdlGemm = requires { + { T::gridwise_gemm } -> GridwiseFwdXdlGemmDescriptor; +}; + // Concept to check if a struct specifies gridwise WMMA GEMM info. template concept SpecifiesGridwiseWmmaGemm = requires { - { T::gridwise_gemm } -> GridwiseWmmaGemmDescriptor; + { T::gridwise_gemm } -> GridwiseBwdXdlGemmDescriptor; }; // Concept to check if a struct specifies convolution input and output block transfer info. diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp index 64d181feff3..1c4cafab8cb 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp @@ -113,37 +113,37 @@ consteval auto diagnose_xdl_params() -> std::string { if constexpr (requires(XdlParams t) { t.m_per_xdl; }) { using MPerXdlType = decltype(std::declval().m_per_xdl); constexpr bool convertible = std::convertible_to; - msg += " → xdl_params.m_per_xdl: " + std::string(CHECK_MARK(convertible)) + + msg += " → xdl_params.m_per_xdl: " + std::string(CHECK_MARK(convertible)) + std::string(get_type_info()) + "\n"; } else { - msg += " → xdl_params.m_per_xdl: [✗] (missing member)\n"; + msg += " → xdl_params.m_per_xdl: [✗] (missing member)\n"; } if constexpr (requires(XdlParams t) { t.n_per_xdl; }) { using NPerXdlType = decltype(std::declval().n_per_xdl); constexpr bool convertible = std::convertible_to; - msg += " → xdl_params.n_per_xdl: " + std::string(CHECK_MARK(convertible)) + + msg += " → xdl_params.n_per_xdl: " + std::string(CHECK_MARK(convertible)) + std::string(get_type_info()) + "\n"; } else { - msg += " → xdl_params.n_per_xdl: [✗] (missing member)\n"; + msg += " → xdl_params.n_per_xdl: [✗] (missing member)\n"; } if constexpr (requires(XdlParams t) { t.m_xdl_per_wave; }) { using MXdlPerWaveType = decltype(std::declval().m_xdl_per_wave); constexpr bool convertible = std::convertible_to; - msg += " → xdl_params.m_xdl_per_wave: " + std::string(CHECK_MARK(convertible)) + + msg += " → xdl_params.m_xdl_per_wave: " + std::string(CHECK_MARK(convertible)) + std::string(get_type_info()) + "\n"; } else { - msg += " → xdl_params.m_xdl_per_wave: [✗] (missing member)\n"; + msg += " → xdl_params.m_xdl_per_wave: [✗] (missing member)\n"; } if constexpr (requires(XdlParams t) { t.n_xdl_per_wave; }) { using NXdlPerWaveType = decltype(std::declval().n_xdl_per_wave); constexpr bool convertible = std::convertible_to; - msg += " → xdl_params.n_xdl_per_wave: " + std::string(CHECK_MARK(convertible)) + + msg += " → xdl_params.n_xdl_per_wave: " + std::string(CHECK_MARK(convertible)) + std::string(get_type_info()) + "\n"; } else { - msg += " → xdl_params.n_xdl_per_wave: [✗] (missing member)\n"; + msg += " → xdl_params.n_xdl_per_wave: [✗] (missing member)\n"; } return msg; @@ -157,28 +157,28 @@ consteval auto diagnose_block_transfer(const char* prefix) -> std::string { if constexpr (requires(BT t) { t.k0; }) { using K0Type = decltype(std::declval().k0); constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + ".k0: " + std::string(CHECK_MARK(convertible)) + + msg += std::string(" → ") + prefix + ".k0: " + std::string(CHECK_MARK(convertible)) + std::string(get_type_info()) + "\n"; } else { - msg += std::string(" → ") + prefix + ".k0: [✗] (missing member)\n"; + msg += std::string(" → ") + prefix + ".k0: [✗] (missing member)\n"; } if constexpr (requires(BT t) { t.m_n; }) { using MNType = decltype(std::declval().m_n); constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + ".m_n: " + std::string(CHECK_MARK(convertible)) + + msg += std::string(" → ") + prefix + ".m_n: " + std::string(CHECK_MARK(convertible)) + std::string(get_type_info()) + "\n"; } else { - msg += std::string(" → ") + prefix + ".m_n: [✗] (missing member)\n"; + msg += std::string(" → ") + prefix + ".m_n: [✗] (missing member)\n"; } if constexpr (requires(BT t) { t.k1; }) { using K1Type = decltype(std::declval().k1); constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + ".k1: " + std::string(CHECK_MARK(convertible)) + + msg += std::string(" → ") + prefix + ".k1: " + std::string(CHECK_MARK(convertible)) + std::string(get_type_info()) + "\n"; } else { - msg += std::string(" → ") + prefix + ".k1: [✗] (missing member)\n"; + msg += std::string(" → ") + prefix + ".k1: [✗] (missing member)\n"; } return msg; @@ -192,46 +192,46 @@ consteval auto diagnose_lds_transfer(const char* prefix) -> std::string { if constexpr (requires(LT t) { t.src_vector_dim; }) { using SrcVectorDimType = decltype(std::declval().src_vector_dim); constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + ".src_vector_dim: " + std::string(CHECK_MARK(convertible)) + + msg += std::string(" → ") + prefix + ".src_vector_dim: " + std::string(CHECK_MARK(convertible)) + std::string(get_type_info()) + "\n"; } else { - msg += std::string(" → ") + prefix + ".src_vector_dim: [✗] (missing member)\n"; + msg += std::string(" → ") + prefix + ".src_vector_dim: [✗] (missing member)\n"; } if constexpr (requires(LT t) { t.src_scalar_per_vector; }) { using SrcScalarType = decltype(std::declval().src_scalar_per_vector); constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + ".src_scalar_per_vector: " + std::string(CHECK_MARK(convertible)) + + msg += std::string(" → ") + prefix + ".src_scalar_per_vector: " + std::string(CHECK_MARK(convertible)) + std::string(get_type_info()) + "\n"; } else { - msg += std::string(" → ") + prefix + ".src_scalar_per_vector: [✗] (missing member)\n"; + msg += std::string(" → ") + prefix + ".src_scalar_per_vector: [✗] (missing member)\n"; } if constexpr (requires(LT t) { t.lds_dst_scalar_per_vector; }) { using LdsDstScalarType = decltype(std::declval().lds_dst_scalar_per_vector); constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + ".lds_dst_scalar_per_vector: " + std::string(CHECK_MARK(convertible)) + + msg += std::string(" → ") + prefix + ".lds_dst_scalar_per_vector: " + std::string(CHECK_MARK(convertible)) + std::string(get_type_info()) + "\n"; } else { - msg += std::string(" → ") + prefix + ".lds_dst_scalar_per_vector: [✗] (missing member)\n"; + msg += std::string(" → ") + prefix + ".lds_dst_scalar_per_vector: [✗] (missing member)\n"; } if constexpr (requires(LT t) { t.is_direct_load; }) { using IsDirectLoadType = decltype(std::declval().is_direct_load); constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + ".is_direct_load: " + std::string(CHECK_MARK(convertible)) + + msg += std::string(" → ") + prefix + ".is_direct_load: " + std::string(CHECK_MARK(convertible)) + std::string(get_type_info()) + "\n"; } else { - msg += std::string(" → ") + prefix + ".is_direct_load: [✗] (missing member)\n"; + msg += std::string(" → ") + prefix + ".is_direct_load: [✗] (missing member)\n"; } if constexpr (requires(LT t) { t.lds_padding; }) { using LdsPaddingType = decltype(std::declval().lds_padding); constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + ".lds_padding: " + std::string(CHECK_MARK(convertible)) + + msg += std::string(" → ") + prefix + ".lds_padding: " + std::string(CHECK_MARK(convertible)) + std::string(get_type_info()) + "\n"; } else { - msg += std::string(" → ") + prefix + ".lds_padding: [✗] (missing member)\n"; + msg += std::string(" → ") + prefix + ".lds_padding: [✗] (missing member)\n"; } return msg; @@ -245,37 +245,37 @@ consteval auto diagnose_thread_cluster(const char* prefix) -> std::string { if constexpr (requires(TC t) { t.m_block; }) { using MBlockType = decltype(std::declval().m_block); constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + ".m_block: " + std::string(CHECK_MARK(convertible)) + + msg += std::string(" → ") + prefix + ".m_block: " + std::string(CHECK_MARK(convertible)) + std::string(get_type_info()) + "\n"; } else { - msg += std::string(" → ") + prefix + ".m_block: [✗] (missing member)\n"; + msg += std::string(" → ") + prefix + ".m_block: [✗] (missing member)\n"; } if constexpr (requires(TC t) { t.m_wave_per_xdl; }) { using MWaveType = decltype(std::declval().m_wave_per_xdl); constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + ".m_wave_per_xdl: " + std::string(CHECK_MARK(convertible)) + + msg += std::string(" → ") + prefix + ".m_wave_per_xdl: " + std::string(CHECK_MARK(convertible)) + std::string(get_type_info()) + "\n"; } else { - msg += std::string(" → ") + prefix + ".m_wave_per_xdl: [✗] (missing member)\n"; + msg += std::string(" → ") + prefix + ".m_wave_per_xdl: [✗] (missing member)\n"; } if constexpr (requires(TC t) { t.n_block; }) { using NBlockType = decltype(std::declval().n_block); constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + ".n_block: " + std::string(CHECK_MARK(convertible)) + + msg += std::string(" → ") + prefix + ".n_block: " + std::string(CHECK_MARK(convertible)) + std::string(get_type_info()) + "\n"; } else { - msg += std::string(" → ") + prefix + ".n_block: [✗] (missing member)\n"; + msg += std::string(" → ") + prefix + ".n_block: [✗] (missing member)\n"; } if constexpr (requires(TC t) { t.n_wave_per_xdl; }) { using NWaveType = decltype(std::declval().n_wave_per_xdl); constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + ".n_wave_per_xdl: " + std::string(CHECK_MARK(convertible)) + + msg += std::string(" → ") + prefix + ".n_wave_per_xdl: " + std::string(CHECK_MARK(convertible)) + std::string(get_type_info()) + "\n"; } else { - msg += std::string(" → ") + prefix + ".n_wave_per_xdl: [✗] (missing member)\n"; + msg += std::string(" → ") + prefix + ".n_wave_per_xdl: [✗] (missing member)\n"; } return msg; @@ -289,10 +289,10 @@ consteval auto diagnose_access_order(const char* prefix) -> std::string { if constexpr (requires(AO t) { t.order; }) { using OrderType = decltype(std::declval().order); constexpr bool convertible = std::convertible_to>; - msg += std::string(" → ") + prefix + ".order: " + std::string(CHECK_MARK(convertible)) + + msg += std::string(" → ") + prefix + ".order: " + std::string(CHECK_MARK(convertible)) + std::string(get_type_info()) + "\n"; } else { - msg += std::string(" → ") + prefix + ".order: [✗] (missing member)\n"; + msg += std::string(" → ") + prefix + ".order: [✗] (missing member)\n"; } return msg; @@ -306,28 +306,28 @@ consteval auto diagnose_epilogue(const char* prefix) -> std::string { if constexpr (requires(E t) { t.m_xdl_per_wave_per_shuffle; }) { using MXdlType = decltype(std::declval().m_xdl_per_wave_per_shuffle); constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + ".m_xdl_per_wave_per_shuffle: " + std::string(CHECK_MARK(convertible)) + + msg += std::string(" → ") + prefix + ".m_xdl_per_wave_per_shuffle: " + std::string(CHECK_MARK(convertible)) + std::string(get_type_info()) + "\n"; } else { - msg += std::string(" → ") + prefix + ".m_xdl_per_wave_per_shuffle: [✗] (missing member)\n"; + msg += std::string(" → ") + prefix + ".m_xdl_per_wave_per_shuffle: [✗] (missing member)\n"; } if constexpr (requires(E t) { t.n_per_wave_per_shuffle; }) { using NPerWaveType = decltype(std::declval().n_per_wave_per_shuffle); constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + ".n_per_wave_per_shuffle: " + std::string(CHECK_MARK(convertible)) + + msg += std::string(" → ") + prefix + ".n_per_wave_per_shuffle: " + std::string(CHECK_MARK(convertible)) + std::string(get_type_info()) + "\n"; } else { - msg += std::string(" → ") + prefix + ".n_per_wave_per_shuffle: [✗] (missing member)\n"; + msg += std::string(" → ") + prefix + ".n_per_wave_per_shuffle: [✗] (missing member)\n"; } if constexpr (requires(E t) { t.scalar_per_vector; }) { using ScalarType = decltype(std::declval().scalar_per_vector); constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + ".scalar_per_vector: " + std::string(CHECK_MARK(convertible)) + + msg += std::string(" → ") + prefix + ".scalar_per_vector: " + std::string(CHECK_MARK(convertible)) + std::string(get_type_info()) + "\n"; } else { - msg += std::string(" → ") + prefix + ".scalar_per_vector: [✗] (missing member)\n"; + msg += std::string(" → ") + prefix + ".scalar_per_vector: [✗] (missing member)\n"; } return msg; @@ -355,29 +355,36 @@ template consteval auto detailed_diagnostic_SpecifiesGridwiseFwdXdlGemm() -> std::string { std::string msg; - if constexpr (requires { T::ak1; }) { - using AK1Type = decltype(T::ak1); + if constexpr (!requires { T::gridwise_gemm; }) { + return " → T::gridwise_gemm member: [✗] (missing member)\n"; + } + + msg += " → T::gridwise_gemm member: [✓]\n"; + using GG = decltype(T::gridwise_gemm); + + if constexpr (requires { GG::ak1; }) { + using AK1Type = decltype(GG::ak1); constexpr bool convertible = std::convertible_to; - msg += " → T::ak1: " + std::string(CHECK_MARK(convertible)) + + msg += " → gridwise_gemm.ak1: " + std::string(CHECK_MARK(convertible)) + std::string(detail::get_type_info()) + "\n"; } else { - msg += " → T::ak1: [✗] (missing member)\n"; + msg += " → gridwise_gemm.ak1: [✗] (missing member)\n"; } - if constexpr (requires { T::bk1; }) { - using BK1Type = decltype(T::bk1); + if constexpr (requires { GG::bk1; }) { + using BK1Type = decltype(GG::bk1); constexpr bool convertible = std::convertible_to; - msg += " → T::bk1: " + std::string(CHECK_MARK(convertible)) + + msg += " → gridwise_gemm.bk1: " + std::string(CHECK_MARK(convertible)) + std::string(detail::get_type_info()) + "\n"; } else { - msg += " → T::bk1: [✗] (missing member)\n"; + msg += " → gridwise_gemm.bk1: [✗] (missing member)\n"; } - if constexpr (requires { T::xdl_params; }) { - msg += " → T::xdl_params member: [✓]\n"; - msg += detail::diagnose_xdl_params(); + if constexpr (requires { GG::xdl_params; }) { + msg += " → gridwise_gemm.xdl_params member: [✓]\n"; + msg += detail::diagnose_xdl_params(); } else { - msg += " → T::xdl_params: [✗] (missing member)\n"; + msg += " → gridwise_gemm.xdl_params: [✗] (missing member)\n"; } return msg; @@ -387,29 +394,36 @@ template consteval auto detailed_diagnostic_SpecifiesGridwiseBwdXdlGemm() -> std::string { std::string msg; - if constexpr (requires { T::k0_per_block; }) { - using K0Type = decltype(T::k0_per_block); + if constexpr (!requires { T::gridwise_gemm; }) { + return " → T::gridwise_gemm member: [✗] (missing member)\n"; + } + + msg += " → T::gridwise_gemm member: [✓]\n"; + using GG = decltype(T::gridwise_gemm); + + if constexpr (requires { GG::k0_per_block; }) { + using K0Type = decltype(GG::k0_per_block); constexpr bool convertible = std::convertible_to; - msg += " → T::k0_per_block: " + std::string(CHECK_MARK(convertible)) + + msg += " → gridwise_gemm.k0_per_block: " + std::string(CHECK_MARK(convertible)) + std::string(detail::get_type_info()) + "\n"; } else { - msg += " → T::k0_per_block: [✗] (missing member)\n"; + msg += " → gridwise_gemm.k0_per_block: [✗] (missing member)\n"; } - if constexpr (requires { T::k1; }) { - using K1Type = decltype(T::k1); + if constexpr (requires { GG::k1; }) { + using K1Type = decltype(GG::k1); constexpr bool convertible = std::convertible_to; - msg += " → T::k1: " + std::string(CHECK_MARK(convertible)) + + msg += " → gridwise_gemm.k1: " + std::string(CHECK_MARK(convertible)) + std::string(detail::get_type_info()) + "\n"; } else { - msg += " → T::k1: [✗] (missing member)\n"; + msg += " → gridwise_gemm.k1: [✗] (missing member)\n"; } - if constexpr (requires { T::xdl_params; }) { - msg += " → T::xdl_params member: [✓]\n"; - msg += detail::diagnose_xdl_params(); + if constexpr (requires { GG::xdl_params; }) { + msg += " → gridwise_gemm.xdl_params member: [✓]\n"; + msg += detail::diagnose_xdl_params(); } else { - msg += " → T::xdl_params: [✗] (missing member)\n"; + msg += " → gridwise_gemm.xdl_params: [✗] (missing member)\n"; } return msg; @@ -857,13 +871,13 @@ consteval auto detailed_diagnostic_SpecifiesDlBlockTransfer() -> std::string { constexpr bool has_src_contiguous = requires(ABT t) { { t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to>; }; constexpr bool has_dst_vector = requires(ABT t) { { t.dst_vector_tensor_lengths } -> std::convertible_to>; }; - msg += " → transfer.a.block_transfer.thread_slice_lengths: " + std::string(CHECK_MARK(has_thread_slice)) + (has_thread_slice ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.a.block_transfer.thread_cluster_lengths: " + std::string(CHECK_MARK(has_thread_cluster)) + (has_thread_cluster ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.a.block_transfer.thread_cluster_arrange_order: " + std::string(CHECK_MARK(has_cluster_arrange)) + (has_cluster_arrange ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.a.block_transfer.src_access_order: " + std::string(CHECK_MARK(has_src_access)) + (has_src_access ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.a.block_transfer.src_vector_tensor_lengths: " + std::string(CHECK_MARK(has_src_vector)) + (has_src_vector ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.a.block_transfer.src_vector_tensor_contiguous_dim_order: " + std::string(CHECK_MARK(has_src_contiguous)) + (has_src_contiguous ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.a.block_transfer.dst_vector_tensor_lengths: " + std::string(CHECK_MARK(has_dst_vector)) + (has_dst_vector ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.a.block_transfer.thread_slice_lengths: " + std::string(CHECK_MARK(has_thread_slice)) + (has_thread_slice ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.a.block_transfer.thread_cluster_lengths: " + std::string(CHECK_MARK(has_thread_cluster)) + (has_thread_cluster ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.a.block_transfer.thread_cluster_arrange_order: " + std::string(CHECK_MARK(has_cluster_arrange)) + (has_cluster_arrange ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.a.block_transfer.src_access_order: " + std::string(CHECK_MARK(has_src_access)) + (has_src_access ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.a.block_transfer.src_vector_tensor_lengths: " + std::string(CHECK_MARK(has_src_vector)) + (has_src_vector ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.a.block_transfer.src_vector_tensor_contiguous_dim_order: " + std::string(CHECK_MARK(has_src_contiguous)) + (has_src_contiguous ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.a.block_transfer.dst_vector_tensor_lengths: " + std::string(CHECK_MARK(has_dst_vector)) + (has_dst_vector ? "\n" : " (missing or wrong type)\n"); } else if constexpr (has_a) { msg += " → T::transfer.a.block_transfer: [✗] (missing)\n"; } diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index 2fbd5d3fc7f..2767814e115 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -45,7 +45,7 @@ struct GridwiseFwdXdlGemm size_t bk1 = 0; XdlParams xdl_params; }; -static_assert(ckb::SpecifiesGridwiseFwdXdlGemm); +static_assert(ckb::GridwiseFwdXdlGemmDescriptor); struct GridwiseBwdXdlGemm { @@ -53,7 +53,7 @@ struct GridwiseBwdXdlGemm size_t k1 = 0; XdlParams xdl_params; }; -static_assert(ckb::SpecifiesGridwiseBwdXdlGemm); +static_assert(ckb::GridwiseBwdXdlGemmDescriptor); // Describe gridwise WMMA GEMM parameters. struct GridwiseWmmaGemm From dacf82d6528767d8ff07365d487c43ce9ea96311 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Mon, 22 Dec 2025 09:23:47 -0500 Subject: [PATCH 11/81] Concept bug fixes. --- .../builder/conv_algorithm_concepts.hpp | 28 +++++++++---------- .../factory/helpers/ck/conv_tuning_params.hpp | 2 +- .../test/conv/ck/test_ckb_conv_bwd_weight.cpp | 3 +- .../test/impl/conv_algorithm_types.hpp | 4 +-- .../test/utils/conv_algorithm_type_utils.hpp | 2 +- 5 files changed, 20 insertions(+), 19 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index 2a3d3cd75b8..108ccc04253 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -157,36 +157,36 @@ concept SpecifiesTileThreadBlock = requires { // Concept to check if a struct specifies gridwise XDL GEMM info. template -concept GridwiseFwdXdlGemmDescriptor = requires { - { T::ak1 } -> std::convertible_to; - { T::bk1 } -> std::convertible_to; - { T::xdl_params } -> GridwiseXdlGemmDescriptor; +concept GridwiseFwdXdlGemmDescriptor = requires (T t){ + { t.ak1 } -> std::convertible_to; + { t.bk1 } -> std::convertible_to; + { t.xdl_params } -> GridwiseXdlGemmDescriptor; }; // Concept to check if a struct specifies gridwise XDL GEMM info. template -concept GridwiseBwdXdlGemmDescriptor = requires { - { T::k0_per_block } -> std::convertible_to; - { T::k1 } -> std::convertible_to; - { T::xdl_params } -> GridwiseXdlGemmDescriptor; +concept GridwiseBwdXdlGemmDescriptor = requires (T t){ + { t.k0_per_block } -> std::convertible_to; + { t.k1 } -> std::convertible_to; + { t.xdl_params } -> GridwiseXdlGemmDescriptor; }; // Concept to check if a struct specifies gridwise XDL GEMM info. template -concept SpecifiesGridwiseFwdXdlGemm = requires { - { T::gridwise_gemm } -> GridwiseFwdXdlGemmDescriptor; +concept SpecifiesGridwiseFwdXdlGemm = requires (T t) { + { t.gridwise_gemm } -> GridwiseFwdXdlGemmDescriptor; }; // Concept to check if a struct specifies gridwise XDL GEMM info. template -concept SpecifiesGridwiseBwdXdlGemm = requires { - { T::gridwise_gemm } -> GridwiseFwdXdlGemmDescriptor; +concept SpecifiesGridwiseBwdXdlGemm = requires (T t) { + { t.gridwise_gemm } -> GridwiseBwdXdlGemmDescriptor; }; // Concept to check if a struct specifies gridwise WMMA GEMM info. template -concept SpecifiesGridwiseWmmaGemm = requires { - { T::gridwise_gemm } -> GridwiseBwdXdlGemmDescriptor; +concept SpecifiesGridwiseWmmaGemm = requires (T t){ + { t.gridwise_gemm } -> GridwiseWmmaGemmDescriptor; }; // Concept to check if a struct specifies convolution input and output block transfer info. diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp index 6f3a9e8e78d..d7f3b17197f 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp @@ -161,7 +161,7 @@ consteval ck::tensor_operation::device::ConvolutionForwardSpecialization SetFwdC template consteval ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization SetBwdWeightConvSpecialization() { - constexpr auto specialization = ALGORITHM.bwd_specialization; + constexpr auto specialization = ALGORITHM.bwd_weight_specialization; using ck_conv_spec = ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization; switch(specialization) { diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp index 975212999b0..045efbc385e 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp @@ -25,7 +25,8 @@ constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CSh .with_gemm_config(cku::BwdGemmParams_Xdl_4x4_per_wave) .with_transfer(cku::Transfer_4x64x1) .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT); - +static_assert(cku::SpecifiesGridwiseBwdXdlGemm, "Error"); + using Builder = ckb::ConvBuilder; using Instance = Builder::Instance; diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index 2767814e115..4d5ac2cd9e2 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -211,7 +211,7 @@ struct ConvSpecializationFwd_ struct ConvSpecializationBwdWeight_ { - ConvSpecialization bwd_specialization; + ConvSpecialization bwd_weight_specialization; }; struct Prefetch_ @@ -400,7 +400,7 @@ struct ConvAlgorithmTemplate : Components... { static_assert(std::is_base_of_v); auto result = *this; - result.bwd_specialization = bwd_spec; + result.bwd_weight_specialization = bwd_spec; return result; } diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index 6c1d9ae15f9..c3afe2bd4e0 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -278,7 +278,7 @@ template <> inline std::string to_string(ConvSpecializationBwdWeight_ t) { std::ostringstream oss; - oss << to_string(t.bwd_specialization); + oss << to_string(t.bwd_weight_specialization); return oss.str(); } From 8eb62241fb4eae37a22c1ee300cf87176eab8388 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Mon, 22 Dec 2025 09:30:43 -0500 Subject: [PATCH 12/81] Remove debug assert. --- experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp index 045efbc385e..366c2b27514 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp @@ -25,7 +25,6 @@ constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CSh .with_gemm_config(cku::BwdGemmParams_Xdl_4x4_per_wave) .with_transfer(cku::Transfer_4x64x1) .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT); -static_assert(cku::SpecifiesGridwiseBwdXdlGemm, "Error"); using Builder = ckb::ConvBuilder; using Instance = Builder::Instance; From a8e7edd814657e8ee023c5ee2caec32b099ee08b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Mon, 22 Dec 2025 10:56:47 -0500 Subject: [PATCH 13/81] Update algorithm signature diagnostics. --- .../builder/conv_algorithm_diagnostics.hpp | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp index 1c4cafab8cb..f97ef7c275a 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp @@ -362,8 +362,8 @@ consteval auto detailed_diagnostic_SpecifiesGridwiseFwdXdlGemm() -> std::string msg += " → T::gridwise_gemm member: [✓]\n"; using GG = decltype(T::gridwise_gemm); - if constexpr (requires { GG::ak1; }) { - using AK1Type = decltype(GG::ak1); + if constexpr (requires(GG t) { t.ak1; }) { + using AK1Type = decltype(std::declval().ak1); constexpr bool convertible = std::convertible_to; msg += " → gridwise_gemm.ak1: " + std::string(CHECK_MARK(convertible)) + std::string(detail::get_type_info()) + "\n"; @@ -371,8 +371,8 @@ consteval auto detailed_diagnostic_SpecifiesGridwiseFwdXdlGemm() -> std::string msg += " → gridwise_gemm.ak1: [✗] (missing member)\n"; } - if constexpr (requires { GG::bk1; }) { - using BK1Type = decltype(GG::bk1); + if constexpr (requires(GG t) { t.bk1; }) { + using BK1Type = decltype(std::declval().bk1); constexpr bool convertible = std::convertible_to; msg += " → gridwise_gemm.bk1: " + std::string(CHECK_MARK(convertible)) + std::string(detail::get_type_info()) + "\n"; @@ -380,9 +380,9 @@ consteval auto detailed_diagnostic_SpecifiesGridwiseFwdXdlGemm() -> std::string msg += " → gridwise_gemm.bk1: [✗] (missing member)\n"; } - if constexpr (requires { GG::xdl_params; }) { + if constexpr (requires(GG t) { t.xdl_params; }) { msg += " → gridwise_gemm.xdl_params member: [✓]\n"; - msg += detail::diagnose_xdl_params(); + msg += detail::diagnose_xdl_params().xdl_params)>(); } else { msg += " → gridwise_gemm.xdl_params: [✗] (missing member)\n"; } @@ -401,8 +401,8 @@ consteval auto detailed_diagnostic_SpecifiesGridwiseBwdXdlGemm() -> std::string msg += " → T::gridwise_gemm member: [✓]\n"; using GG = decltype(T::gridwise_gemm); - if constexpr (requires { GG::k0_per_block; }) { - using K0Type = decltype(GG::k0_per_block); + if constexpr (requires(GG t) { t.k0_per_block; }) { + using K0Type = decltype(std::declval().k0_per_block); constexpr bool convertible = std::convertible_to; msg += " → gridwise_gemm.k0_per_block: " + std::string(CHECK_MARK(convertible)) + std::string(detail::get_type_info()) + "\n"; @@ -410,8 +410,8 @@ consteval auto detailed_diagnostic_SpecifiesGridwiseBwdXdlGemm() -> std::string msg += " → gridwise_gemm.k0_per_block: [✗] (missing member)\n"; } - if constexpr (requires { GG::k1; }) { - using K1Type = decltype(GG::k1); + if constexpr (requires(GG t) { t.k1; }) { + using K1Type = decltype(std::declval().k1); constexpr bool convertible = std::convertible_to; msg += " → gridwise_gemm.k1: " + std::string(CHECK_MARK(convertible)) + std::string(detail::get_type_info()) + "\n"; @@ -419,9 +419,9 @@ consteval auto detailed_diagnostic_SpecifiesGridwiseBwdXdlGemm() -> std::string msg += " → gridwise_gemm.k1: [✗] (missing member)\n"; } - if constexpr (requires { GG::xdl_params; }) { + if constexpr (requires(GG t) { t.xdl_params; }) { msg += " → gridwise_gemm.xdl_params member: [✓]\n"; - msg += detail::diagnose_xdl_params(); + msg += detail::diagnose_xdl_params().xdl_params)>(); } else { msg += " → gridwise_gemm.xdl_params: [✗] (missing member)\n"; } From 96a4a5de376ebed481ffccde0a549a7abba55ed8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Mon, 22 Dec 2025 11:05:00 -0500 Subject: [PATCH 14/81] Factory bug fixes. --- .../builder/factory/conv_bwd_weight_xdl_factory.hpp | 9 +++++---- .../builder/factory/helpers/ck/conv_tensor_type.hpp | 2 ++ .../builder/factory/helpers/ck/conv_tuning_params.hpp | 1 + 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp index 0f726fe67d9..db361149975 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp @@ -35,6 +35,7 @@ struct ConvBwdWeightXdlFactory static constexpr auto BLOCK = internal::SetThreadBlockInfo(); static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; static constexpr auto A_BLOCK_TRANSFER = internal::SetFwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -70,10 +71,10 @@ struct ConvBwdWeightXdlFactory BLOCK.per_block.n, GRIDWISE_GEMM.k0_per_block, GRIDWISE_GEMM.k1, - GRIDWISE_GEMM.m_per_xdl, - GRIDWISE_GEMM.n_per_xdl, - GRIDWISE_GEMM.m_xdl_per_wave, - GRIDWISE_GEMM.n_xdl_per_wave, + XDL_PARAMS.m_per_xdl, + XDL_PARAMS.n_per_xdl, + XDL_PARAMS.m_xdl_per_wave, + XDL_PARAMS.n_xdl_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp index d4a470dcedf..d6b0e067009 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp @@ -190,6 +190,8 @@ struct BwdWeightConvTensorDataTypes using InComputeType = typename decltype(input_types.second)::type; using WeiDataType = typename decltype(weight_types.first)::type; using WeiComputeType = typename decltype(weight_types.second)::type; + using OutDataType = typename decltype(output_types.first)::type; + using OutComputeType = typename decltype(output_types.second)::type; using AccDataType = typename decltype(GetTensorAccumulationType())::type; diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp index d7f3b17197f..92a7b48ddd9 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp @@ -169,6 +169,7 @@ consteval ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization case ConvSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0; case ConvSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0; case ConvSpecialization::ODD_C: return ck_conv_spec::OddC; + case ConvSpecialization::FILTER_3x3: throw "FILTER_3x3 is not supported for backward weight convolution."; default: throw "Unsupported ConvSpecialization"; } } From 608266a4ef0e1630d575cc4699fb5400cc8c70f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Mon, 22 Dec 2025 11:50:00 -0500 Subject: [PATCH 15/81] First functional version of bwd weight conv factory. --- .../ck_tile/builder/conv_algorithm_limits.hpp | 8 + .../factory/conv_bwd_weight_xdl_factory.hpp | 12 +- .../helpers/ck/conv_block_transfer.hpp | 32 ++ .../test/conv/ck/test_ckb_conv_bwd_weight.cpp | 4 +- .../test/test_concept_diagnostics_sync.cpp | 409 ++++++++++++++++++ .../test/utils/ckb_conv_test_configs.hpp | 33 ++ 6 files changed, 490 insertions(+), 8 deletions(-) create mode 100644 experimental/builder/test/test_concept_diagnostics_sync.cpp diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp index 10a619024a9..f60e7703a3e 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp @@ -35,4 +35,12 @@ concept AccessOrderLimits = requires { (Value[2] >= 0 && Value[2] < 3)); }; +// Limits for access order. Must be a permutation of {1, 2, 3} for the last three elements. +template +concept BwdAccessOrderLimits = requires { + requires((Value[1] != Value[2]) && (Value[1] != Value[3]) && (Value[2] != Value[3]) && + (Value[1] >= 1 && Value[1] < 4) && (Value[2] >= 1 && Value[2] < 4) && + (Value[3] >= 1 && Value[3] < 4)) && (Value[0] == 0); +}; + } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp index db361149975..cc3262c07c1 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp @@ -37,9 +37,9 @@ struct ConvBwdWeightXdlFactory static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; static constexpr auto A_BLOCK_TRANSFER = - internal::SetFwdConvBlockTransfer(); + internal::SetBwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = - internal::SetFwdConvBlockTransfer(); + internal::SetBwdConvBlockTransfer(); static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); // Check limits for the algorithm parameters. @@ -47,10 +47,10 @@ struct ConvBwdWeightXdlFactory static_assert(InputVectorTransferLimits); static_assert(InputVectorTransferLimits); static_assert(OutputVectorTransferLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); + static_assert(BwdAccessOrderLimits); + static_assert(BwdAccessOrderLimits); + static_assert(BwdAccessOrderLimits); + static_assert(BwdAccessOrderLimits); // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle< diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp index 5da1e4eadbe..9729a72ce71 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp @@ -22,6 +22,18 @@ struct BlockTransfer bool lds_padding = false; }; +struct BwdBlockTransfer +{ + ck::Array thread_cluster_dims = {0, 0, 0, 0}; + ck::Array thread_cluster_order = {0, 0, 0, 0}; + ck::Array src_access_order = {0, 0, 0, 0}; + size_t src_vector_dim = 0; + size_t src_scalar_per_vector = 0; + size_t lds_dst_scalar_per_vector = 0; + bool is_direct_load = false; + bool lds_padding = false; +}; + template constexpr BlockTransfer SetFwdConvBlockTransfer() { @@ -42,6 +54,26 @@ constexpr BlockTransfer SetFwdConvBlockTransfer() }; } +template +constexpr BwdBlockTransfer SetBwdConvBlockTransfer() +{ + auto& block_xfer = TRANSFER.block_transfer; + auto& block_order = TRANSFER.block_transfer_access_order; + auto& src_order = TRANSFER.src_access_order; + auto& lds_cfg = TRANSFER.lds_transfer; + + return BwdBlockTransfer{ + .thread_cluster_dims = {1, block_xfer.k0, block_xfer.m_n, block_xfer.k1}, + .thread_cluster_order = {0, block_order.order[0], block_order.order[1], block_order.order[2]}, + .src_access_order = {0, src_order.order[0], src_order.order[1], src_order.order[2]}, + .src_vector_dim = lds_cfg.src_vector_dim, + .src_scalar_per_vector = lds_cfg.src_scalar_per_vector, + .lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector, + .is_direct_load = lds_cfg.is_direct_load, + .lds_padding = lds_cfg.lds_padding, + }; +} + // Block transfer parameters for C tensor. struct CBlockTransfer { diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp index 366c2b27514..f626bbb288b 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp @@ -21,9 +21,9 @@ constexpr auto SIGNATURE = .output = {.config = {.layout = ckb::TensorLayout::GNHWK}}}; constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle{} - .with_thread_block(cku::ThreadBlock_256_256x256x32) + .with_thread_block(cku::ThreadBlock_256_128x128x32) .with_gemm_config(cku::BwdGemmParams_Xdl_4x4_per_wave) - .with_transfer(cku::Transfer_4x64x1) + .with_transfer(cku::BwdTransfer_4x64x1) .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT); using Builder = ckb::ConvBuilder; diff --git a/experimental/builder/test/test_concept_diagnostics_sync.cpp b/experimental/builder/test/test_concept_diagnostics_sync.cpp new file mode 100644 index 00000000000..0bc786dbdf5 --- /dev/null +++ b/experimental/builder/test/test_concept_diagnostics_sync.cpp @@ -0,0 +1,409 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * @file test_concept_diagnostics_sync.cpp + * @brief Unit tests to ensure concepts and their diagnostics remain in sync + * + * This test suite verifies that: + * 1. Valid types satisfy their corresponding concepts + * 2. Invalid types (missing members) do not satisfy concepts + * 3. Diagnostic messages correctly identify missing requirements + * 4. Existing test types from conv_algorithm_types.hpp satisfy their concepts + */ + +#include +#include + +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_diagnostics.hpp" +#include "ck_tile/builder/types.hpp" +#include "experimental/builder/test/impl/conv_algorithm_types.hpp" + +namespace ck_tile::builder::test { + +using ck_tile::builder::ThreadBlockDescriptor; +using ck_tile::builder::GridwiseXdlGemmDescriptor; +using ck_tile::builder::BlockTransferDescriptor; +using ck_tile::builder::ThreadClusterDescriptor; +using ck_tile::builder::LdsTransferDescriptor; +using ck_tile::builder::EpilogueDescriptor; +using ck_tile::builder::AccessOrderDescriptor; +using ck_tile::builder::BlockGemmDescriptor; +using ck_tile::builder::GridwiseWmmaGemmDescriptor; +using ck_tile::builder::TileThreadBlockDescriptor; +using ck_tile::builder::TileTransferDescriptor; +using ck_tile::builder::TileBlockGemmDescriptor; +using ck_tile::builder::TileOptimizationsDescriptor; +using ck_tile::builder::DlThreadConfigDescriptor; +using ck_tile::builder::DlThreadClusterDescriptor; +using ck_tile::builder::DlBlockTransferDescriptor; +using ck_tile::builder::DlEpilogueDescriptor; +using ck_tile::builder::ConvAlgorithmDescriptor; +using ck_tile::builder::SpecifiesThreadBlock; +using ck_tile::builder::SpecifiesGridwiseFwdXdlGemm; +using ck_tile::builder::SpecifiesGridwiseBwdXdlGemm; +using ck_tile::builder::SpecifiesBlockGemm; +using ck_tile::builder::SpecifiesFwdConvSpecialization; +using ck_tile::builder::SpecifiesBwdWeightConvSpecialization; +using ck_tile::builder::SpecifiesGemmSpecialization; +using ck_tile::builder::SpecifiesNumPrefetchStages; +using ck_tile::builder::SpecifiesLoopScheduler; +using ck_tile::builder::SpecifiesTileThreadBlock; +using ck_tile::builder::SpecifiesTileTransfer; +using ck_tile::builder::SpecifiesTileBlockGemm; +using ck_tile::builder::SpecifiesTileOptimizations; +using ck_tile::builder::SpecifiesTileConvSpecialization; +using ck_tile::builder::SpecifiesDlThreadConfig; +using ck_tile::builder::SpecifiesDlThreadCluster; + +// Helper to check if a string contains a substring +bool contains(const std::string& str, const std::string& substr) +{ + return str.find(substr) != std::string::npos; +} + +// ============================================================================= +// BASIC DESCRIPTOR CONCEPTS TESTS +// ============================================================================= + +TEST(ConceptDiagnosticsSync, ThreadBlockDescriptor_Valid) +{ + // The ThreadBlock type from conv_algorithm_types.hpp should satisfy the concept + static_assert(ThreadBlockDescriptor); +} + +TEST(ConceptDiagnosticsSync, GridwiseXdlGemmDescriptor_Valid) +{ + // The XdlParams type should satisfy the concept + static_assert(GridwiseXdlGemmDescriptor); +} + +TEST(ConceptDiagnosticsSync, BlockTransferDescriptor_Valid) +{ + // The BlockTransfer type should satisfy the concept + static_assert(BlockTransferDescriptor); +} + +TEST(ConceptDiagnosticsSync, ThreadClusterDescriptor_Valid) +{ + // The ThreadCluster type should satisfy the concept + static_assert(ThreadClusterDescriptor); +} + +TEST(ConceptDiagnosticsSync, LdsTransferDescriptor_Valid) +{ + // The LdsTransfer type should satisfy the concept + static_assert(LdsTransferDescriptor); +} + +TEST(ConceptDiagnosticsSync, EpilogueDescriptor_Valid) +{ + // The Epilogue type should satisfy the concept + static_assert(EpilogueDescriptor); +} + +TEST(ConceptDiagnosticsSync, AccessOrderDescriptor_Valid) +{ + // The AccessOrder type should satisfy the concept + static_assert(AccessOrderDescriptor); +} + +TEST(ConceptDiagnosticsSync, BlockGemmDescriptor_Valid) +{ + // The BlockGemm type should satisfy the concept + static_assert(BlockGemmDescriptor); +} + +TEST(ConceptDiagnosticsSync, GridwiseWmmaGemmDescriptor_Valid) +{ + // The GridwiseWmmaGemm type should satisfy the concept + static_assert(GridwiseWmmaGemmDescriptor); +} + +// ============================================================================= +// HIGH-LEVEL "SPECIFIES" CONCEPTS TESTS +// ============================================================================= + +TEST(ConceptDiagnosticsSync, SpecifiesThreadBlock_Valid) +{ + static_assert(SpecifiesThreadBlock); +} + +TEST(ConceptDiagnosticsSync, SpecifiesGridwiseFwdXdlGemm_Valid) +{ + static_assert(SpecifiesGridwiseFwdXdlGemm); +} + +TEST(ConceptDiagnosticsSync, SpecifiesGridwiseBwdXdlGemm_Valid) +{ + static_assert(SpecifiesGridwiseBwdXdlGemm); +} + +TEST(ConceptDiagnosticsSync, SpecifiesBlockGemm_Valid) +{ + static_assert(SpecifiesBlockGemm); +} + +TEST(ConceptDiagnosticsSync, SpecifiesFwdConvSpecialization_Valid) +{ + static_assert(SpecifiesFwdConvSpecialization); +} + +TEST(ConceptDiagnosticsSync, SpecifiesBwdWeightConvSpecialization_Valid) +{ + static_assert(SpecifiesBwdWeightConvSpecialization); +} + +TEST(ConceptDiagnosticsSync, SpecifiesGemmSpecialization_Valid) +{ + static_assert(SpecifiesGemmSpecialization); +} + +TEST(ConceptDiagnosticsSync, SpecifiesNumPrefetchStages_Valid) +{ + static_assert(SpecifiesNumPrefetchStages); +} + +TEST(ConceptDiagnosticsSync, SpecifiesLoopScheduler_Valid) +{ + static_assert(SpecifiesLoopScheduler); +} + +// ============================================================================= +// TILE-SPECIFIC CONCEPTS TESTS +// ============================================================================= + +TEST(ConceptDiagnosticsSync, TileThreadBlockDescriptor_Valid) +{ + static_assert(TileThreadBlockDescriptor); +} + +TEST(ConceptDiagnosticsSync, TileTransferDescriptor_Valid) +{ + static_assert(TileTransferDescriptor); +} + +TEST(ConceptDiagnosticsSync, TileBlockGemmDescriptor_Valid) +{ + static_assert(TileBlockGemmDescriptor); +} + +TEST(ConceptDiagnosticsSync, TileOptimizationsDescriptor_Valid) +{ + static_assert(TileOptimizationsDescriptor); +} + +TEST(ConceptDiagnosticsSync, SpecifiesTileThreadBlock_Valid) +{ + static_assert(SpecifiesTileThreadBlock); +} + +TEST(ConceptDiagnosticsSync, SpecifiesTileTransfer_Valid) +{ + static_assert(SpecifiesTileTransfer); +} + +TEST(ConceptDiagnosticsSync, SpecifiesTileBlockGemm_Valid) +{ + static_assert(SpecifiesTileBlockGemm); +} + +TEST(ConceptDiagnosticsSync, SpecifiesTileOptimizations_Valid) +{ + static_assert(SpecifiesTileOptimizations); +} + +TEST(ConceptDiagnosticsSync, SpecifiesTileConvSpecialization_Valid) +{ + static_assert(SpecifiesTileConvSpecialization); +} + +// ============================================================================= +// DL-SPECIFIC CONCEPTS TESTS +// ============================================================================= + +TEST(ConceptDiagnosticsSync, DlThreadConfigDescriptor_Valid) +{ + static_assert(DlThreadConfigDescriptor); +} + +TEST(ConceptDiagnosticsSync, DlThreadClusterDescriptor_Valid) +{ + static_assert(DlThreadClusterDescriptor); +} + +TEST(ConceptDiagnosticsSync, DlBlockTransferDescriptor_Valid) +{ + static_assert(DlBlockTransferDescriptor); +} + +TEST(ConceptDiagnosticsSync, DlEpilogueDescriptor_Valid) +{ + static_assert(DlEpilogueDescriptor); +} + +TEST(ConceptDiagnosticsSync, SpecifiesDlThreadConfig_Valid) +{ + static_assert(SpecifiesDlThreadConfig); +} + +TEST(ConceptDiagnosticsSync, SpecifiesDlThreadCluster_Valid) +{ + static_assert(SpecifiesDlThreadCluster); +} + +// ============================================================================= +// INVALID TYPE TESTS - Test that concepts correctly reject invalid types +// ============================================================================= + +namespace invalid_types { + +// Test ThreadBlockDescriptor with missing members +struct MissingBlockSize +{ + struct + { + size_t m, n, k; + } tile_size; +}; + +struct MissingTileSizeM +{ + size_t block_size; + struct + { + size_t n, k; + } tile_size; +}; + +// Test GridwiseXdlGemmDescriptor with missing members +struct MissingMPerXdl +{ + size_t n_per_xdl; + size_t m_xdl_per_wave; + size_t n_xdl_per_wave; +}; + +// Test BlockTransferDescriptor with missing members +struct MissingK0 +{ + size_t m_n; + size_t k1; +}; + +// Test LdsTransferDescriptor with missing members +struct MissingSrcVectorDim +{ + size_t src_scalar_per_vector; + size_t lds_dst_scalar_per_vector; + bool is_direct_load; + bool lds_padding; +}; + +} // namespace invalid_types + +TEST(ConceptDiagnosticsSync, ThreadBlockDescriptor_Invalid) +{ + static_assert(!ThreadBlockDescriptor); + static_assert(!ThreadBlockDescriptor); +} + +TEST(ConceptDiagnosticsSync, GridwiseXdlGemmDescriptor_Invalid) +{ + static_assert(!GridwiseXdlGemmDescriptor); +} + +TEST(ConceptDiagnosticsSync, BlockTransferDescriptor_Invalid) +{ + static_assert(!BlockTransferDescriptor); +} + +TEST(ConceptDiagnosticsSync, LdsTransferDescriptor_Invalid) +{ + static_assert(!LdsTransferDescriptor); +} + +// ============================================================================= +// COMPREHENSIVE ALGORITHM TYPE TESTS +// ============================================================================= + +TEST(ConceptDiagnosticsSync, CompleteAlgorithmTypes) +{ + // Test that complete algorithm types satisfy their concepts + static_assert(ConvAlgorithmDescriptor); + static_assert(ConvAlgorithmDescriptor); + static_assert(ConvAlgorithmDescriptor); + static_assert(ConvAlgorithmDescriptor); + static_assert(ConvAlgorithmDescriptor); + + // Test specific requirements for each algorithm type + static_assert(SpecifiesThreadBlock); + static_assert(SpecifiesGridwiseFwdXdlGemm); + static_assert(SpecifiesFwdConvSpecialization); + static_assert(SpecifiesNumPrefetchStages); + + static_assert(SpecifiesTileThreadBlock); + static_assert(SpecifiesTileBlockGemm); + static_assert(SpecifiesTileOptimizations); +} + +// ============================================================================= +// DIAGNOSTIC MESSAGE TESTS +// ============================================================================= + +TEST(ConceptDiagnosticsSync, DiagnosticMessages) +{ + // Test that diagnostics can be called (even if messages may be empty at compile-time) + // The key is that the diagnostic functions exist and compile + std::string diag1 = ck_tile::builder::diagnostics::detailed_diagnostic_SpecifiesThreadBlock(); + std::string diag2 = ck_tile::builder::diagnostics::detailed_diagnostic_SpecifiesGridwiseFwdXdlGemm(); + + // These may be empty depending on the implementation, but they should compile + EXPECT_TRUE(diag1.empty() || contains(diag1, "thread_block") || contains(diag1, "missing")); + EXPECT_TRUE(diag2.empty() || contains(diag2, "gridwise_gemm") || contains(diag2, "missing")); +} + +// ============================================================================= +// CONCEPT COMPLETENESS TESTS +// ============================================================================= + +/** + * @brief Verify that all concepts defined in conv_algorithm_concepts.hpp have tests + * + * This test serves as documentation of which concepts are tested. If new concepts + * are added, this test should be updated to include them. + */ +TEST(ConceptDiagnosticsSync, ConceptCoverage) +{ + // Basic Descriptor Concepts - verify they all exist and can be instantiated + EXPECT_TRUE((ThreadBlockDescriptor)); + EXPECT_TRUE((GridwiseXdlGemmDescriptor)); + EXPECT_TRUE((BlockGemmDescriptor)); + EXPECT_TRUE((GridwiseWmmaGemmDescriptor)); + EXPECT_TRUE((BlockTransferDescriptor)); + EXPECT_TRUE((ThreadClusterDescriptor)); + EXPECT_TRUE((LdsTransferDescriptor)); + EXPECT_TRUE((EpilogueDescriptor)); + EXPECT_TRUE((AccessOrderDescriptor)); + + // Tile Descriptor Concepts + EXPECT_TRUE((TileThreadBlockDescriptor)); + EXPECT_TRUE((TileTransferDescriptor)); + EXPECT_TRUE((TileBlockGemmDescriptor)); + EXPECT_TRUE((TileOptimizationsDescriptor)); + + // DL Descriptor Concepts + EXPECT_TRUE((DlThreadConfigDescriptor)); + EXPECT_TRUE((DlThreadClusterDescriptor)); + EXPECT_TRUE((DlBlockTransferDescriptor)); + EXPECT_TRUE((DlEpilogueDescriptor)); +} + +} // namespace ck_tile::builder::test + +int main(int argc, char** argv) +{ + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/experimental/builder/test/utils/ckb_conv_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_test_configs.hpp index 8e5963b22ea..d176506526f 100644 --- a/experimental/builder/test/utils/ckb_conv_test_configs.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_configs.hpp @@ -72,6 +72,39 @@ constexpr TransferABC Transfer_4x64x1{ }, }; +constexpr TransferABC BwdTransfer_4x64x1{ + .a = + { + .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 4, + .is_direct_load = false, + .lds_padding = true}, + .block_transfer_access_order = {3, 1, 2}, + .src_access_order = {2, 1, 3}, + }, + .b = + { + .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 4, + .is_direct_load = false, + .lds_padding = true}, + .block_transfer_access_order = {3, 1, 2}, + .src_access_order = {2, 1, 3}, + }, + .c = + { + .thread_cluster_dims = + {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, + .epilogue = {.m_xdl_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, + }, +}; + constexpr TransferABC Transfer_4x64x1_fp8{ .a = { From a1740c614ba6beda7402f31b6609599e63efdfad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Tue, 23 Dec 2025 10:08:56 -0500 Subject: [PATCH 16/81] Refactor handing of GEMM-K batch template parameter in conv bwd weight factory. --- .../builder/conv_algorithm_concepts.hpp | 15 ++++- .../builder/conv_algorithm_diagnostics.hpp | 21 +++---- .../factory/conv_bwd_weight_xdl_factory.hpp | 2 +- .../helpers/ck/conv_block_transfer.hpp | 6 +- .../test/conv/ck/test_ckb_conv_bwd_weight.cpp | 9 +-- .../test/impl/conv_algorithm_types.hpp | 55 ++++++++++++------- .../test/utils/ckb_conv_test_configs.hpp | 28 ++++++---- .../test/utils/conv_algorithm_type_utils.hpp | 42 ++++++++------ 8 files changed, 109 insertions(+), 69 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index 108ccc04253..4d81becfb5a 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -51,12 +51,24 @@ concept GridwiseWmmaGemmDescriptor = requires(T t) { { t.pipeline_version } -> std::convertible_to; }; + +template +concept HasGemmKBatch = requires(T t) { + { t.k_batch_size}; +}; + +// Concept to check if GEMM k batch size is specified. +template +concept GemmKBatchSizeWellDefinedIfProvided = + !HasGemmKBatch || requires(T t) { {t.k_batch_size} -> std::convertible_to; }; + // Concept for vectorized data transfer for convolution input tensors. template concept BlockTransferDescriptor = requires(T t) { { t.k0 } -> std::convertible_to; { t.m_n } -> std::convertible_to; { t.k1 } -> std::convertible_to; + GemmKBatchSizeWellDefinedIfProvided; }; // Concept for thread cluster dimensions for GEMM output tensor. @@ -91,6 +103,8 @@ concept EpilogueDescriptor = requires(T t) { template concept AccessOrderDescriptor = requires(T t) { { t.order } -> std::convertible_to>; +} || requires(T t) { + { t.order } -> std::convertible_to>; }; // Concept for thread block dimensions for a GEMM problem for CK Tile (Block @@ -166,7 +180,6 @@ concept GridwiseFwdXdlGemmDescriptor = requires (T t){ // Concept to check if a struct specifies gridwise XDL GEMM info. template concept GridwiseBwdXdlGemmDescriptor = requires (T t){ - { t.k0_per_block } -> std::convertible_to; { t.k1 } -> std::convertible_to; { t.xdl_params } -> GridwiseXdlGemmDescriptor; }; diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp index f97ef7c275a..3b75eb00e9b 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp @@ -181,6 +181,14 @@ consteval auto diagnose_block_transfer(const char* prefix) -> std::string { msg += std::string(" → ") + prefix + ".k1: [✗] (missing member)\n"; } + // k_batch_size is optional - only report if it exists and has wrong type + if constexpr (requires(BT t) { t.k_batch_size; }) { + using KBatchType = decltype(std::declval().k_batch_size); + constexpr bool convertible = std::convertible_to; + msg += std::string(" → ") + prefix + ".k_batch_size (optional): " + std::string(CHECK_MARK(convertible)) + + std::string(get_type_info()) + "\n"; + } + return msg; } @@ -288,7 +296,9 @@ consteval auto diagnose_access_order(const char* prefix) -> std::string { if constexpr (requires(AO t) { t.order; }) { using OrderType = decltype(std::declval().order); - constexpr bool convertible = std::convertible_to>; + constexpr bool convertible_3 = std::convertible_to>; + constexpr bool convertible_4 = std::convertible_to>; + constexpr bool convertible = convertible_3 || convertible_4; msg += std::string(" → ") + prefix + ".order: " + std::string(CHECK_MARK(convertible)) + std::string(get_type_info()) + "\n"; } else { @@ -401,15 +411,6 @@ consteval auto detailed_diagnostic_SpecifiesGridwiseBwdXdlGemm() -> std::string msg += " → T::gridwise_gemm member: [✓]\n"; using GG = decltype(T::gridwise_gemm); - if constexpr (requires(GG t) { t.k0_per_block; }) { - using K0Type = decltype(std::declval().k0_per_block); - constexpr bool convertible = std::convertible_to; - msg += " → gridwise_gemm.k0_per_block: " + std::string(CHECK_MARK(convertible)) + - std::string(detail::get_type_info()) + "\n"; - } else { - msg += " → gridwise_gemm.k0_per_block: [✗] (missing member)\n"; - } - if constexpr (requires(GG t) { t.k1; }) { using K1Type = decltype(std::declval().k1); constexpr bool convertible = std::convertible_to; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp index cc3262c07c1..8cb0d5d2ae9 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp @@ -69,7 +69,7 @@ struct ConvBwdWeightXdlFactory BLOCK.block_size, BLOCK.per_block.m, BLOCK.per_block.n, - GRIDWISE_GEMM.k0_per_block, + BLOCK.per_block.k, GRIDWISE_GEMM.k1, XDL_PARAMS.m_per_xdl, XDL_PARAMS.n_per_xdl, diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp index 9729a72ce71..5ee07fa2d18 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp @@ -63,9 +63,9 @@ constexpr BwdBlockTransfer SetBwdConvBlockTransfer() auto& lds_cfg = TRANSFER.lds_transfer; return BwdBlockTransfer{ - .thread_cluster_dims = {1, block_xfer.k0, block_xfer.m_n, block_xfer.k1}, - .thread_cluster_order = {0, block_order.order[0], block_order.order[1], block_order.order[2]}, - .src_access_order = {0, src_order.order[0], src_order.order[1], src_order.order[2]}, + .thread_cluster_dims = {block_xfer.k_batch_size, block_xfer.k0, block_xfer.m_n, block_xfer.k1}, + .thread_cluster_order = {block_order.order[0], block_order.order[1], block_order.order[2], block_order.order[3]}, + .src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2], src_order.order[3]}, .src_vector_dim = lds_cfg.src_vector_dim, .src_scalar_per_vector = lds_cfg.src_scalar_per_vector, .lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector, diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp index f626bbb288b..9989a20ad08 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp @@ -21,7 +21,7 @@ constexpr auto SIGNATURE = .output = {.config = {.layout = ckb::TensorLayout::GNHWK}}}; constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle{} - .with_thread_block(cku::ThreadBlock_256_128x128x32) + .with_thread_block(cku::ThreadBlock_256_128x128x8) .with_gemm_config(cku::BwdGemmParams_Xdl_4x4_per_wave) .with_transfer(cku::BwdTransfer_4x64x1) .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT); @@ -35,11 +35,8 @@ TEST(BwdWeight_2DFp16_CShufV3_GNHWC, Create) cku::run_test({"DeviceGroupedConvBwdWeight_Xdl_CShuffle", expected_transfer_parameters, "Default", - "Intrawave", - "v3", - "GNHWC,GKYXC,EmptyTuple,GNHWK", - "PassThrough,PassThrough,PassThrough", - "MNKPadding"}); + "GNHWC,GKYXC,GNHWK", + "PassThrough,PassThrough,PassThrough"}); } // TEST(BwdWeight_2DFp16_CShufV3_GNHWC, EndToEnd) diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index 4d5ac2cd9e2..2d2829820d3 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -49,7 +49,6 @@ static_assert(ckb::GridwiseFwdXdlGemmDescriptor); struct GridwiseBwdXdlGemm { - size_t k0_per_block = 0; size_t k1 = 0; XdlParams xdl_params; }; @@ -75,13 +74,25 @@ struct BlockGemm static_assert(ckb::BlockGemmDescriptor); // Describe Aand B block transfer thread cluster lengths. +template struct BlockTransfer { size_t k0; size_t m_n; size_t k1; + size_t k_batch_size; }; -static_assert(ckb::BlockTransferDescriptor); + +// Specialization for forward (IsBwd = false) +template <> +struct BlockTransfer +{ + size_t k0; + size_t m_n; + size_t k1; +}; +static_assert(ckb::BlockTransferDescriptor>); +static_assert(ckb::BlockTransferDescriptor>); // Describe C block transfer thread cluster lengths. struct ThreadCluster @@ -111,31 +122,35 @@ struct Epilogue }; static_assert(EpilogueDescriptor); +template struct AccessOrder { - std::array order; + std::array order; }; -static_assert(AccessOrderDescriptor); +static_assert(AccessOrderDescriptor>); +static_assert(AccessOrderDescriptor>); -struct TransferAB +template +struct InputTransfer { - BlockTransfer block_transfer; + BlockTransfer block_transfer; LdsTransfer lds_transfer; - AccessOrder block_transfer_access_order; - AccessOrder src_access_order; + std::conditional_t, AccessOrder<3>> block_transfer_access_order; + std::conditional_t, AccessOrder<3>> src_access_order; }; -struct TransferC +struct OutputTransfer { ThreadCluster thread_cluster_dims; Epilogue epilogue; }; -struct TransferABC +template +struct Transfer { - TransferAB a; - TransferAB b; - TransferC c; + InputTransfer a; + InputTransfer b; + OutputTransfer c; }; // DL-specific descriptors @@ -198,9 +213,10 @@ struct WmmaGemm_ GridwiseWmmaGemm gridwise_gemm; }; +template struct Transfer_ { - TransferABC transfer; + Transfer transfer; }; struct ConvSpecializationFwd_ @@ -380,7 +396,8 @@ struct ConvAlgorithmTemplate : Components... template constexpr auto with_transfer(const T& t) const { - static_assert(std::is_base_of_v); + static_assert(std::is_base_of_v, ConvAlgorithmTemplate> || + std::is_base_of_v, ConvAlgorithmTemplate>); auto result = *this; result.transfer = t; return result; @@ -511,13 +528,13 @@ struct ConvAlgorithmTemplate : Components... // Algorithm types using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = - ConvAlgorithmTemplate; + ConvAlgorithmTemplate, ConvSpecializationFwd_, Prefetch_>; using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 = - ConvAlgorithmTemplate; + ConvAlgorithmTemplate, ConvSpecializationFwd_, BlockGemm_>; using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle = - ConvAlgorithmTemplate; + ConvAlgorithmTemplate, ConvSpecializationFwd_, Prefetch_>; using ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK = ConvAlgorithmTemplate; using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle = - ConvAlgorithmTemplate; + ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, TransposeParams_>; } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/utils/ckb_conv_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_test_configs.hpp index d176506526f..956f65f4531 100644 --- a/experimental/builder/test/utils/ckb_conv_test_configs.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_configs.hpp @@ -39,7 +39,7 @@ constexpr DlTransferABC DlFwdTransfer{.a = .dst_scalar_per_vector = 4}, }}; -constexpr TransferABC Transfer_4x64x1{ +constexpr Transfer<> Transfer_4x64x1{ .a = { .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, @@ -72,28 +72,29 @@ constexpr TransferABC Transfer_4x64x1{ }, }; -constexpr TransferABC BwdTransfer_4x64x1{ +constexpr bool BWD = true; +constexpr Transfer BwdTransfer_4x64x1{ .a = { - .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, + .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1}, .lds_transfer = {.src_vector_dim = 2, .src_scalar_per_vector = 2, .lds_dst_scalar_per_vector = 4, .is_direct_load = false, .lds_padding = true}, - .block_transfer_access_order = {3, 1, 2}, - .src_access_order = {2, 1, 3}, + .block_transfer_access_order = {0, 3, 1, 2}, + .src_access_order = {0, 2, 1, 3}, }, .b = { - .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, + .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1}, .lds_transfer = {.src_vector_dim = 2, .src_scalar_per_vector = 2, .lds_dst_scalar_per_vector = 4, .is_direct_load = false, .lds_padding = true}, - .block_transfer_access_order = {3, 1, 2}, - .src_access_order = {2, 1, 3}, + .block_transfer_access_order = {0, 3, 1, 2}, + .src_access_order = {0, 2, 1, 3}, }, .c = { @@ -105,7 +106,7 @@ constexpr TransferABC BwdTransfer_4x64x1{ }, }; -constexpr TransferABC Transfer_4x64x1_fp8{ +constexpr Transfer<> Transfer_4x64x1_fp8{ .a = { .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, @@ -138,7 +139,7 @@ constexpr TransferABC Transfer_4x64x1_fp8{ }, }; -constexpr TransferABC Transfer_4x16x1{ +constexpr Transfer<> Transfer_4x16x1{ .a = { .block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1}, @@ -172,7 +173,7 @@ constexpr TransferABC Transfer_4x16x1{ }, }; -constexpr TransferABC Transfer_4x32x1{ +constexpr Transfer<> Transfer_4x32x1{ .a = { .block_transfer = {.k0 = 4, .m_n = 32, .k1 = 1}, @@ -206,7 +207,7 @@ constexpr TransferABC Transfer_4x32x1{ }; constexpr GridwiseBwdXdlGemm BwdGemmParams_Xdl_4x4_per_wave{ - .k0_per_block = 8, .k1 = 8, + .k1 = 8, .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}}; constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_4x4_per_wave{ @@ -244,6 +245,9 @@ constexpr ThreadBlock ThreadBlock_256_128x128x32{.block_size = 256, constexpr ThreadBlock ThreadBlock_256_128x128x16{.block_size = 256, .tile_size = {.m = 128, .n = 128, .k = 16}}; +constexpr ThreadBlock ThreadBlock_256_128x128x8{.block_size = 256, + .tile_size = {.m = 128, .n = 128, .k = 8}}; + constexpr ThreadBlock ThreadBlock_64_64x32x32{.block_size = 64, .tile_size = {.m = 64, .n = 32, .k = 32}}; diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index c3afe2bd4e0..f7096f27f81 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -89,7 +89,7 @@ template <> inline std::string to_string(GridwiseBwdXdlGemm t) { std::ostringstream oss; - oss << t.k0_per_block << "," << t.k1 << "," << t.xdl_params.m_per_xdl << "," << t.xdl_params.n_per_xdl << "," + oss << t.k1 << "," << t.xdl_params.m_per_xdl << "," << t.xdl_params.n_per_xdl << "," << t.xdl_params.m_xdl_per_wave << "," << t.xdl_params.n_xdl_per_wave; return oss.str(); } @@ -120,10 +120,17 @@ inline std::string to_string(BlockGemm t) return oss.str(); } -template <> -inline std::string to_string(BlockTransfer t) +template +inline std::string to_string(BlockTransfer t) { - return array_to_seq(std::array{t.k0, t.m_n, t.k1}); + if constexpr (IsBwd) + { + return array_to_seq(std::array{t.k_batch_size, t.k0, t.m_n, t.k1}); + } + else + { + return array_to_seq(std::array{t.k0, t.m_n, t.k1}); + } } template <> @@ -143,14 +150,14 @@ inline std::string to_string(LdsTransfer t) return oss.str(); } -template <> -inline std::string to_string(AccessOrder t) +template +inline std::string to_string(AccessOrder t) { return array_to_seq(t.order); } -template <> -inline std::string to_string(TransferAB t) +template +inline std::string to_string(InputTransfer t) { std::ostringstream oss; oss << to_string(t.block_transfer) << "," << to_string(t.block_transfer_access_order) << "," @@ -161,7 +168,7 @@ inline std::string to_string(TransferAB t) } template <> -inline std::string to_string(TransferC t) +inline std::string to_string(OutputTransfer t) { std::ostringstream oss; oss << t.epilogue.m_xdl_per_wave_per_shuffle << "," << t.epilogue.n_per_wave_per_shuffle << "," @@ -169,8 +176,8 @@ inline std::string to_string(TransferC t) return oss.str(); } -template <> -inline std::string to_string(TransferABC t) +template +inline std::string to_string(Transfer t) { std::ostringstream oss; oss << to_string(t.a) << "," << to_string(t.b) << "," << to_string(t.c); @@ -260,8 +267,8 @@ inline std::string to_string(WmmaGemm_ t) return to_string(t.gridwise_gemm); } -template <> -inline std::string to_string(Transfer_ t) +template +inline std::string to_string(Transfer_ t) { return to_string(t.transfer); } @@ -323,7 +330,7 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast(t)); + << "," << to_string(static_cast>(t)); return oss.str(); } @@ -333,7 +340,7 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast(t)); + << "," << to_string(static_cast>(t)); return oss.str(); } @@ -343,7 +350,7 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast(t)); + << "," << to_string(static_cast>(t)); return oss.str(); } @@ -371,8 +378,9 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast(t)); + << "," << to_string(static_cast>(t)); return oss.str(); } From 77e10c7b08f7bea336c3c836411df32fcf75295d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Tue, 23 Dec 2025 10:27:38 -0500 Subject: [PATCH 17/81] Concept improvements. --- .../builder/conv_algorithm_concepts.hpp | 25 +++++++++++-------- .../builder/factory/conv_algorithms.hpp | 6 ++--- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index 4d81becfb5a..4feb09cfe93 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -52,23 +52,20 @@ concept GridwiseWmmaGemmDescriptor = requires(T t) { }; +// Concept for vectorized data transfer for convolution input tensors. template -concept HasGemmKBatch = requires(T t) { - { t.k_batch_size}; +concept BlockTransferDescriptor = requires(T t) { + { t.k0 } -> std::convertible_to; + { t.m_n } -> std::convertible_to; + { t.k1 } -> std::convertible_to; }; -// Concept to check if GEMM k batch size is specified. template -concept GemmKBatchSizeWellDefinedIfProvided = - !HasGemmKBatch || requires(T t) { {t.k_batch_size} -> std::convertible_to; }; - -// Concept for vectorized data transfer for convolution input tensors. -template -concept BlockTransferDescriptor = requires(T t) { +concept BlockTransferDescriptorBwd = requires(T t) { { t.k0 } -> std::convertible_to; { t.m_n } -> std::convertible_to; { t.k1 } -> std::convertible_to; - GemmKBatchSizeWellDefinedIfProvided; + { t.k_batch_size } -> std::convertible_to;; }; // Concept for thread cluster dimensions for GEMM output tensor. @@ -210,6 +207,14 @@ concept SpecifiesBlockTransfer = requires(T t) { { T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor; }; +// Concept to check if a struct specifies convolution input and output block transfer info (Bwd direction). +template +concept SpecifiesBlockTransferBwd = requires(T t) { + { T::transfer.a.block_transfer } -> BlockTransferDescriptorBwd; + { T::transfer.b.block_transfer } -> BlockTransferDescriptorBwd; + { T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor; +}; + // Concept to check if a struct specifies convolution scalar per vector infor for A, B and C. template concept SpecifiesTileTransfer = requires(T t) { diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index 4f268b05e92..71e3d03da36 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -242,7 +242,7 @@ template struct BwdXdlAlgorithm { CHECK_CONCEPT(T, ConvAlgorithmDescriptor) CHECK_CONCEPT(T, SpecifiesThreadBlock) - CHECK_CONCEPT(T, SpecifiesBlockTransfer) + CHECK_CONCEPT(T, SpecifiesBlockTransferBwd) CHECK_CONCEPT(T, SpecifiesLdsTransfer) CHECK_CONCEPT(T, SpecifiesThreadClusterAccessOrder) CHECK_CONCEPT(T, SpecifiesSourceAccessOrder) @@ -252,7 +252,7 @@ struct BwdXdlAlgorithm { static constexpr bool c1 = c_ConvAlgorithmDescriptor; static constexpr bool c2 = c_SpecifiesThreadBlock; - static constexpr bool c3 = c_SpecifiesBlockTransfer; + static constexpr bool c3 = c_SpecifiesBlockTransferBwd; static constexpr bool c4 = c_SpecifiesLdsTransfer; static constexpr bool c5 = c_SpecifiesThreadClusterAccessOrder; static constexpr bool c6 = c_SpecifiesSourceAccessOrder; @@ -269,7 +269,7 @@ struct BwdXdlAlgorithm { "Concepts for BwdXdl Algorithm:\n") + DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + DIAGNOSTIC_LINE(SpecifiesThreadBlock) + - DIAGNOSTIC_LINE(SpecifiesBlockTransfer) + + DIAGNOSTIC_LINE(SpecifiesBlockTransferBwd) + DIAGNOSTIC_LINE(SpecifiesLdsTransfer) + DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) + DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) + From ff2fdd8acc919b806d5d58725facac87992798b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Mon, 29 Dec 2025 04:26:06 -0500 Subject: [PATCH 18/81] Improve concept diagnostics. --- .../builder/conv_algorithm_concepts.hpp | 2 +- .../builder/conv_algorithm_diagnostics.hpp | 235 +++++++++++------- 2 files changed, 147 insertions(+), 90 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index 4feb09cfe93..ac279a4de8a 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -65,7 +65,7 @@ concept BlockTransferDescriptorBwd = requires(T t) { { t.k0 } -> std::convertible_to; { t.m_n } -> std::convertible_to; { t.k1 } -> std::convertible_to; - { t.k_batch_size } -> std::convertible_to;; + { t.k_batch_size } -> std::convertible_to; }; // Concept for thread cluster dimensions for GEMM output tensor. diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp index 3b75eb00e9b..6528e0fd444 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp @@ -69,7 +69,7 @@ consteval auto diagnose_thread_block_descriptor() -> std::string { using BlockSizeType = decltype(std::declval().block_size); constexpr bool convertible = std::convertible_to; msg += " → thread_block.block_size: " + std::string(CHECK_MARK(convertible)) + - std::string(get_type_info()) + "\n"; + (convertible ? "" : std::string(get_type_info())) + "\n"; } else { msg += " → thread_block.block_size: [✗] (missing member)\n"; } @@ -78,7 +78,7 @@ consteval auto diagnose_thread_block_descriptor() -> std::string { using TileMType = decltype(std::declval().tile_size.m); constexpr bool convertible = std::convertible_to; msg += " → thread_block.tile_size.m: " + std::string(CHECK_MARK(convertible)) + - std::string(get_type_info()) + "\n"; + (convertible ? "" : std::string(get_type_info())) + "\n"; } else { msg += " → thread_block.tile_size.m: [✗] (missing member)\n"; } @@ -87,7 +87,7 @@ consteval auto diagnose_thread_block_descriptor() -> std::string { using TileNType = decltype(std::declval().tile_size.n); constexpr bool convertible = std::convertible_to; msg += " → thread_block.tile_size.n: " + std::string(CHECK_MARK(convertible)) + - std::string(get_type_info()) + "\n"; + (convertible ? "" : std::string(get_type_info())) + "\n"; } else { msg += " → thread_block.tile_size.n: [✗] (missing member)\n"; } @@ -96,7 +96,7 @@ consteval auto diagnose_thread_block_descriptor() -> std::string { using TileKType = decltype(std::declval().tile_size.k); constexpr bool convertible = std::convertible_to; msg += " → thread_block.tile_size.k: " + std::string(CHECK_MARK(convertible)) + - std::string(get_type_info()) + "\n"; + (convertible ? "" : std::string(get_type_info())) + "\n"; } else { msg += " → thread_block.tile_size.k: [✗] (missing member)\n"; } @@ -114,7 +114,7 @@ consteval auto diagnose_xdl_params() -> std::string { using MPerXdlType = decltype(std::declval().m_per_xdl); constexpr bool convertible = std::convertible_to; msg += " → xdl_params.m_per_xdl: " + std::string(CHECK_MARK(convertible)) + - std::string(get_type_info()) + "\n"; + (convertible ? "" : std::string(get_type_info())) + "\n"; } else { msg += " → xdl_params.m_per_xdl: [✗] (missing member)\n"; } @@ -123,7 +123,7 @@ consteval auto diagnose_xdl_params() -> std::string { using NPerXdlType = decltype(std::declval().n_per_xdl); constexpr bool convertible = std::convertible_to; msg += " → xdl_params.n_per_xdl: " + std::string(CHECK_MARK(convertible)) + - std::string(get_type_info()) + "\n"; + (convertible ? "" : std::string(get_type_info())) + "\n"; } else { msg += " → xdl_params.n_per_xdl: [✗] (missing member)\n"; } @@ -132,7 +132,7 @@ consteval auto diagnose_xdl_params() -> std::string { using MXdlPerWaveType = decltype(std::declval().m_xdl_per_wave); constexpr bool convertible = std::convertible_to; msg += " → xdl_params.m_xdl_per_wave: " + std::string(CHECK_MARK(convertible)) + - std::string(get_type_info()) + "\n"; + (convertible ? "" : std::string(get_type_info())) + "\n"; } else { msg += " → xdl_params.m_xdl_per_wave: [✗] (missing member)\n"; } @@ -141,7 +141,7 @@ consteval auto diagnose_xdl_params() -> std::string { using NXdlPerWaveType = decltype(std::declval().n_xdl_per_wave); constexpr bool convertible = std::convertible_to; msg += " → xdl_params.n_xdl_per_wave: " + std::string(CHECK_MARK(convertible)) + - std::string(get_type_info()) + "\n"; + (convertible ? "" : std::string(get_type_info())) + "\n"; } else { msg += " → xdl_params.n_xdl_per_wave: [✗] (missing member)\n"; } @@ -158,7 +158,7 @@ consteval auto diagnose_block_transfer(const char* prefix) -> std::string { using K0Type = decltype(std::declval().k0); constexpr bool convertible = std::convertible_to; msg += std::string(" → ") + prefix + ".k0: " + std::string(CHECK_MARK(convertible)) + - std::string(get_type_info()) + "\n"; + (convertible ? "" : std::string(get_type_info())) + "\n"; } else { msg += std::string(" → ") + prefix + ".k0: [✗] (missing member)\n"; } @@ -167,7 +167,7 @@ consteval auto diagnose_block_transfer(const char* prefix) -> std::string { using MNType = decltype(std::declval().m_n); constexpr bool convertible = std::convertible_to; msg += std::string(" → ") + prefix + ".m_n: " + std::string(CHECK_MARK(convertible)) + - std::string(get_type_info()) + "\n"; + (convertible ? "" : std::string(get_type_info())) + "\n"; } else { msg += std::string(" → ") + prefix + ".m_n: [✗] (missing member)\n"; } @@ -176,17 +176,54 @@ consteval auto diagnose_block_transfer(const char* prefix) -> std::string { using K1Type = decltype(std::declval().k1); constexpr bool convertible = std::convertible_to; msg += std::string(" → ") + prefix + ".k1: " + std::string(CHECK_MARK(convertible)) + - std::string(get_type_info()) + "\n"; + (convertible ? "" : std::string(get_type_info())) + "\n"; } else { msg += std::string(" → ") + prefix + ".k1: [✗] (missing member)\n"; } - // k_batch_size is optional - only report if it exists and has wrong type + return msg; +} + +// BlockTransferDescriptorBwd diagnostics (requires k_batch_size) +template +consteval auto diagnose_block_transfer_bwd(const char* prefix) -> std::string { + std::string msg; + + if constexpr (requires(BT t) { t.k0; }) { + using K0Type = decltype(std::declval().k0); + constexpr bool convertible = std::convertible_to; + msg += std::string(" → ") + prefix + ".k0: " + std::string(CHECK_MARK(convertible)) + + (convertible ? "" : std::string(get_type_info())) + "\n"; + } else { + msg += std::string(" → ") + prefix + ".k0: [✗] (missing member)\n"; + } + + if constexpr (requires(BT t) { t.m_n; }) { + using MNType = decltype(std::declval().m_n); + constexpr bool convertible = std::convertible_to; + msg += std::string(" → ") + prefix + ".m_n: " + std::string(CHECK_MARK(convertible)) + + (convertible ? "" : std::string(get_type_info())) + "\n"; + } else { + msg += std::string(" → ") + prefix + ".m_n: [✗] (missing member)\n"; + } + + if constexpr (requires(BT t) { t.k1; }) { + using K1Type = decltype(std::declval().k1); + constexpr bool convertible = std::convertible_to; + msg += std::string(" → ") + prefix + ".k1: " + std::string(CHECK_MARK(convertible)) + + (convertible ? "" : std::string(get_type_info())) + "\n"; + } else { + msg += std::string(" → ") + prefix + ".k1: [✗] (missing member)\n"; + } + + // k_batch_size is required for Bwd descriptor if constexpr (requires(BT t) { t.k_batch_size; }) { using KBatchType = decltype(std::declval().k_batch_size); constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + ".k_batch_size (optional): " + std::string(CHECK_MARK(convertible)) + - std::string(get_type_info()) + "\n"; + msg += std::string(" → ") + prefix + ".k_batch_size: " + std::string(CHECK_MARK(convertible)) + + (convertible ? "" : std::string(get_type_info())) + "\n"; + } else { + msg += std::string(" → ") + prefix + ".k_batch_size: [✗] (missing member)\n"; } return msg; @@ -201,7 +238,7 @@ consteval auto diagnose_lds_transfer(const char* prefix) -> std::string { using SrcVectorDimType = decltype(std::declval().src_vector_dim); constexpr bool convertible = std::convertible_to; msg += std::string(" → ") + prefix + ".src_vector_dim: " + std::string(CHECK_MARK(convertible)) + - std::string(get_type_info()) + "\n"; + (convertible ? "" : std::string(get_type_info())) + "\n"; } else { msg += std::string(" → ") + prefix + ".src_vector_dim: [✗] (missing member)\n"; } @@ -210,7 +247,7 @@ consteval auto diagnose_lds_transfer(const char* prefix) -> std::string { using SrcScalarType = decltype(std::declval().src_scalar_per_vector); constexpr bool convertible = std::convertible_to; msg += std::string(" → ") + prefix + ".src_scalar_per_vector: " + std::string(CHECK_MARK(convertible)) + - std::string(get_type_info()) + "\n"; + (convertible ? "" : std::string(get_type_info())) + "\n"; } else { msg += std::string(" → ") + prefix + ".src_scalar_per_vector: [✗] (missing member)\n"; } @@ -219,7 +256,7 @@ consteval auto diagnose_lds_transfer(const char* prefix) -> std::string { using LdsDstScalarType = decltype(std::declval().lds_dst_scalar_per_vector); constexpr bool convertible = std::convertible_to; msg += std::string(" → ") + prefix + ".lds_dst_scalar_per_vector: " + std::string(CHECK_MARK(convertible)) + - std::string(get_type_info()) + "\n"; + (convertible ? "" : std::string(get_type_info())) + "\n"; } else { msg += std::string(" → ") + prefix + ".lds_dst_scalar_per_vector: [✗] (missing member)\n"; } @@ -228,7 +265,7 @@ consteval auto diagnose_lds_transfer(const char* prefix) -> std::string { using IsDirectLoadType = decltype(std::declval().is_direct_load); constexpr bool convertible = std::convertible_to; msg += std::string(" → ") + prefix + ".is_direct_load: " + std::string(CHECK_MARK(convertible)) + - std::string(get_type_info()) + "\n"; + (convertible ? "" : std::string(get_type_info())) + "\n"; } else { msg += std::string(" → ") + prefix + ".is_direct_load: [✗] (missing member)\n"; } @@ -237,7 +274,7 @@ consteval auto diagnose_lds_transfer(const char* prefix) -> std::string { using LdsPaddingType = decltype(std::declval().lds_padding); constexpr bool convertible = std::convertible_to; msg += std::string(" → ") + prefix + ".lds_padding: " + std::string(CHECK_MARK(convertible)) + - std::string(get_type_info()) + "\n"; + (convertible ? "" : std::string(get_type_info())) + "\n"; } else { msg += std::string(" → ") + prefix + ".lds_padding: [✗] (missing member)\n"; } @@ -254,7 +291,7 @@ consteval auto diagnose_thread_cluster(const char* prefix) -> std::string { using MBlockType = decltype(std::declval().m_block); constexpr bool convertible = std::convertible_to; msg += std::string(" → ") + prefix + ".m_block: " + std::string(CHECK_MARK(convertible)) + - std::string(get_type_info()) + "\n"; + (convertible ? "" : std::string(get_type_info())) + "\n"; } else { msg += std::string(" → ") + prefix + ".m_block: [✗] (missing member)\n"; } @@ -263,7 +300,7 @@ consteval auto diagnose_thread_cluster(const char* prefix) -> std::string { using MWaveType = decltype(std::declval().m_wave_per_xdl); constexpr bool convertible = std::convertible_to; msg += std::string(" → ") + prefix + ".m_wave_per_xdl: " + std::string(CHECK_MARK(convertible)) + - std::string(get_type_info()) + "\n"; + (convertible ? "" : std::string(get_type_info())) + "\n"; } else { msg += std::string(" → ") + prefix + ".m_wave_per_xdl: [✗] (missing member)\n"; } @@ -272,7 +309,7 @@ consteval auto diagnose_thread_cluster(const char* prefix) -> std::string { using NBlockType = decltype(std::declval().n_block); constexpr bool convertible = std::convertible_to; msg += std::string(" → ") + prefix + ".n_block: " + std::string(CHECK_MARK(convertible)) + - std::string(get_type_info()) + "\n"; + (convertible ? "" : std::string(get_type_info())) + "\n"; } else { msg += std::string(" → ") + prefix + ".n_block: [✗] (missing member)\n"; } @@ -281,7 +318,7 @@ consteval auto diagnose_thread_cluster(const char* prefix) -> std::string { using NWaveType = decltype(std::declval().n_wave_per_xdl); constexpr bool convertible = std::convertible_to; msg += std::string(" → ") + prefix + ".n_wave_per_xdl: " + std::string(CHECK_MARK(convertible)) + - std::string(get_type_info()) + "\n"; + (convertible ? "" : std::string(get_type_info())) + "\n"; } else { msg += std::string(" → ") + prefix + ".n_wave_per_xdl: [✗] (missing member)\n"; } @@ -300,7 +337,7 @@ consteval auto diagnose_access_order(const char* prefix) -> std::string { constexpr bool convertible_4 = std::convertible_to>; constexpr bool convertible = convertible_3 || convertible_4; msg += std::string(" → ") + prefix + ".order: " + std::string(CHECK_MARK(convertible)) + - std::string(get_type_info()) + "\n"; + (convertible ? "" : std::string(get_type_info())) + "\n"; } else { msg += std::string(" → ") + prefix + ".order: [✗] (missing member)\n"; } @@ -317,7 +354,7 @@ consteval auto diagnose_epilogue(const char* prefix) -> std::string { using MXdlType = decltype(std::declval().m_xdl_per_wave_per_shuffle); constexpr bool convertible = std::convertible_to; msg += std::string(" → ") + prefix + ".m_xdl_per_wave_per_shuffle: " + std::string(CHECK_MARK(convertible)) + - std::string(get_type_info()) + "\n"; + (convertible ? "" : std::string(get_type_info())) + "\n"; } else { msg += std::string(" → ") + prefix + ".m_xdl_per_wave_per_shuffle: [✗] (missing member)\n"; } @@ -326,7 +363,7 @@ consteval auto diagnose_epilogue(const char* prefix) -> std::string { using NPerWaveType = decltype(std::declval().n_per_wave_per_shuffle); constexpr bool convertible = std::convertible_to; msg += std::string(" → ") + prefix + ".n_per_wave_per_shuffle: " + std::string(CHECK_MARK(convertible)) + - std::string(get_type_info()) + "\n"; + (convertible ? "" : std::string(get_type_info())) + "\n"; } else { msg += std::string(" → ") + prefix + ".n_per_wave_per_shuffle: [✗] (missing member)\n"; } @@ -335,7 +372,7 @@ consteval auto diagnose_epilogue(const char* prefix) -> std::string { using ScalarType = decltype(std::declval().scalar_per_vector); constexpr bool convertible = std::convertible_to; msg += std::string(" → ") + prefix + ".scalar_per_vector: " + std::string(CHECK_MARK(convertible)) + - std::string(get_type_info()) + "\n"; + (convertible ? "" : std::string(get_type_info())) + "\n"; } else { msg += std::string(" → ") + prefix + ".scalar_per_vector: [✗] (missing member)\n"; } @@ -353,8 +390,8 @@ consteval auto detailed_diagnostic_ConvAlgorithmDescriptor() -> std::string { template consteval auto detailed_diagnostic_SpecifiesThreadBlock() -> std::string { - if constexpr (!requires { T::thread_block; }) { - return " → T::thread_block member: [✗] (not found)\n"; + if constexpr (!requires { { T::thread_block } -> ThreadBlockDescriptor; }) { + return " → T::thread_block member: [✗] (missing or wrong type)\n"; } else { return " → T::thread_block member: [✓]\n" + detail::diagnose_thread_block_descriptor(); @@ -365,8 +402,8 @@ template consteval auto detailed_diagnostic_SpecifiesGridwiseFwdXdlGemm() -> std::string { std::string msg; - if constexpr (!requires { T::gridwise_gemm; }) { - return " → T::gridwise_gemm member: [✗] (missing member)\n"; + if constexpr (!requires(T t) { { t.gridwise_gemm } -> GridwiseFwdXdlGemmDescriptor; }) { + return " → T::gridwise_gemm member: [✗] (missing or wrong type)\n"; } msg += " → T::gridwise_gemm member: [✓]\n"; @@ -376,7 +413,7 @@ consteval auto detailed_diagnostic_SpecifiesGridwiseFwdXdlGemm() -> std::string using AK1Type = decltype(std::declval().ak1); constexpr bool convertible = std::convertible_to; msg += " → gridwise_gemm.ak1: " + std::string(CHECK_MARK(convertible)) + - std::string(detail::get_type_info()) + "\n"; + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; } else { msg += " → gridwise_gemm.ak1: [✗] (missing member)\n"; } @@ -385,7 +422,7 @@ consteval auto detailed_diagnostic_SpecifiesGridwiseFwdXdlGemm() -> std::string using BK1Type = decltype(std::declval().bk1); constexpr bool convertible = std::convertible_to; msg += " → gridwise_gemm.bk1: " + std::string(CHECK_MARK(convertible)) + - std::string(detail::get_type_info()) + "\n"; + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; } else { msg += " → gridwise_gemm.bk1: [✗] (missing member)\n"; } @@ -404,8 +441,8 @@ template consteval auto detailed_diagnostic_SpecifiesGridwiseBwdXdlGemm() -> std::string { std::string msg; - if constexpr (!requires { T::gridwise_gemm; }) { - return " → T::gridwise_gemm member: [✗] (missing member)\n"; + if constexpr (!requires(T t) { { t.gridwise_gemm } -> GridwiseBwdXdlGemmDescriptor; }) { + return " → T::gridwise_gemm member: [✗] (missing or wrong type)\n"; } msg += " → T::gridwise_gemm member: [✓]\n"; @@ -415,7 +452,7 @@ consteval auto detailed_diagnostic_SpecifiesGridwiseBwdXdlGemm() -> std::string using K1Type = decltype(std::declval().k1); constexpr bool convertible = std::convertible_to; msg += " → gridwise_gemm.k1: " + std::string(CHECK_MARK(convertible)) + - std::string(detail::get_type_info()) + "\n"; + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; } else { msg += " → gridwise_gemm.k1: [✗] (missing member)\n"; } @@ -441,30 +478,54 @@ consteval auto detailed_diagnostic_SpecifiesBlockTransfer() -> std::string { return msg; } - constexpr bool has_a = requires { T::transfer.a; }; - constexpr bool has_b = requires { T::transfer.b; }; - constexpr bool has_c = requires { T::transfer.c; }; - + constexpr bool has_a = requires { { T::transfer.a.block_transfer } -> BlockTransferDescriptor; }; msg += " → T::transfer.a: " + std::string(CHECK_MARK(has_a)) + "\n"; + if constexpr (!has_a) { + msg += " → T::transfer.a.block_transfer: [✗] (missing or wrong type)\n"; + } + + constexpr bool has_b = requires { { T::transfer.b.block_transfer } -> BlockTransferDescriptor; }; msg += " → T::transfer.b: " + std::string(CHECK_MARK(has_b)) + "\n"; + if constexpr (!has_b) { + msg += " → T::transfer.b.block_transfer: [✗] (missing or wrong type)\n"; + } + + constexpr bool has_c = requires { { T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor; }; msg += " → T::transfer.c: " + std::string(CHECK_MARK(has_c)) + "\n"; + if constexpr (!has_c) { + msg += " → T::transfer.c.thread_cluster_dims: [✗] (missing or wrong type)\n"; + } - if constexpr (has_a && requires { T::transfer.a.block_transfer; }) { - msg += detail::diagnose_block_transfer("transfer.a.block_transfer"); - } else if constexpr (has_a) { - msg += " → T::transfer.a.block_transfer: [✗] (missing)\n"; + return msg; +} + +template +consteval auto detailed_diagnostic_SpecifiesBlockTransferBwd() -> std::string { + std::string msg; + + constexpr bool has_transfer = requires { T::transfer; }; + msg += " → T::transfer member: " + std::string(CHECK_MARK(has_transfer)) + "\n"; + + if constexpr (!has_transfer) { + return msg; } - if constexpr (has_b && requires { T::transfer.b.block_transfer; }) { - msg += detail::diagnose_block_transfer("transfer.b.block_transfer"); - } else if constexpr (has_b) { - msg += " → T::transfer.b.block_transfer: [✗] (missing)\n"; + constexpr bool has_a = requires { { T::transfer.a.block_transfer } -> BlockTransferDescriptorBwd; }; + msg += " → T::transfer.a: " + std::string(CHECK_MARK(has_a)) + "\n"; + if constexpr (!has_a) { + msg += " → T::transfer.a.block_transfer: [✗] (missing or wrong type)\n"; } - if constexpr (has_c && requires { T::transfer.c.thread_cluster_dims; }) { - msg += detail::diagnose_thread_cluster("transfer.c.thread_cluster_dims"); - } else if constexpr (has_c) { - msg += " → T::transfer.c.thread_cluster_dims: [✗] (missing)\n"; + constexpr bool has_b = requires { { T::transfer.b.block_transfer } -> BlockTransferDescriptorBwd; }; + msg += " → T::transfer.b: " + std::string(CHECK_MARK(has_b)) + "\n"; + if constexpr (!has_b) { + msg += " → T::transfer.b.block_transfer: [✗] (missing or wrong type)\n"; + } + + constexpr bool has_c = requires { { T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor; }; + msg += " → T::transfer.c: " + std::string(CHECK_MARK(has_c)) + "\n"; + if constexpr (!has_c) { + msg += " → T::transfer.c.thread_cluster_dims: [✗] (missing or wrong type)\n"; } return msg; @@ -528,8 +589,8 @@ template consteval auto detailed_diagnostic_SpecifiesBlockGemm() -> std::string { std::string msg; - if constexpr (!requires { T::block_gemm; }) { - return " → T::block_gemm: [✗] (missing member)\n"; + if constexpr (!requires { { T::block_gemm } -> BlockGemmDescriptor; }) { + return " → T::block_gemm: [✗] (missing or wrong type)\n"; } msg += " → T::block_gemm member: [✓]\n"; @@ -538,7 +599,7 @@ consteval auto detailed_diagnostic_SpecifiesBlockGemm() -> std::string { using PipelineType = decltype(T::block_gemm.pipeline_version); constexpr bool convertible = std::convertible_to; msg += " → block_gemm.pipeline_version: " + std::string(CHECK_MARK(convertible)) + - std::string(detail::get_type_info()) + "\n"; + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; } else { msg += " → block_gemm.pipeline_version: [✗] (missing member)\n"; } @@ -547,7 +608,7 @@ consteval auto detailed_diagnostic_SpecifiesBlockGemm() -> std::string { using SchedulerType = decltype(T::block_gemm.scheduler); constexpr bool convertible = std::convertible_to; msg += " → block_gemm.scheduler: " + std::string(CHECK_MARK(convertible)) + - std::string(detail::get_type_info()) + "\n"; + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; } else { msg += " → block_gemm.scheduler: [✗] (missing member)\n"; } @@ -561,7 +622,7 @@ consteval auto detailed_diagnostic_SpecifiesFwdConvSpecialization() -> std::stri using FwdSpecType = decltype(T::fwd_specialization); constexpr bool convertible = std::convertible_to; return " → T::fwd_specialization: " + std::string(CHECK_MARK(convertible)) + - std::string(detail::get_type_info()) + "\n"; + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; } else { return " → T::fwd_specialization: [✗] (missing member)\n"; } @@ -573,7 +634,7 @@ consteval auto detailed_diagnostic_SpecifiesBwdWeightConvSpecialization() -> std using BwdSpecType = decltype(T::bwd_weight_specialization); constexpr bool convertible = std::convertible_to; return " → T::bwd_weight_specialization: " + std::string(CHECK_MARK(convertible)) + - std::string(detail::get_type_info()) + "\n"; + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; } else { return " → T::bwd_weight_specialization: [✗] (missing member)\n"; } @@ -585,7 +646,7 @@ consteval auto detailed_diagnostic_SpecifiesGemmSpecialization() -> std::string using GemmSpecType = decltype(T::gemm_specialization); constexpr bool convertible = std::convertible_to; return " → T::gemm_specialization: " + std::string(CHECK_MARK(convertible)) + - std::string(detail::get_type_info()) + "\n"; + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; } else { return " → T::gemm_specialization: [✗] (missing member)\n"; } @@ -597,7 +658,7 @@ consteval auto detailed_diagnostic_SpecifiesNumPrefetchStages() -> std::string { using NumPrefetchType = decltype(T::num_gemm_k_prefetch_stages); constexpr bool convertible = std::convertible_to; return " → T::num_gemm_k_prefetch_stages: " + std::string(CHECK_MARK(convertible)) + - std::string(detail::get_type_info()) + "\n"; + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; } else { return " → T::num_gemm_k_prefetch_stages: [✗] (missing member)\n"; } @@ -609,7 +670,7 @@ consteval auto detailed_diagnostic_SpecifiesNumGroupsToMerge() -> std::string { using NumGroupsType = decltype(T::num_groups_to_merge); constexpr bool convertible = std::convertible_to; return " → T::num_groups_to_merge: " + std::string(CHECK_MARK(convertible)) + - std::string(detail::get_type_info()) + "\n"; + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; } else { return " → T::num_groups_to_merge: [✗] (missing member)\n"; } @@ -621,7 +682,7 @@ consteval auto detailed_diagnostic_SpecifiesLoopScheduler() -> std::string { using LoopSchedulerType = decltype(T::loop_scheduler); constexpr bool convertible = std::convertible_to; return " → T::loop_scheduler: " + std::string(CHECK_MARK(convertible)) + - std::string(detail::get_type_info()) + "\n"; + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; } else { return " → T::loop_scheduler: [✗] (missing member)\n"; } @@ -634,7 +695,7 @@ consteval auto detailed_diagnostic_SpecifiesLargeTensorSupport() -> std::string using SpecType = decltype(T::specialization); constexpr bool convertible = std::convertible_to; msg += " → T::specialization: " + std::string(CHECK_MARK(convertible)) + - std::string(detail::get_type_info()) + "\n"; + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; if constexpr (convertible) { constexpr bool is_large_tensor = (T::specialization == ConvAlgorithmSpecialization::LARGE_TENSOR); @@ -655,7 +716,7 @@ consteval auto detailed_diagnostic_SpecifiesTransposeTransfer() -> std::string { using SrcType = decltype(T::max_transpose_transfer_src_scalar_per_vector); constexpr bool convertible = std::convertible_to; msg += " → T::max_transpose_transfer_src_scalar_per_vector: " + std::string(CHECK_MARK(convertible)) + - std::string(detail::get_type_info()) + "\n"; + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; } else { msg += " → T::max_transpose_transfer_src_scalar_per_vector: [✗] (missing member)\n"; } @@ -664,7 +725,7 @@ consteval auto detailed_diagnostic_SpecifiesTransposeTransfer() -> std::string { using DstType = decltype(T::max_transpose_transfer_dst_scalar_per_vector); constexpr bool convertible = std::convertible_to; msg += " → T::max_transpose_transfer_dst_scalar_per_vector: " + std::string(CHECK_MARK(convertible)) + - std::string(detail::get_type_info()) + "\n"; + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; } else { msg += " → T::max_transpose_transfer_dst_scalar_per_vector: [✗] (missing member)\n"; } @@ -675,7 +736,7 @@ consteval auto detailed_diagnostic_SpecifiesTransposeTransfer() -> std::string { template consteval auto detailed_diagnostic_SpecifiesGridwiseWmmaGemm() -> std::string { std::string msg; - constexpr bool has_gridwise_gemm = requires { T::gridwise_gemm; }; + constexpr bool has_gridwise_gemm = requires(T t) { { t.gridwise_gemm } -> GridwiseWmmaGemmDescriptor; }; msg += " → T::gridwise_gemm member: " + std::string(CHECK_MARK(has_gridwise_gemm)) + "\n"; if constexpr (!has_gridwise_gemm) { @@ -703,8 +764,8 @@ consteval auto detailed_diagnostic_SpecifiesGridwiseWmmaGemm() -> std::string { // Tile-specific diagnostics template consteval auto detailed_diagnostic_SpecifiesTileThreadBlock() -> std::string { - if constexpr (!requires { T::thread_block; }) { - return " → T::thread_block member: [✗] (not found)\n"; + if constexpr (!requires { { T::thread_block } -> TileThreadBlockDescriptor; }) { + return " → T::thread_block member: [✗] (missing or wrong type)\n"; } else { using TB = decltype(T::thread_block); std::string msg = " → T::thread_block member: [✓]\n"; @@ -745,7 +806,7 @@ consteval auto detailed_diagnostic_SpecifiesTileTransfer() -> std::string { template consteval auto detailed_diagnostic_SpecifiesTileBlockGemm() -> std::string { std::string msg; - constexpr bool has_block_gemm = requires { T::block_gemm; }; + constexpr bool has_block_gemm = requires { { T::block_gemm } -> TileBlockGemmDescriptor; }; msg += " → T::block_gemm member: " + std::string(CHECK_MARK(has_block_gemm)) + "\n"; if constexpr (!has_block_gemm) { @@ -781,7 +842,7 @@ consteval auto detailed_diagnostic_SpecifiesTileBlockGemm() -> std::string { template consteval auto detailed_diagnostic_SpecifiesTileOptimizations() -> std::string { std::string msg; - constexpr bool has_optimizations = requires { T::optimizations; }; + constexpr bool has_optimizations = requires { { T::optimizations } -> TileOptimizationsDescriptor; }; msg += " → T::optimizations member: " + std::string(CHECK_MARK(has_optimizations)) + "\n"; if constexpr (!has_optimizations) { @@ -804,7 +865,7 @@ consteval auto detailed_diagnostic_SpecifiesTileOptimizations() -> std::string { template consteval auto detailed_diagnostic_SpecifiesDlThreadConfig() -> std::string { std::string msg; - constexpr bool has_thread_config = requires { T::thread_config; }; + constexpr bool has_thread_config = requires { { T::thread_config } -> DlThreadConfigDescriptor; }; msg += " → T::thread_config member: " + std::string(CHECK_MARK(has_thread_config)) + "\n"; if constexpr (!has_thread_config) { @@ -830,7 +891,7 @@ consteval auto detailed_diagnostic_SpecifiesDlThreadConfig() -> std::string { template consteval auto detailed_diagnostic_SpecifiesDlThreadCluster() -> std::string { std::string msg; - constexpr bool has_thread_cluster = requires { T::thread_cluster; }; + constexpr bool has_thread_cluster = requires { { T::thread_cluster } -> DlThreadClusterDescriptor; }; msg += " → T::thread_cluster member: " + std::string(CHECK_MARK(has_thread_cluster)) + "\n"; if constexpr (!has_thread_cluster) { @@ -926,7 +987,7 @@ consteval auto detailed_diagnostic_SpecifiesTileConvSpecialization() -> std::str using SpecType = decltype(T::specialization); constexpr bool convertible = std::convertible_to; return " → T::specialization: " + std::string(CHECK_MARK(convertible)) + - std::string(detail::get_type_info()) + "\n"; + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; } else { return " → T::specialization: [✗] (missing member)\n"; } @@ -943,26 +1004,22 @@ consteval auto detailed_diagnostic_SpecifiesLdsTransfer() -> std::string { return msg; } - constexpr bool has_a = requires { T::transfer.a; }; - constexpr bool has_b = requires { T::transfer.b; }; - constexpr bool has_c = requires { T::transfer.c; }; - - if constexpr (has_a && requires { T::transfer.a.lds_transfer; }) { - msg += detail::diagnose_lds_transfer("transfer.a.lds_transfer"); - } else if constexpr (has_a) { - msg += " → T::transfer.a.lds_transfer: [✗] (missing member)\n"; + constexpr bool has_a = requires { { T::transfer.a.lds_transfer } -> LdsTransferDescriptor; }; + msg += " → T::transfer.a: " + std::string(CHECK_MARK(has_a)) + "\n"; + if constexpr (!has_a) { + msg += " → T::transfer.a.lds_transfer: [✗] (missing or wrong type)\n"; } - if constexpr (has_b && requires { T::transfer.b.lds_transfer; }) { - msg += detail::diagnose_lds_transfer("transfer.b.lds_transfer"); - } else if constexpr (has_b) { - msg += " → T::transfer.b.lds_transfer: [✗] (missing member)\n"; + constexpr bool has_b = requires { { T::transfer.b.lds_transfer } -> LdsTransferDescriptor; }; + msg += " → T::transfer.b: " + std::string(CHECK_MARK(has_b)) + "\n"; + if constexpr (!has_b) { + msg += " → T::transfer.b.lds_transfer: [✗] (missing or wrong type)\n"; } - if constexpr (has_c && requires { T::transfer.c.epilogue; }) { - msg += detail::diagnose_epilogue("transfer.c.epilogue"); - } else if constexpr (has_c) { - msg += " → T::transfer.c.epilogue: [✗] (missing member)\n"; + constexpr bool has_c = requires { { T::transfer.c.epilogue } -> EpilogueDescriptor; }; + msg += " → T::transfer.c: " + std::string(CHECK_MARK(has_c)) + "\n"; + if constexpr (!has_c) { + msg += " → T::transfer.c.epilogue: [✗] (missing or wrong type)\n"; } return msg; From 8c80e005bd4159c42e8c4af719a8f5e59630ae3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Mon, 29 Dec 2025 04:53:19 -0500 Subject: [PATCH 19/81] Introduve a common size type for concepts. --- .../builder/conv_algorithm_concepts.hpp | 110 +++++++++--------- 1 file changed, 57 insertions(+), 53 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index ac279a4de8a..cdbe805cdd5 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -15,22 +15,26 @@ namespace ck_tile::builder { /* Descriptors for individual elements of the algorithm description */ /********************************************************************/ +// Common concept for size-related fields +template +concept SizeType = std::unsigned_integral>; + // Concept for thread block dimensions for a GEMM problem. template concept ThreadBlockDescriptor = requires(T t) { - { t.block_size } -> std::convertible_to; - { t.tile_size.m } -> std::convertible_to; - { t.tile_size.n } -> std::convertible_to; - { t.tile_size.k } -> std::convertible_to; + { t.block_size } -> SizeType; + { t.tile_size.m } -> SizeType; + { t.tile_size.n } -> SizeType; + { t.tile_size.k } -> SizeType; }; // Concept for parameters that describe a gridwise XDL GEMM problem. template concept GridwiseXdlGemmDescriptor = requires(T t) { - { t.m_per_xdl } -> std::convertible_to; - { t.n_per_xdl } -> std::convertible_to; - { t.m_xdl_per_wave } -> std::convertible_to; - { t.n_xdl_per_wave } -> std::convertible_to; + { t.m_per_xdl } -> SizeType; + { t.n_per_xdl } -> SizeType; + { t.m_xdl_per_wave } -> SizeType; + { t.n_xdl_per_wave } -> SizeType; }; // Concept for parameter that describe block GEMM problem. @@ -43,11 +47,11 @@ concept BlockGemmDescriptor = requires(T t) { // Concept for parameters that describe a gridwise WMMA GEMM problem. template concept GridwiseWmmaGemmDescriptor = requires(T t) { - { t.k1 } -> std::convertible_to; - { t.m_per_wmma } -> std::convertible_to; - { t.n_per_wmma } -> std::convertible_to; - { t.m_wmma_per_wave } -> std::convertible_to; - { t.n_wmma_per_wave } -> std::convertible_to; + { t.k1 } -> SizeType; + { t.m_per_wmma } -> SizeType; + { t.n_per_wmma } -> SizeType; + { t.m_wmma_per_wave } -> SizeType; + { t.n_wmma_per_wave } -> SizeType; { t.pipeline_version } -> std::convertible_to; }; @@ -55,34 +59,34 @@ concept GridwiseWmmaGemmDescriptor = requires(T t) { // Concept for vectorized data transfer for convolution input tensors. template concept BlockTransferDescriptor = requires(T t) { - { t.k0 } -> std::convertible_to; - { t.m_n } -> std::convertible_to; - { t.k1 } -> std::convertible_to; + { t.k0 } -> SizeType; + { t.m_n } -> SizeType; + { t.k1 } -> SizeType; }; template concept BlockTransferDescriptorBwd = requires(T t) { - { t.k0 } -> std::convertible_to; - { t.m_n } -> std::convertible_to; - { t.k1 } -> std::convertible_to; - { t.k_batch_size } -> std::convertible_to; + { t.k0 } -> SizeType; + { t.m_n } -> SizeType; + { t.k1 } -> SizeType; + { t.k_batch_size } -> SizeType; }; // Concept for thread cluster dimensions for GEMM output tensor. template concept ThreadClusterDescriptor = requires(T t) { - { t.m_block } -> std::convertible_to; - { t.m_wave_per_xdl } -> std::convertible_to; - { t.n_block } -> std::convertible_to; - { t.n_wave_per_xdl } -> std::convertible_to; + { t.m_block } -> SizeType; + { t.m_wave_per_xdl } -> SizeType; + { t.n_block } -> SizeType; + { t.n_wave_per_xdl } -> SizeType; }; // Concept for the LDS transfer for the convolution input tensors. template concept LdsTransferDescriptor = requires(T t) { - { t.src_vector_dim } -> std::convertible_to; - { t.src_scalar_per_vector } -> std::convertible_to; - { t.lds_dst_scalar_per_vector } -> std::convertible_to; + { t.src_vector_dim } -> SizeType; + { t.src_scalar_per_vector } -> SizeType; + { t.lds_dst_scalar_per_vector } -> SizeType; { t.is_direct_load } -> std::convertible_to; { t.lds_padding } -> std::convertible_to; }; @@ -91,9 +95,9 @@ concept LdsTransferDescriptor = requires(T t) { // LDS). template concept EpilogueDescriptor = requires(T t) { - { t.m_xdl_per_wave_per_shuffle } -> std::convertible_to; - { t.n_per_wave_per_shuffle } -> std::convertible_to; - { t.scalar_per_vector } -> std::convertible_to; + { t.m_xdl_per_wave_per_shuffle } -> SizeType; + { t.n_per_wave_per_shuffle } -> SizeType; + { t.scalar_per_vector } -> SizeType; }; // Concept for the thread cluster access order @@ -108,18 +112,18 @@ concept AccessOrderDescriptor = requires(T t) { // size is deduced from block gemm structure). template concept TileThreadBlockDescriptor = requires(T t) { - { t.tile_size.m } -> std::convertible_to; - { t.tile_size.n } -> std::convertible_to; - { t.tile_size.k } -> std::convertible_to; + { t.tile_size.m } -> SizeType; + { t.tile_size.n } -> SizeType; + { t.tile_size.k } -> SizeType; }; // Concept for thread block dimensions for a GEMM problem for CK Tile (Block // size is deduced from block gemm structure). template concept TileTransferDescriptor = requires(T t) { - { t.a_scalar_per_vector } -> std::convertible_to; - { t.b_scalar_per_vector } -> std::convertible_to; - { t.c_scalar_per_vector } -> std::convertible_to; + { t.a_scalar_per_vector } -> SizeType; + { t.b_scalar_per_vector } -> SizeType; + { t.c_scalar_per_vector } -> SizeType; }; // Concept to check if struct specifies block GEMM (CK Tile). @@ -169,15 +173,15 @@ concept SpecifiesTileThreadBlock = requires { // Concept to check if a struct specifies gridwise XDL GEMM info. template concept GridwiseFwdXdlGemmDescriptor = requires (T t){ - { t.ak1 } -> std::convertible_to; - { t.bk1 } -> std::convertible_to; + { t.ak1 } -> SizeType; + { t.bk1 } -> SizeType; { t.xdl_params } -> GridwiseXdlGemmDescriptor; }; // Concept to check if a struct specifies gridwise XDL GEMM info. template concept GridwiseBwdXdlGemmDescriptor = requires (T t){ - { t.k1 } -> std::convertible_to; + { t.k1 } -> SizeType; { t.xdl_params } -> GridwiseXdlGemmDescriptor; }; @@ -218,9 +222,9 @@ concept SpecifiesBlockTransferBwd = requires(T t) { // Concept to check if a struct specifies convolution scalar per vector infor for A, B and C. template concept SpecifiesTileTransfer = requires(T t) { - { T::transfer.a_scalar_per_vector } -> std::convertible_to; - { T::transfer.b_scalar_per_vector } -> std::convertible_to; - { T::transfer.c_scalar_per_vector } -> std::convertible_to; + { T::transfer.a_scalar_per_vector } -> SizeType; + { T::transfer.b_scalar_per_vector } -> SizeType; + { T::transfer.c_scalar_per_vector } -> SizeType; }; // Concept to check if a struct specifies LDS transfer info for tensors A, B, and C. @@ -297,12 +301,12 @@ concept SpecifiesGemmSpecialization = requires { template concept SpecifiesNumPrefetchStages = requires { - { T::num_gemm_k_prefetch_stages } -> std::convertible_to; + { T::num_gemm_k_prefetch_stages } -> SizeType; }; template concept SpecifiesNumGroupsToMerge = requires { - { T::num_groups_to_merge } -> std::convertible_to; + { T::num_groups_to_merge } -> SizeType; }; template @@ -318,8 +322,8 @@ concept SpecifiesLargeTensorSupport = requires { template concept SpecifiesTransposeTransfer = requires { - { T::max_transpose_transfer_src_scalar_per_vector } -> std::convertible_to; - { T::max_transpose_transfer_dst_scalar_per_vector } -> std::convertible_to; + { T::max_transpose_transfer_src_scalar_per_vector } -> SizeType; + { T::max_transpose_transfer_dst_scalar_per_vector } -> SizeType; }; /******************************************** */ @@ -329,11 +333,11 @@ concept SpecifiesTransposeTransfer = requires { // Concept for DL thread configuration template concept DlThreadConfigDescriptor = requires(T t) { - { t.k0_per_block } -> std::convertible_to; - { t.k1 } -> std::convertible_to; - { t.m1_per_thread } -> std::convertible_to; - { t.n1_per_thread } -> std::convertible_to; - { t.k_per_thread } -> std::convertible_to; + { t.k0_per_block } -> SizeType; + { t.k1 } -> SizeType; + { t.m1_per_thread } -> SizeType; + { t.n1_per_thread } -> SizeType; + { t.k_per_thread } -> SizeType; }; // Concept for DL thread cluster @@ -359,8 +363,8 @@ concept DlBlockTransferDescriptor = requires(T t) { template concept DlEpilogueDescriptor = requires(T t) { { t.src_dst_access_order } -> std::convertible_to>; - { t.src_dst_vector_dim } -> std::convertible_to; - { t.dst_scalar_per_vector } -> std::convertible_to; + { t.src_dst_vector_dim } -> SizeType; + { t.dst_scalar_per_vector } -> SizeType; }; // Concept to check if algorithm specifies DL thread config From 30a968687701c5332fb50860cf6522a33f42229a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Mon, 29 Dec 2025 04:56:22 -0500 Subject: [PATCH 20/81] Update compiletime diagnostics to use the size type. --- .../builder/conv_algorithm_diagnostics.hpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp index 6528e0fd444..2340e19e614 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp @@ -67,7 +67,7 @@ consteval auto diagnose_thread_block_descriptor() -> std::string { if constexpr (requires(TB t) { t.block_size; }) { using BlockSizeType = decltype(std::declval().block_size); - constexpr bool convertible = std::convertible_to; + constexpr bool convertible = SizeType; msg += " → thread_block.block_size: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(get_type_info())) + "\n"; } else { @@ -76,7 +76,7 @@ consteval auto diagnose_thread_block_descriptor() -> std::string { if constexpr (requires(TB t) { t.tile_size.m; }) { using TileMType = decltype(std::declval().tile_size.m); - constexpr bool convertible = std::convertible_to; + constexpr bool convertible = SizeType; msg += " → thread_block.tile_size.m: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(get_type_info())) + "\n"; } else { @@ -85,7 +85,7 @@ consteval auto diagnose_thread_block_descriptor() -> std::string { if constexpr (requires(TB t) { t.tile_size.n; }) { using TileNType = decltype(std::declval().tile_size.n); - constexpr bool convertible = std::convertible_to; + constexpr bool convertible = SizeType; msg += " → thread_block.tile_size.n: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(get_type_info())) + "\n"; } else { @@ -94,7 +94,7 @@ consteval auto diagnose_thread_block_descriptor() -> std::string { if constexpr (requires(TB t) { t.tile_size.k; }) { using TileKType = decltype(std::declval().tile_size.k); - constexpr bool convertible = std::convertible_to; + constexpr bool convertible = SizeType; msg += " → thread_block.tile_size.k: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(get_type_info())) + "\n"; } else { @@ -112,7 +112,7 @@ consteval auto diagnose_xdl_params() -> std::string { if constexpr (requires(XdlParams t) { t.m_per_xdl; }) { using MPerXdlType = decltype(std::declval().m_per_xdl); - constexpr bool convertible = std::convertible_to; + constexpr bool convertible = SizeType; msg += " → xdl_params.m_per_xdl: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(get_type_info())) + "\n"; } else { @@ -121,7 +121,7 @@ consteval auto diagnose_xdl_params() -> std::string { if constexpr (requires(XdlParams t) { t.n_per_xdl; }) { using NPerXdlType = decltype(std::declval().n_per_xdl); - constexpr bool convertible = std::convertible_to; + constexpr bool convertible = SizeType; msg += " → xdl_params.n_per_xdl: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(get_type_info())) + "\n"; } else { @@ -130,7 +130,7 @@ consteval auto diagnose_xdl_params() -> std::string { if constexpr (requires(XdlParams t) { t.m_xdl_per_wave; }) { using MXdlPerWaveType = decltype(std::declval().m_xdl_per_wave); - constexpr bool convertible = std::convertible_to; + constexpr bool convertible = SizeType; msg += " → xdl_params.m_xdl_per_wave: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(get_type_info())) + "\n"; } else { @@ -139,7 +139,7 @@ consteval auto diagnose_xdl_params() -> std::string { if constexpr (requires(XdlParams t) { t.n_xdl_per_wave; }) { using NXdlPerWaveType = decltype(std::declval().n_xdl_per_wave); - constexpr bool convertible = std::convertible_to; + constexpr bool convertible = SizeType; msg += " → xdl_params.n_xdl_per_wave: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(get_type_info())) + "\n"; } else { From 027d943b2fe67cfd300388e2a9a7a0fd9daf36fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Mon, 29 Dec 2025 05:06:39 -0500 Subject: [PATCH 21/81] Update conv specialization enum. --- .../builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp | 2 +- .../builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp | 2 +- .../builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp | 2 +- .../builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp | 4 ++-- .../conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp | 2 +- .../builder/test/conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp | 4 ++-- .../builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp | 2 +- .../builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp | 2 +- .../builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp | 2 +- .../test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp | 4 ++-- .../builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp | 2 +- .../builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp | 2 +- .../builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp | 2 +- experimental/builder/test/conv/ck/test_conv_traits.cpp | 6 +++--- experimental/builder/test/test_conv_description.cpp | 2 +- experimental/builder/test/unit_conv_tuning_params.cpp | 2 +- 16 files changed, 21 insertions(+), 21 deletions(-) diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp index 8c59dd21b16..3eca32c8970 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp @@ -33,7 +33,7 @@ TEST(FwdConvInstances, .with_thread_block(ThreadBlock_256_256x256x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) .with_transfer(Transfer_4x64x1) - .with_fwd_specializations(ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0, + .with_fwd_specializations(ConvSpecialization::FILTER_1X1_STRIDE1_PAD0, GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v2_intrawave); diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp index 7ab3ac605b8..e543ce6fa0e 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp @@ -30,7 +30,7 @@ TEST(FwdConvInstances, .with_thread_block(ThreadBlock_64_64x32x32) .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) .with_transfer(Transfer_4x16x1) - .with_fwd_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_prefetch_config(1, 2, PipelineScheduler::DEFAULT); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp index d8fbd3827e0..c87ffbd0663 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp @@ -32,7 +32,7 @@ TEST(FwdConvInstances, .with_thread_block(ThreadBlock_128_64x64x64) .with_gemm_config(FwdGemmParams_Wmma_2x1_per_wave) .with_transfer(Transfer_4x32x1) - .with_fwd_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_prefetch_config(1, 0, PipelineScheduler::DEFAULT); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp index 5f3bdfe4140..0ce530c1f85 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp @@ -30,7 +30,7 @@ TEST(FwdConvInstances, .with_thread_block(ThreadBlock_256_256x256x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) .with_transfer(Transfer_4x64x1) - .with_fwd_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v1_intrawave); using Builder = ConvBuilder; @@ -67,7 +67,7 @@ TEST(FwdConvInstances, .with_thread_block(ThreadBlock_256_256x256x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) .with_transfer(Transfer_4x64x1) - .with_fwd_specializations(ConvFwdSpecialization::FILTER_3x3, GemmSpecialization::MNKPadding) + .with_fwd_specializations(ConvSpecialization::FILTER_3x3, GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v5_intrawave); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp index f6403b312cc..56efb70f24c 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp @@ -35,7 +35,7 @@ TEST(FwdConvInstances, .with_thread_block(ThreadBlock_64_64x32x32) .with_gemm_config(FwdGemmParams_Xdl_2x2_per_wave) .with_transfer(Transfer_4x16x1) - .with_fwd_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp index a49f55e6d67..b740ac8704b 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp @@ -26,7 +26,7 @@ TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Ins constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK{} .with_thread_block(ThreadBlock_256_128x128x16) - .with_fwd_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_dl_thread_config(DlThreadConfig_16x2x4x4x1) .with_dl_thread_cluster(DlThreadCluster_8x2) .with_dl_transfer(DlFwdTransfer); @@ -60,7 +60,7 @@ TEST(FwdConvInstances, constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK{} .with_thread_block(ThreadBlock_256_128x128x16) - .with_fwd_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0, + .with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0, GemmSpecialization::MNKPadding) .with_dl_thread_config(DlThreadConfig_16x2x4x4x1) .with_dl_thread_cluster(DlThreadCluster_8x2) diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp index 7b543e0e3ee..0f09ce6e8e5 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp @@ -24,7 +24,7 @@ constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xd .with_thread_block(cku::ThreadBlock_256_256x256x32) .with_gemm_config(cku::FwdGemmParams_Xdl_4x4_per_wave) .with_transfer(cku::Transfer_4x64x1) - .with_fwd_specializations(ckb::ConvFwdSpecialization::DEFAULT, + .with_fwd_specializations(ckb::ConvSpecialization::DEFAULT, ckb::GemmSpecialization::MNKPadding) .with_block_gemm(cku::BlockGemmDesc_v3_intrawave); diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp index a3f493c1d71..a8f97c417e3 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp @@ -29,7 +29,7 @@ TEST(FwdConvInstances, .with_thread_block(ThreadBlock_256_128x128x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) .with_transfer(Transfer_4x64x1) - .with_fwd_specializations(ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0, + .with_fwd_specializations(ConvSpecialization::FILTER_1X1_STRIDE1_PAD0, GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v4_intrawave); diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp index 279c942ba90..c95d48d7125 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp @@ -30,7 +30,7 @@ TEST(FwdConvInstances, .with_thread_block(ThreadBlock_256_256x128x32) .with_gemm_config(FwdGemmParams_Xdl_4x2_per_wave) .with_transfer(Transfer_4x64x1_fp8) - .with_fwd_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp index cd9655f0b45..4ad3283883e 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp @@ -30,7 +30,7 @@ TEST(FwdConvInstances, .with_thread_block(ThreadBlock_256_256x128x32) .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) .with_transfer(Transfer_4x16x1) - .with_fwd_specializations(ConvFwdSpecialization::DEFAULT, + .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT)}; @@ -67,7 +67,7 @@ TEST( .with_thread_block(ThreadBlock_128_128x128x32) .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) .with_transfer(Transfer_4x16x1) - .with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0, + .with_specializations(ConvSpecialization::FILTER_1X1_PAD0, GemmSpecialization::MNKPadding) .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT)}; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp index b29b4471c34..9e6ca00e581 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp @@ -30,7 +30,7 @@ TEST(FwdConvInstances, .with_thread_block(ThreadBlock_256_256x256x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) .with_transfer(Transfer_4x64x1) - .with_fwd_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v3_intrawave); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp index 9c4b9d4ec0a..5e4e84543f1 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp @@ -30,7 +30,7 @@ TEST(FwdConvInstances, .with_thread_block(ThreadBlock_256_128x128x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) .with_transfer(Transfer_4x64x1) - .with_fwd_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0, + .with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0, GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v4_intrawave); diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp index ed0ec0e3c1d..2c2e80b1a7c 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp @@ -30,7 +30,7 @@ TEST(FwdConvInstances, .with_thread_block(ThreadBlock_256_256x256x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) .with_transfer(Transfer_4x64x1) - .with_fwd-specializations(ConvFwdSpecialization::FILTER_1X1_PAD0, + .with_fwd-specializations(ConvSpecialization::FILTER_1X1_PAD0, GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v1_intrawave); diff --git a/experimental/builder/test/conv/ck/test_conv_traits.cpp b/experimental/builder/test/conv/ck/test_conv_traits.cpp index d5661ad67bf..b3a76e4e11d 100644 --- a/experimental/builder/test/conv/ck/test_conv_traits.cpp +++ b/experimental/builder/test/conv/ck/test_conv_traits.cpp @@ -101,7 +101,7 @@ TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction) // Verify specializations EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); - EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT); + EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT); // Verify algorithm information EXPECT_EQ(Traits::thread_block_size, 256); @@ -229,7 +229,7 @@ TEST_F(ConvTraitsTest, ConvFwdBaseTraitsExtraction) // Verify specializations EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); - EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT); + EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT); // Verify algorithm information EXPECT_EQ(Traits::thread_block_size, 256); @@ -313,7 +313,7 @@ TEST_F(ConvTraitsTest, ConvFwdLargeTensorTraitsExtraction) // Verify specializations EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); - EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT); + EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT); // Verify algorithm information EXPECT_EQ(Traits::thread_block_size, 256); diff --git a/experimental/builder/test/test_conv_description.cpp b/experimental/builder/test/test_conv_description.cpp index 158cb2668f8..98c8ebbab15 100644 --- a/experimental/builder/test/test_conv_description.cpp +++ b/experimental/builder/test/test_conv_description.cpp @@ -148,7 +148,7 @@ struct DefaultAlgorithm }, }; - ckb::ConvFwdSpecialization fwd_specialization = ckb::ConvFwdSpecialization::DEFAULT; + ckb::ConvFwdSpecialization fwd_specialization = ckb::ConvSpecialization::DEFAULT; ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::Default; ckb::test::BlockGemm block_gemm{.pipeline_version = ckb::PipelineVersion::V4, .scheduler = ckb::PipelineScheduler::INTRAWAVE}; diff --git a/experimental/builder/test/unit_conv_tuning_params.cpp b/experimental/builder/test/unit_conv_tuning_params.cpp index b35a1ced55a..2ea181ffb0f 100644 --- a/experimental/builder/test/unit_conv_tuning_params.cpp +++ b/experimental/builder/test/unit_conv_tuning_params.cpp @@ -79,7 +79,7 @@ TEST(ConvTuningParams, AssignsFwdConvSpecialization) constexpr struct Algorithm { ckb::ConvFwdSpecialization fwd_specialization = - ckb::ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0; + ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0; } kAlgorithm; constexpr auto conv_spec = SetFwdConvSpecialization(); From 3bd0f050819607d0207f44c1114a693042f336b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Mon, 29 Dec 2025 05:31:35 -0500 Subject: [PATCH 22/81] Fix fwd conv builder tests. --- .../builder/factory/conv_bwd_weight_xdl_factory.hpp | 2 +- .../builder/factory/conv_fwd_large_tensor_factory.hpp | 9 +++++---- .../ck_tile/builder/factory/conv_fwd_v3_factory.hpp | 9 +++++---- .../ck_tile/builder/factory/conv_fwd_xdl_factory.hpp | 9 +++++---- .../builder/include/ck_tile/builder/testing/conv_fwd.hpp | 3 +-- .../conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp | 2 +- .../builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp | 2 +- experimental/builder/test/impl/conv_algorithm_types.hpp | 2 +- 8 files changed, 20 insertions(+), 18 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp index 8cb0d5d2ae9..6ad5820daba 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp @@ -35,7 +35,7 @@ struct ConvBwdWeightXdlFactory static constexpr auto BLOCK = internal::SetThreadBlockInfo(); static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; - static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; + static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; static constexpr auto A_BLOCK_TRANSFER = internal::SetBwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp index fca36386974..62547dbe326 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp @@ -42,6 +42,7 @@ struct ConvFwdLargeTensorFactory static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler(); static constexpr auto BLOCK = internal::SetThreadBlockInfo(); static constexpr auto GRIDWISE_GEMM = BASE_ALGORITHM.gridwise_gemm; + static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; static constexpr auto A_BLOCK_TRANSFER = internal::SetFwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -84,10 +85,10 @@ struct ConvFwdLargeTensorFactory BLOCK.per_block.k, GRIDWISE_GEMM.ak1, GRIDWISE_GEMM.bk1, - GRIDWISE_GEMM.m_per_xdl, - GRIDWISE_GEMM.n_per_xdl, - GRIDWISE_GEMM.m_xdl_per_wave, - GRIDWISE_GEMM.n_xdl_per_wave, + XDL_PARAMS.m_per_xdl, + XDL_PARAMS.n_per_xdl, + XDL_PARAMS.m_xdl_per_wave, + XDL_PARAMS.n_xdl_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp index 47891869cce..30fd555dd98 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp @@ -43,6 +43,7 @@ struct ConvFwdXdlV3Factory static constexpr auto BLOCK = internal::SetThreadBlockInfo(); static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; static constexpr auto A_BLOCK_TRANSFER = internal::SetFwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -84,10 +85,10 @@ struct ConvFwdXdlV3Factory BLOCK.per_block.k, GRIDWISE_GEMM.ak1, GRIDWISE_GEMM.bk1, - GRIDWISE_GEMM.m_per_xdl, - GRIDWISE_GEMM.n_per_xdl, - GRIDWISE_GEMM.m_xdl_per_wave, - GRIDWISE_GEMM.n_xdl_per_wave, + XDL_PARAMS.m_per_xdl, + XDL_PARAMS.n_per_xdl, + XDL_PARAMS.m_xdl_per_wave, + XDL_PARAMS.n_xdl_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp index 695f1546143..16baf4fbced 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp @@ -39,6 +39,7 @@ struct ConvFwdXdlFactory static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler(); static constexpr auto BLOCK = internal::SetThreadBlockInfo(); static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; static constexpr auto A_BLOCK_TRANSFER = internal::SetFwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -80,10 +81,10 @@ struct ConvFwdXdlFactory BLOCK.per_block.k, GRIDWISE_GEMM.ak1, GRIDWISE_GEMM.bk1, - GRIDWISE_GEMM.m_per_xdl, - GRIDWISE_GEMM.n_per_xdl, - GRIDWISE_GEMM.m_xdl_per_wave, - GRIDWISE_GEMM.n_xdl_per_wave, + XDL_PARAMS.m_per_xdl, + XDL_PARAMS.n_per_xdl, + XDL_PARAMS.m_xdl_per_wave, + XDL_PARAMS.n_xdl_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, diff --git a/experimental/builder/include/ck_tile/builder/testing/conv_fwd.hpp b/experimental/builder/include/ck_tile/builder/testing/conv_fwd.hpp index f329a8a4d3d..c8f6451ec22 100644 --- a/experimental/builder/include/ck_tile/builder/testing/conv_fwd.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/conv_fwd.hpp @@ -62,8 +62,7 @@ struct Args using Ops = factory::internal::ElementwiseOps; // TODO: We shouldn't need to call into an internal namespace here. - using Layouts = - factory::internal::ConvTensorLayouts; + using Layouts = factory::internal::ConvTensorLayouts; ConvTensorLengths lengths; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp index 4ad3283883e..e7ad9060147 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp @@ -67,7 +67,7 @@ TEST( .with_thread_block(ThreadBlock_128_128x128x32) .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) .with_transfer(Transfer_4x16x1) - .with_specializations(ConvSpecialization::FILTER_1X1_PAD0, + .with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0, GemmSpecialization::MNKPadding) .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT)}; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp index 2c2e80b1a7c..6b69ca89a69 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp @@ -30,7 +30,7 @@ TEST(FwdConvInstances, .with_thread_block(ThreadBlock_256_256x256x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) .with_transfer(Transfer_4x64x1) - .with_fwd-specializations(ConvSpecialization::FILTER_1X1_PAD0, + .with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0, GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v1_intrawave); diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index 2d2829820d3..e3c947f5414 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -379,7 +379,7 @@ struct ConvAlgorithmTemplate : Components... { result.gridwise_gemm = gemm; } - if constexpr(std::is_base_of_v) + else if constexpr(std::is_base_of_v) { result.gridwise_gemm = gemm; } From 52086b350af2393d536b98afcc11732e61ec8792 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Mon, 29 Dec 2025 05:44:51 -0500 Subject: [PATCH 23/81] Fix smoke tests. --- .../factory/helpers/ck/conv_tensor_layout.hpp | 2 +- .../factory/helpers/ck/conv_tuning_params.hpp | 1 + .../builder/test/test_conv_description.cpp | 18 ++++---- .../builder/test/unit_conv_tensor_layout.cpp | 46 +++++++++---------- .../builder/test/unit_conv_tuning_params.cpp | 2 +- 5 files changed, 36 insertions(+), 33 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp index 22c026c28f7..566524f3a07 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp @@ -208,7 +208,7 @@ consteval auto GetAuxiliaryTensorLayouts() SPATIAL_DIM>{}; } -template +template requires(!HasElementwiseOpWithAuxiliaryOperands) consteval auto GetAuxiliaryTensorLayouts() { diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp index 92a7b48ddd9..f3cf3bcc3cc 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" diff --git a/experimental/builder/test/test_conv_description.cpp b/experimental/builder/test/test_conv_description.cpp index 98c8ebbab15..8e5411fabbd 100644 --- a/experimental/builder/test/test_conv_description.cpp +++ b/experimental/builder/test/test_conv_description.cpp @@ -107,14 +107,16 @@ struct DefaultAlgorithm ckb::test::ThreadBlock thread_block{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; - ckb::test::GridwiseXdlGemm gridwise_gemm{.ak1 = 8, - .bk1 = 8, - .m_per_xdl = 16, - .n_per_xdl = 16, - .m_xdl_per_wave = 4, - .n_xdl_per_wave = 4}; + ckb::test::GridwiseFwdXdlGemm gridwise_gemm{.ak1 = 8, + .bk1 = 8, + .xdl_params = + { + .m_per_xdl = 16, + .n_per_xdl = 16, + .m_xdl_per_wave = 4, + .n_xdl_per_wave = 4}}; - ckb::test::TransferABC transfer{ + ckb::test::Transfer<> transfer{ .a = { .block_transfer = {.k0 = 1, .m_n = 128, .k1 = 2}, @@ -148,7 +150,7 @@ struct DefaultAlgorithm }, }; - ckb::ConvFwdSpecialization fwd_specialization = ckb::ConvSpecialization::DEFAULT; + ckb::ConvSpecialization fwd_specialization = ckb::ConvSpecialization::DEFAULT; ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::Default; ckb::test::BlockGemm block_gemm{.pipeline_version = ckb::PipelineVersion::V4, .scheduler = ckb::PipelineScheduler::INTRAWAVE}; diff --git a/experimental/builder/test/unit_conv_tensor_layout.cpp b/experimental/builder/test/unit_conv_tensor_layout.cpp index ce31f419338..8c1ba5562eb 100644 --- a/experimental/builder/test/unit_conv_tensor_layout.cpp +++ b/experimental/builder/test/unit_conv_tensor_layout.cpp @@ -38,7 +38,7 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_NWGC_GKXC_NWGK) .weight = {.config = {.layout = GKXC}}, .output = {.config = {.layout = NWGK}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); @@ -57,7 +57,7 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKXC_NGKW) .weight = {.config = {.layout = GKXC}}, .output = {.config = {.layout = NGKW}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); @@ -76,7 +76,7 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_GNWC_GKXC_GNWK) .weight = {.config = {.layout = GKXC}}, .output = {.config = {.layout = GNWK}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); @@ -95,7 +95,7 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKCX_NGKW) .weight = {.config = {.layout = GKCX}}, .output = {.config = {.layout = NGKW}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); @@ -114,7 +114,7 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKYXC_NGKHW) .weight = {.config = {.layout = GKYXC}}, .output = {.config = {.layout = NGKHW}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); @@ -133,7 +133,7 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_NHWGC_GKYXC_NHWGK) .weight = {.config = {.layout = GKYXC}}, .output = {.config = {.layout = NHWGK}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); @@ -152,7 +152,7 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_GNHWC_GKYXC_GNHWK) .weight = {.config = {.layout = GKYXC}}, .output = {.config = {.layout = GNHWK}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); @@ -171,7 +171,7 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKCYX_NGKHW) .weight = {.config = {.layout = GKCYX}}, .output = {.config = {.layout = NGKHW}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); @@ -190,7 +190,7 @@ TEST(ConvTensorLayout, AssignsLayoutsFor3D_NGCDHW_GKCZYX_NGKDHW) .weight = {.config = {.layout = GKCZYX}}, .output = {.config = {.layout = NGKDHW}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); @@ -209,7 +209,7 @@ TEST(ConvTensorLayout, AssignsLayoutsFor3D_NDHWGC_GKZYXC_NDHWGK) .weight = {.config = {.layout = GKZYXC}}, .output = {.config = {.layout = NDHWGK}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); @@ -228,7 +228,7 @@ TEST(ConvTensorLayout, AssignsLayoutsFor3D_GNDHWC_GKZYXC_GNDHWK) .weight = {.config = {.layout = GKZYXC}}, .output = {.config = {.layout = GNDHWK}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); @@ -273,7 +273,7 @@ TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithG_K_Layout) static constexpr std::array aux_configs = { MockAuxiliaryTensorConfig{.layout = G_K_strided}}; - using AuxLayouts = AuxiliaryTensorLayouts; + using AuxLayouts = AuxiliaryTensorLayouts; EXPECT_EQ(AuxLayouts::Size, 1); using ExpectedType = ck::Tuple; @@ -287,7 +287,7 @@ TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithGC_Layout) static constexpr std::array aux_configs = { MockAuxiliaryTensorConfig{.layout = TensorLayout::GC}}; - using AuxLayouts = AuxiliaryTensorLayouts; + using AuxLayouts = AuxiliaryTensorLayouts; EXPECT_EQ(AuxLayouts::Size, 1); using ExpectedType = ck::Tuple; @@ -301,7 +301,7 @@ TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithG_C_Layout) static constexpr std::array aux_configs = { MockAuxiliaryTensorConfig{.layout = G_C_strided}}; - using AuxLayouts = AuxiliaryTensorLayouts; + using AuxLayouts = AuxiliaryTensorLayouts; EXPECT_EQ(AuxLayouts::Size, 1); using ExpectedType = ck::Tuple; @@ -316,7 +316,7 @@ TEST(AuxiliaryTensorLayoutIntegration, TwoAuxiliaryTensors) MockAuxiliaryTensorConfig{.layout = TensorLayout::G_K_strided}, MockAuxiliaryTensorConfig{.layout = GC}}; - using AuxLayouts = AuxiliaryTensorLayouts; + using AuxLayouts = AuxiliaryTensorLayouts; EXPECT_EQ(AuxLayouts::Size, 2); using ExpectedType = @@ -333,7 +333,7 @@ TEST(AuxiliaryTensorLayoutIntegration, ThreeAuxiliaryTensors) MockAuxiliaryTensorConfig{.layout = GC}, MockAuxiliaryTensorConfig{.layout = G_C_strided}}; - using AuxLayouts = AuxiliaryTensorLayouts; + using AuxLayouts = AuxiliaryTensorLayouts; EXPECT_EQ(AuxLayouts::Size, 3); using ExpectedType = ck::Tuple aux_configs = { MockAuxiliaryTensorConfig{.layout = G_K_strided}}; - using AuxLayouts = AuxiliaryTensorLayouts; + using AuxLayouts = AuxiliaryTensorLayouts; EXPECT_EQ(AuxLayouts::Size, 1); using ExpectedType = ck::Tuple; @@ -363,7 +363,7 @@ TEST(AuxiliaryTensorLayoutIntegration, WorksWith3DConvolution) static constexpr std::array aux_configs = { MockAuxiliaryTensorConfig{.layout = GC}}; - using AuxLayouts = AuxiliaryTensorLayouts; + using AuxLayouts = AuxiliaryTensorLayouts; EXPECT_EQ(AuxLayouts::Size, 1); using ExpectedType = ck::Tuple; @@ -387,7 +387,7 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasG_K) .operation = OutputOp{.elementwise_operation = ElementwiseOperation::SCALE}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); @@ -414,7 +414,7 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasGC) .operation = OutputOp{.elementwise_operation = ElementwiseOperation::SCALE}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); @@ -442,7 +442,7 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithTwoAuxiliaryTensors) .operation = OutputOp{.elementwise_operation = ElementwiseOperation::SCALEADD_SCALEADD_RELU}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); @@ -470,7 +470,7 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv1DWithBias) .operation = OutputOp{.elementwise_operation = ElementwiseOperation::SCALE}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); @@ -497,7 +497,7 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv3DWithBias) .operation = OutputOp{.elementwise_operation = ElementwiseOperation::BIAS_BNORM_CLAMP}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); diff --git a/experimental/builder/test/unit_conv_tuning_params.cpp b/experimental/builder/test/unit_conv_tuning_params.cpp index 2ea181ffb0f..ee1388a77f7 100644 --- a/experimental/builder/test/unit_conv_tuning_params.cpp +++ b/experimental/builder/test/unit_conv_tuning_params.cpp @@ -78,7 +78,7 @@ TEST(ConvTuningParams, AssignsFwdConvSpecialization) { constexpr struct Algorithm { - ckb::ConvFwdSpecialization fwd_specialization = + ckb::ConvSpecialization fwd_specialization = ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0; } kAlgorithm; constexpr auto conv_spec = SetFwdConvSpecialization(); From 9926d942e9ddc82d22daa24957b0b7bd4f56ec8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Mon, 29 Dec 2025 07:03:55 -0500 Subject: [PATCH 24/81] Separate bwd weigth and bwd data tests into separate targets. --- experimental/builder/test/CMakeLists.txt | 12 +++-- ...test_ckb_conv_bwd_weight_xdl_cshuffle.cpp} | 44 +------------------ 2 files changed, 10 insertions(+), 46 deletions(-) rename experimental/builder/test/conv/ck/{test_ckb_conv_bwd_weight.cpp => test_ckb_conv_bwd_weight_xdl_cshuffle.cpp} (56%) diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index d0f645e9059..8ffdf3f5435 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -137,15 +137,20 @@ add_ck_builder_test(test_ckb_build_fwd_instances conv/ck/test_ckb_conv_fwd_3d_fp16.cpp conv/ck/test_ckb_conv_fwd_3d_fp32.cpp conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp - conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp - conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp) + ) target_link_libraries(test_ckb_build_fwd_instances PRIVATE utility) add_ck_builder_test(test_ckb_build_bwd_weight_instances - conv/ck/test_ckb_conv_bwd_weight.cpp + conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle.cpp + conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp ) target_link_libraries(test_ckb_build_bwd_weight_instances PRIVATE utility) +add_ck_builder_test(test_ckb_build_bwd_data_instances + conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp + ) +target_link_libraries(test_ckb_build_bwd_data_instances PRIVATE utility) + ################################################################################ # FACTORY TESTS - Expensive Regression Tests (Full MIOpen Kernel Set) @@ -199,6 +204,7 @@ set(CKB_REGRESSION_TESTS test_ckb_instance_string test_ckb_build_fwd_instances test_ckb_build_bwd_weight_instances + test_ckb_build_bwd_data_instances test_ckb_testing_utils # test_ckb_factory_grouped_convolution_forward_convscale # test_ckb_factory_grouped_convolution_forward_scaleadd_ab diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle.cpp similarity index 56% rename from experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle.cpp index 9989a20ad08..ad11eba6938 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle.cpp @@ -4,7 +4,6 @@ #include "utils/ckb_conv_test_configs.hpp" #include "utils/ckb_conv_test_utils.hpp" #include "utils/conv_algorithm_type_utils.hpp" -//#include "ck_tile/builder/testing/conv_bwd_ck.hpp" #include "ck_tile/host/device_prop.hpp" namespace ckb = ck_tile::builder; @@ -29,7 +28,7 @@ constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CSh using Builder = ckb::ConvBuilder; using Instance = Builder::Instance; -TEST(BwdWeight_2DFp16_CShufV3_GNHWC, Create) +TEST(BwdWeight_2DFp16_CShuffle_GNHWC, Create) { const auto expected_transfer_parameters = to_string(ALGORITHM); cku::run_test({"DeviceGroupedConvBwdWeight_Xdl_CShuffle", @@ -38,44 +37,3 @@ TEST(BwdWeight_2DFp16_CShufV3_GNHWC, Create) "GNHWC,GKYXC,GNHWK", "PassThrough,PassThrough,PassThrough"}); } - -// TEST(BwdWeight_2DFp16_CShufV3_GNHWC, EndToEnd) -// { -// if(!ck_tile::get_device_name().starts_with("gfx9")) -// { -// GTEST_SKIP() << "unsupported architecture"; -// } - -// ckt::Args args = { -// .lengths = -// { -// .batch_size = 16, -// .groups = 1, -// .input_channels = 32, -// .output_channels = 48, -// .image = -// { -// .width = 56, -// .height = 64, -// }, -// .filter = -// { -// .width = 3, -// .height = 5, -// }, -// }, -// .filter_strides = {.width = 1, .height = 1}, -// .filter_dilation = {.width = 1, .height = 1}, -// .input_left_pad = {.width = 0, .height = 0}, -// .input_right_pad = {.width = 0, .height = 0}, -// .a_elementwise_op = {}, -// .b_elementwise_op = {}, -// .cde_elementwise_op = {}, -// }; - -// auto inputs = alloc_inputs(args); -// auto outputs = alloc_outputs(args); - -// auto conv = Instance{}; -// ckt::run(conv, args, inputs.get(), outputs.get()); -// } From 277981bc9b95ca881c144b000a105b8ae8cb8766 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Mon, 29 Dec 2025 07:15:08 -0500 Subject: [PATCH 25/81] Clean-up CK Tile builder tests. --- .../test_ckb_conv_bwd_data_2d_fp16_v3.cpp | 12 ++++++------ .../test_ckb_conv_bwd_weight_2d_fp16_v3.cpp | 12 ++++++------ .../ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp | 4 ++-- .../test/utils/ckb_conv_tile_test_configs.hpp | 18 +++++++++--------- 4 files changed, 23 insertions(+), 23 deletions(-) diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp index ad31fc52bcf..677a3043715 100644 --- a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp @@ -8,9 +8,9 @@ namespace { using namespace ck_tile::builder::test_utils; -TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC) +TEST(BwdDataConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC) { - constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + constexpr ConvSignature BwdDataConvSignature{.spatial_dim = 2, .direction = ConvDirection::BACKWARD_DATA, .data_type = DataType::FP16, .accumulation_data_type = DataType::FP32, @@ -18,16 +18,16 @@ TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP1 .weight = {.config = {.layout = TensorLayout::GKYXC}}, .output = {.config = {.layout = TensorLayout::NHWGK}}}; - constexpr auto FwdConvAlgorithm = + constexpr auto BwdDataConvAlgorithm = ConvAlgorithm_Tile_GroupedConvolutionKernel{} .with_tile_specializations(TileConvSpecialization::DEFAULT) - .with_tile_thread_block(FwdTileThreadBlock_64x64x64) + .with_tile_thread_block(TileThreadBlock_64x64x64) .with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave) - .with_tile_transfer(FwdTileTransfer_4x4x4) + .with_tile_transfer(TileTransfer_4x4x4) .with_tile_optimizations(TileOptimizations{ .num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false}); - using Builder = ConvBuilder; + using Builder = ConvBuilder; run_ck_tile_test({ "grouped_convolution_backward_data", "fp16", diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp index 47908e0e5b7..f3de0bb762b 100644 --- a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp @@ -8,9 +8,9 @@ namespace { using namespace ck_tile::builder::test_utils; -TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC) +TEST(BwdWeightConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC) { - constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + constexpr ConvSignature BwdWeightConvSignature{.spatial_dim = 2, .direction = ConvDirection::BACKWARD_WEIGHT, .data_type = DataType::FP16, .accumulation_data_type = DataType::FP32, @@ -18,16 +18,16 @@ TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP1 .weight = {.config = {.layout = TensorLayout::GKYXC}}, .output = {.config = {.layout = TensorLayout::NHWGK}}}; - constexpr auto FwdConvAlgorithm = + constexpr auto BwdWeightConvAlgorithm = ConvAlgorithm_Tile_GroupedConvolutionKernel{} .with_tile_specializations(TileConvSpecialization::DEFAULT) - .with_tile_thread_block(FwdTileThreadBlock_64x64x64) + .with_tile_thread_block(TileThreadBlock_64x64x64) .with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave) - .with_tile_transfer(FwdTileTransfer_4x4x4) + .with_tile_transfer(TileTransfer_4x4x4) .with_tile_optimizations(TileOptimizations{ .num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false}); - using Builder = ConvBuilder; + using Builder = ConvBuilder; run_ck_tile_test({ "grouped_convolution_backward_weight", "fp16", diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp index 083d9d99552..9a8a4ce753e 100644 --- a/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp @@ -21,9 +21,9 @@ TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP1 constexpr auto FwdConvAlgorithm = ConvAlgorithm_Tile_GroupedConvolutionKernel{} .with_tile_specializations(TileConvSpecialization::DEFAULT) - .with_tile_thread_block(FwdTileThreadBlock_64x64x64) + .with_tile_thread_block(TileThreadBlock_64x64x64) .with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave) - .with_tile_transfer(FwdTileTransfer_4x4x4) + .with_tile_transfer(TileTransfer_4x4x4) .with_tile_optimizations(TileOptimizations{ .num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false}); diff --git a/experimental/builder/test/utils/ckb_conv_tile_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_tile_test_configs.hpp index 377234dd19a..41a12508543 100644 --- a/experimental/builder/test/utils/ckb_conv_tile_test_configs.hpp +++ b/experimental/builder/test/utils/ckb_conv_tile_test_configs.hpp @@ -12,35 +12,35 @@ namespace ck_tile::builder::test_utils { using namespace ck_tile::builder; using namespace test; -constexpr TileTransfer FwdTileTransfer_1x1x1{ +constexpr TileTransfer TileTransfer_1x1x1{ .a_scalar_per_vector = 1, .b_scalar_per_vector = 1, .c_scalar_per_vector = 1, }; -constexpr TileTransfer FwdTileTransfer_4x4x4{ +constexpr TileTransfer TileTransfer_4x4x4{ .a_scalar_per_vector = 4, .b_scalar_per_vector = 4, .c_scalar_per_vector = 4, }; -constexpr TileTransfer FwdTileTransfer_8x8x8{ +constexpr TileTransfer TileTransfer_8x8x8{ .a_scalar_per_vector = 8, .b_scalar_per_vector = 8, .c_scalar_per_vector = 8, }; -constexpr TileThreadBlock FwdTileThreadBlock_256x256x32{.tile_size = {.m = 256, .n = 256, .k = 32}}; +constexpr TileThreadBlock TileThreadBlock_256x256x32{.tile_size = {.m = 256, .n = 256, .k = 32}}; -constexpr TileThreadBlock FwdTileThreadBlock_256x128x32{.tile_size = {.m = 256, .n = 128, .k = 32}}; +constexpr TileThreadBlock TileThreadBlock_256x128x32{.tile_size = {.m = 256, .n = 128, .k = 32}}; -constexpr TileThreadBlock FwdTileThreadBlock_128x128x32{.tile_size = {.m = 128, .n = 128, .k = 32}}; +constexpr TileThreadBlock TileThreadBlock_128x128x32{.tile_size = {.m = 128, .n = 128, .k = 32}}; -constexpr TileThreadBlock FwdTileThreadBlock_128x128x16{.tile_size = {.m = 128, .n = 128, .k = 16}}; +constexpr TileThreadBlock TileThreadBlock_128x128x16{.tile_size = {.m = 128, .n = 128, .k = 16}}; -constexpr TileThreadBlock FwdTileThreadBlock_64x32x32{.tile_size = {.m = 64, .n = 32, .k = 32}}; +constexpr TileThreadBlock TileThreadBlock_64x32x32{.tile_size = {.m = 64, .n = 32, .k = 32}}; -constexpr TileThreadBlock FwdTileThreadBlock_64x64x64{.tile_size = {.m = 64, .n = 64, .k = 64}}; +constexpr TileThreadBlock TileThreadBlock_64x64x64{.tile_size = {.m = 64, .n = 64, .k = 64}}; constexpr TileBlockGemm TileBlockGemmDesc_16x16_v1_intrawave = { .warps = {.m = 2, .n = 2, .k = 1}, From 80f44824f592b35df1fb6db0db1efa8c81bca09c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Mon, 29 Dec 2025 09:12:14 -0500 Subject: [PATCH 26/81] Add bwd weight XDL CShuffle V3 factory. --- .../ck_tile/builder/conv_algorithm_limits.hpp | 16 +-- .../builder/factory/conv_algorithms.hpp | 61 ++++++++++- .../factory/conv_bwd_weight_xdl_factory.hpp | 8 +- .../conv_bwd_weight_xdl_v3_factory.hpp | 103 ++++++++++++++++++ .../builder/factory/conv_dispatcher.hpp | 24 ++++ .../factory/conv_fwd_large_tensor_factory.hpp | 8 +- .../builder/factory/conv_fwd_v3_factory.hpp | 8 +- .../builder/factory/conv_fwd_wmma_factory.hpp | 8 +- .../builder/factory/conv_fwd_xdl_factory.hpp | 8 +- .../helpers/ck/conv_block_transfer.hpp | 44 ++++++-- experimental/builder/test/CMakeLists.txt | 1 + ...st_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp | 43 ++++++++ .../test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp | 2 +- .../test/impl/conv_algorithm_types.hpp | 33 +++--- .../test/utils/ckb_conv_test_configs.hpp | 43 +++++++- .../test/utils/conv_algorithm_type_utils.hpp | 37 +++++-- 16 files changed, 377 insertions(+), 70 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp create mode 100644 experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp index f60e7703a3e..d35897fc78d 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp @@ -29,18 +29,20 @@ concept OutputVectorTransferLimits = requires { // Limits for access order. Must be a permutation of {0, 1, 2}. template -concept AccessOrderLimits = requires { +concept AccessOrderLimits3D = requires { requires((Value[0] != Value[1]) && (Value[0] != Value[2]) && (Value[1] != Value[2]) && (Value[0] >= 0 && Value[0] < 3) && (Value[1] >= 0 && Value[1] < 3) && - (Value[2] >= 0 && Value[2] < 3)); + (Value[2] >= 0 && Value[2] < 3) && (Value.Size() == 3)); }; -// Limits for access order. Must be a permutation of {1, 2, 3} for the last three elements. +// Limits for access order. Must be a permutation of {0, 1, 2, 3}. template -concept BwdAccessOrderLimits = requires { - requires((Value[1] != Value[2]) && (Value[1] != Value[3]) && (Value[2] != Value[3]) && - (Value[1] >= 1 && Value[1] < 4) && (Value[2] >= 1 && Value[2] < 4) && - (Value[3] >= 1 && Value[3] < 4)) && (Value[0] == 0); +concept AccessOrderLimits4D = requires { + requires((Value[0] != Value[1]) && (Value[0] != Value[2]) && (Value[0] != Value[3]) && + (Value[1] != Value[2]) && (Value[1] != Value[3]) && (Value[2] != Value[3]) && + (Value[0] >= 0 && Value[0] < 4) && (Value[1] >= 0 && Value[1] < 4) && + (Value[2] >= 0 && Value[2] < 4) && (Value[3] >= 0 && Value[3] < 4) && + (Value.Size() == 4)); }; } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index 71e3d03da36..e537b7ba99b 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -279,6 +279,47 @@ struct BwdXdlAlgorithm { } }; +template +struct BwdXdlV3Algorithm { + CHECK_CONCEPT(T, ConvAlgorithmDescriptor) + CHECK_CONCEPT(T, SpecifiesThreadBlock) + CHECK_CONCEPT(T, SpecifiesBlockTransferBwd) + CHECK_CONCEPT(T, SpecifiesLdsTransfer) + CHECK_CONCEPT(T, SpecifiesThreadClusterAccessOrder) + CHECK_CONCEPT(T, SpecifiesSourceAccessOrder) + CHECK_CONCEPT(T, SpecifiesGridwiseBwdXdlGemm) + CHECK_CONCEPT(T, SpecifiesBwdWeightConvSpecialization) + CHECK_CONCEPT(T, SpecifiesBlockGemm) + + static constexpr bool c1 = c_ConvAlgorithmDescriptor; + static constexpr bool c2 = c_SpecifiesThreadBlock; + static constexpr bool c3 = c_SpecifiesBlockTransferBwd; + static constexpr bool c4 = c_SpecifiesLdsTransfer; + static constexpr bool c5 = c_SpecifiesThreadClusterAccessOrder; + static constexpr bool c6 = c_SpecifiesSourceAccessOrder; + static constexpr bool c7 = c_SpecifiesGridwiseBwdXdlGemm; + static constexpr bool c8 = c_SpecifiesBwdWeightConvSpecialization; + static constexpr bool c9 = c_SpecifiesBlockGemm; + + static consteval bool is_valid() { + return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9; + } + + static consteval auto message() -> std::string { + return std::string("\n=== Backward XDL V3 Algorithm Diagnostic (closest match) ===\n" + "Concepts for BwdXdlV3 Algorithm:\n") + + DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + + DIAGNOSTIC_LINE(SpecifiesThreadBlock) + + DIAGNOSTIC_LINE(SpecifiesBlockTransferBwd) + + DIAGNOSTIC_LINE(SpecifiesLdsTransfer) + + DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) + + DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) + + DIAGNOSTIC_LINE(SpecifiesGridwiseBwdXdlGemm) + + DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization) + + DIAGNOSTIC_LINE(SpecifiesBlockGemm); + } +}; + template consteval int count_matches_fwd_xdl_v3() { using Alg = FwdXdlV3Algorithm; @@ -309,6 +350,12 @@ consteval int count_matches_bwd_xdl() { return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9; } +template +consteval int count_matches_bwd_xdl_v3() { + using Alg = BwdXdlV3Algorithm; + return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9; +} + template consteval int count_matches_large_tensor() { using Alg = LargeTensorAlgorithm; @@ -368,12 +415,20 @@ consteval void diagnose_fwd_algorithm_signature() template consteval void diagnose_bwd_weight_algorithm_signature() { - constexpr int xdl_matches = count_matches_fwd_xdl(); - constexpr int max_matches = xdl_matches; + constexpr int xdl_matches = count_matches_bwd_xdl(); + constexpr int xdl_v3_matches = count_matches_fwd_xdl_v3(); + + constexpr int max_matches = xdl_v3_matches > xdl_matches ? xdl_v3_matches : xdl_matches; + if constexpr (max_matches == xdl_matches) { using Alg = BwdXdlAlgorithm; static_assert(Alg::is_valid(), Alg::message()); - } else { + } + else if constexpr (max_matches == xdl_v3_matches) { + using Alg = BwdXdlV3Algorithm; + static_assert(Alg::is_valid(), Alg::message()); + } + else { // This should never happen static_assert(false, "Internal Error: No matching algorithm variant found for diagnostics."); } diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp index 6ad5820daba..8790121ed93 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp @@ -47,10 +47,10 @@ struct ConvBwdWeightXdlFactory static_assert(InputVectorTransferLimits); static_assert(InputVectorTransferLimits); static_assert(OutputVectorTransferLimits); - static_assert(BwdAccessOrderLimits); - static_assert(BwdAccessOrderLimits); - static_assert(BwdAccessOrderLimits); - static_assert(BwdAccessOrderLimits); + static_assert(AccessOrderLimits4D); + static_assert(AccessOrderLimits4D); + static_assert(AccessOrderLimits4D); + static_assert(AccessOrderLimits4D); // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle< diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp new file mode 100644 index 00000000000..14121be940a --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp @@ -0,0 +1,103 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_limits.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" + +namespace ck_tile::builder::factory { + +// Factory for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle instance +// of a grouped forward convolution kernel. +template + requires ConvDirectionIsBackwardWeight +struct ConvBwdWeightXdlV3Factory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = internal::ConvTensorLayouts; + using Types = internal::BwdWeightConvTensorDataTypes; + using Ops = internal::ElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto BWD_CONV_SPECIALIZATION = internal::SetBwdWeightConvSpecialization(); + + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; + static constexpr auto A_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); + static constexpr auto BLOCK_GEMM = internal::SetBlockGemm(); + + // Check limits for the algorithm parameters. + // TODO: Add more limits checks as needed. + static_assert(InputVectorTransferLimits, "Invalid A block transfer config"); + static_assert(InputVectorTransferLimits, "Invalid B block transfer config"); + static_assert(OutputVectorTransferLimits, "Invalid C block transfer config"); + static_assert(AccessOrderLimits3D, "Invalid A thread cluster access order"); + static_assert(AccessOrderLimits3D, "Invalid B thread cluster access order"); + static_assert(AccessOrderLimits3D, "Invalid A source access order"); + static_assert(AccessOrderLimits3D, "Invalid B source access order"); + + // The forward convolution kernel class instance. + using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< + SPATIAL_DIM, + typename Layouts::InLayout, + typename Layouts::WeiLayout, + typename Layouts::OutLayout, + typename Types::InDataType, + typename Types::WeiDataType, + typename Types::OutDataType, + typename Types::AccDataType, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, + BWD_CONV_SPECIALIZATION, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.k1, + XDL_PARAMS.m_per_xdl, + XDL_PARAMS.n_per_xdl, + XDL_PARAMS.m_xdl_per_wave, + XDL_PARAMS.n_xdl_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + A_BLOCK_TRANSFER.lds_padding, + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + B_BLOCK_TRANSFER.lds_padding, + C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + BLOCK_GEMM.scheduler, + BLOCK_GEMM.pipeline_version, + typename Types::InComputeType, + typename Types::WeiComputeType>; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp index adbf12992e5..01c0fb9c56b 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -49,6 +49,12 @@ #pragma once +// Disable pragma message warnings for factory selection diagnostics +#ifdef __clang__ + #pragma clang diagnostic push + #pragma clang diagnostic ignored "-W#pragma-messages" +#endif + #include "ck_tile/builder/conv_signature_concepts.hpp" #include "ck_tile/builder/conv_algorithm_concepts.hpp" #include "ck_tile/builder/types.hpp" @@ -64,6 +70,7 @@ #include "ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp" #include "ck_tile/builder/factory/conv_tile_factory.hpp" #include "ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp" +#include "ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp" namespace ck_tile::builder::factory { @@ -96,28 +103,34 @@ constexpr auto make_conv_instance() // CK Tile supports common factory for each direction if constexpr(TileAlgorithm::is_valid()) { + #pragma message("[CK Builder] Using ConvTileFactory...") return typename ConvTileFactory::Instance{}; } else if constexpr(ConvDirectionIsForward) { if constexpr(FwdXdlV3Algorithm::is_valid()) { + #pragma message("[CK Builder] Using ConvFwdXdlV3Factory...") return typename ConvFwdXdlV3Factory::Instance{}; } else if constexpr(FwdXdlAlgorithm::is_valid()) { + #pragma message("[CK Builder] Using ConvFwdXdlFactory...") return typename ConvFwdXdlFactory::Instance{}; } else if constexpr(FwdWmmaAlgorithm::is_valid()) { + #pragma message("[CK Builder] Using ConvFwdWmmaFactory...") return typename ConvFwdWmmaFactory::Instance{}; } else if constexpr(FwdDlAlgorithm::is_valid()) { + #pragma message("[CK Builder] Using ConvFwdDlFactory...") return typename ConvFwdDlFactory::Instance{}; } else if constexpr(LargeTensorAlgorithm::is_valid()) { + #pragma message("[CK Builder] Using ConvFwdLargeTensorFactory...") return typename ConvFwdLargeTensorFactory::Instance{}; } else @@ -135,8 +148,14 @@ constexpr auto make_conv_instance() { if constexpr (BwdXdlAlgorithm::is_valid()) { + #pragma message("[CK Builder] Using ConvBwdWeightXdlFactory...") return typename ConvBwdWeightXdlFactory::Instance{}; } + else if constexpr (BwdXdlV3Algorithm::is_valid()) + { + #pragma message("[CK Builder] Using ConvBwdWeightXdlV3Factory...") + return typename ConvBwdWeightXdlV3Factory::Instance{}; + } else { diagnose_bwd_weight_algorithm_signature(); @@ -152,3 +171,8 @@ constexpr auto make_conv_instance() } } // namespace ck_tile::builder::factory + +// Re-enable pragma message warnings +#ifdef __clang__ + #pragma clang diagnostic pop +#endif diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp index 62547dbe326..456c567aa06 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp @@ -54,10 +54,10 @@ struct ConvFwdLargeTensorFactory static_assert(InputVectorTransferLimits); static_assert(InputVectorTransferLimits); static_assert(OutputVectorTransferLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); // The forward convolution kernel class instance with large tensor support. using Instance = diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp index 30fd555dd98..e34f39965f2 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp @@ -56,10 +56,10 @@ struct ConvFwdXdlV3Factory static_assert(InputVectorTransferLimits); static_assert(InputVectorTransferLimits); static_assert(OutputVectorTransferLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp index 1fb3942df0a..dbaa8651eb2 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp @@ -52,10 +52,10 @@ struct ConvFwdWmmaFactory static_assert(InputVectorTransferLimits); static_assert(InputVectorTransferLimits); static_assert(OutputVectorTransferLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp index 16baf4fbced..cebf5a0c3a1 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp @@ -51,10 +51,10 @@ struct ConvFwdXdlFactory static_assert(InputVectorTransferLimits); static_assert(InputVectorTransferLimits); static_assert(OutputVectorTransferLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp index 5ee07fa2d18..25cf773694f 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp @@ -62,16 +62,40 @@ constexpr BwdBlockTransfer SetBwdConvBlockTransfer() auto& src_order = TRANSFER.src_access_order; auto& lds_cfg = TRANSFER.lds_transfer; - return BwdBlockTransfer{ - .thread_cluster_dims = {block_xfer.k_batch_size, block_xfer.k0, block_xfer.m_n, block_xfer.k1}, - .thread_cluster_order = {block_order.order[0], block_order.order[1], block_order.order[2], block_order.order[3]}, - .src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2], src_order.order[3]}, - .src_vector_dim = lds_cfg.src_vector_dim, - .src_scalar_per_vector = lds_cfg.src_scalar_per_vector, - .lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector, - .is_direct_load = lds_cfg.is_direct_load, - .lds_padding = lds_cfg.lds_padding, - }; + constexpr auto array_length = block_order.order.size(); + static_assert(block_order.order.size() == src_order.order.size(), + "Mismatched size between block order and src order"); + + if constexpr (array_length == 3) + { + return BwdBlockTransfer{ + .thread_cluster_dims = {block_xfer.k_batch_size, block_xfer.k0, block_xfer.m_n, block_xfer.k1}, + .thread_cluster_order = {block_order.order[0], block_order.order[1], block_order.order[2]}, + .src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2]}, + .src_vector_dim = lds_cfg.src_vector_dim, + .src_scalar_per_vector = lds_cfg.src_scalar_per_vector, + .lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector, + .is_direct_load = lds_cfg.is_direct_load, + .lds_padding = lds_cfg.lds_padding, + }; + } + else if constexpr (array_length == 4) + { + return BwdBlockTransfer{ + .thread_cluster_dims = {block_xfer.k_batch_size, block_xfer.k0, block_xfer.m_n, block_xfer.k1}, + .thread_cluster_order = {block_order.order[0], block_order.order[1], block_order.order[2], block_order.order[3]}, + .src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2], src_order.order[3]}, + .src_vector_dim = lds_cfg.src_vector_dim, + .src_scalar_per_vector = lds_cfg.src_scalar_per_vector, + .lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector, + .is_direct_load = lds_cfg.is_direct_load, + .lds_padding = lds_cfg.lds_padding, + }; + } + else + { + static_assert(false, "Internal error: Unsupported array length"); + } } // Block transfer parameters for C tensor. diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 8ffdf3f5435..8105a41bf57 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -142,6 +142,7 @@ target_link_libraries(test_ckb_build_fwd_instances PRIVATE utility) add_ck_builder_test(test_ckb_build_bwd_weight_instances conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle.cpp + conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp ) target_link_libraries(test_ckb_build_bwd_weight_instances PRIVATE utility) diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp new file mode 100644 index 00000000000..2dfa6e57719 --- /dev/null +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp @@ -0,0 +1,43 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" +#include "ck_tile/host/device_prop.hpp" + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace cku = ck_tile::builder::test_utils; +using enum ck_tile::builder::TensorLayout; + +constexpr auto SIGNATURE = + ckt::ConvSignature{.spatial_dim = 1, + .direction = ckb::ConvDirection::BACKWARD_WEIGHT, + .data_type = ckb::DataType::BF16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = NGCW}}, + .weight = {.config = {.layout = GKXC}}, + .output = {.config = {.layout = NGKW}}}; + +constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3{} + .with_thread_block(cku::ThreadBlock_64_32x32x32) + .with_gemm_config(cku::BwdGemmParams_Xdl_1x1_per_wave) + .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) + .with_bwd_specialization(ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0) + .with_block_gemm(cku::BlockGemmDesc_v2_intrawave); + +using Builder = ckb::ConvBuilder; +using Instance = Builder::Instance; + +TEST(BwdWeight_1DBf16_CShuffle_V3, Create) +{ + const auto expected_transfer_parameters = to_string(ALGORITHM); + cku::run_test({"DeviceGroupedConvBwdWeight_Xdl_CShuffleV3", + expected_transfer_parameters, + "FILTER_1X1_STRIDE1_PAD0", + "NGCW,GKXC,NGKW", + "PassThrough,PassThrough,PassThrough", + "Intrawave", + "v2"}); +} diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp index c87ffbd0663..22911a1a26c 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp @@ -25,7 +25,7 @@ TEST(FwdConvInstances, .accumulation_data_type = INT32, .input = {.config = {.layout = GNWC}}, .weight = {.config = {.layout = GKXC}}, - .output = {.config = {.layout = GNWK}}}; + .output = {.config = {.layout = GNWK}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle{} diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index e3c947f5414..b045d185e2a 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -74,7 +74,7 @@ struct BlockGemm static_assert(ckb::BlockGemmDescriptor); // Describe Aand B block transfer thread cluster lengths. -template +template struct BlockTransfer { size_t k0; @@ -83,16 +83,16 @@ struct BlockTransfer size_t k_batch_size; }; -// Specialization for forward (IsBwd = false) +// Specialization for ThreadSliceLength == 3 template <> -struct BlockTransfer +struct BlockTransfer<3> { size_t k0; size_t m_n; size_t k1; }; static_assert(ckb::BlockTransferDescriptor>); -static_assert(ckb::BlockTransferDescriptor>); +static_assert(ckb::BlockTransferDescriptor>); // Describe C block transfer thread cluster lengths. struct ThreadCluster @@ -130,13 +130,13 @@ struct AccessOrder static_assert(AccessOrderDescriptor>); static_assert(AccessOrderDescriptor>); -template +template struct InputTransfer { - BlockTransfer block_transfer; + BlockTransfer block_transfer; LdsTransfer lds_transfer; - std::conditional_t, AccessOrder<3>> block_transfer_access_order; - std::conditional_t, AccessOrder<3>> src_access_order; + AccessOrder block_transfer_access_order; + AccessOrder src_access_order; }; struct OutputTransfer @@ -145,11 +145,11 @@ struct OutputTransfer Epilogue epilogue; }; -template +template struct Transfer { - InputTransfer a; - InputTransfer b; + InputTransfer a; + InputTransfer b; OutputTransfer c; }; @@ -213,10 +213,10 @@ struct WmmaGemm_ GridwiseWmmaGemm gridwise_gemm; }; -template +template struct Transfer_ { - Transfer transfer; + Transfer transfer; }; struct ConvSpecializationFwd_ @@ -397,7 +397,7 @@ struct ConvAlgorithmTemplate : Components... constexpr auto with_transfer(const T& t) const { static_assert(std::is_base_of_v, ConvAlgorithmTemplate> || - std::is_base_of_v, ConvAlgorithmTemplate>); + std::is_base_of_v, ConvAlgorithmTemplate>); auto result = *this; result.transfer = t; return result; @@ -553,6 +553,9 @@ using ConvAlgorithm_Tile_GroupedConvolutionKernel = ConvAlgorithmTemplate; using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle = - ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, TransposeParams_>; + ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, TransposeParams_>; + +using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 = + ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, BlockGemm_>; } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/utils/ckb_conv_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_test_configs.hpp index 956f65f4531..7b5807ef233 100644 --- a/experimental/builder/test/utils/ckb_conv_test_configs.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_configs.hpp @@ -72,8 +72,7 @@ constexpr Transfer<> Transfer_4x64x1{ }, }; -constexpr bool BWD = true; -constexpr Transfer BwdTransfer_4x64x1{ +constexpr Transfer<4> BwdTransfer_4x64x1{ .a = { .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1}, @@ -106,6 +105,39 @@ constexpr Transfer BwdTransfer_4x64x1{ }, }; +constexpr Transfer<> BwdTransfer_4x8x1_4x16x1_v3{ + .a = + { + .block_transfer = {.k0 = 4, .m_n = 8, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 1, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 2, + .is_direct_load = false, + .lds_padding = false}, + .block_transfer_access_order = {2, 0, 1}, + .src_access_order = {1, 0, 2}, + }, + .b = + { + .block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 1, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 2, + .is_direct_load = false, + .lds_padding = false}, + .block_transfer_access_order = {2, 0, 1}, + .src_access_order = {1, 0, 2}, + }, + .c = + { + .thread_cluster_dims = + {.m_block = 1, .m_wave_per_xdl = 8, .n_block = 1, .n_wave_per_xdl = 8}, + .epilogue = {.m_xdl_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 2}, + }, +}; + constexpr Transfer<> Transfer_4x64x1_fp8{ .a = { @@ -210,6 +242,10 @@ constexpr GridwiseBwdXdlGemm BwdGemmParams_Xdl_4x4_per_wave{ .k1 = 8, .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}}; +constexpr GridwiseBwdXdlGemm BwdGemmParams_Xdl_1x1_per_wave{ + .k1 = 8, + .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 1, .n_xdl_per_wave = 1}}; + constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_4x4_per_wave{ .ak1 = 8, .bk1 = 8, .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}}; @@ -251,6 +287,9 @@ constexpr ThreadBlock ThreadBlock_256_128x128x8{.block_size = 256, constexpr ThreadBlock ThreadBlock_64_64x32x32{.block_size = 64, .tile_size = {.m = 64, .n = 32, .k = 32}}; +constexpr ThreadBlock ThreadBlock_64_32x32x32{.block_size = 64, + .tile_size = {.m = 32, .n = 32, .k = 32}}; + constexpr ThreadBlock ThreadBlock_128_128x128x32{.block_size = 128, .tile_size = {.m = 128, .n = 128, .k = 32}}; diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index f7096f27f81..cf13f39391a 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -120,17 +120,21 @@ inline std::string to_string(BlockGemm t) return oss.str(); } -template -inline std::string to_string(BlockTransfer t) +template +inline std::string to_string(BlockTransfer t) { - if constexpr (IsBwd) + if constexpr (ThreadSliceDim == 4) { return array_to_seq(std::array{t.k_batch_size, t.k0, t.m_n, t.k1}); } - else + else if constexpr (ThreadSliceDim == 3) { return array_to_seq(std::array{t.k0, t.m_n, t.k1}); } + else + { + static_assert(ThreadSliceDim == 3 || ThreadSliceDim == 4, "Unsupported ThreadSliceDim"); + } } template <> @@ -156,8 +160,8 @@ inline std::string to_string(AccessOrder t) return array_to_seq(t.order); } -template -inline std::string to_string(InputTransfer t) +template +inline std::string to_string(InputTransfer t) { std::ostringstream oss; oss << to_string(t.block_transfer) << "," << to_string(t.block_transfer_access_order) << "," @@ -176,8 +180,8 @@ inline std::string to_string(OutputTransfer t) return oss.str(); } -template -inline std::string to_string(Transfer t) +template +inline std::string to_string(Transfer t) { std::ostringstream oss; oss << to_string(t.a) << "," << to_string(t.b) << "," << to_string(t.c); @@ -267,8 +271,8 @@ inline std::string to_string(WmmaGemm_ t) return to_string(t.gridwise_gemm); } -template -inline std::string to_string(Transfer_ t) +template +inline std::string to_string(Transfer_ t) { return to_string(t.transfer); } @@ -378,9 +382,18 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); + << "," << to_string(static_cast>(t)); + return oss.str(); +} + +template <> +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); return oss.str(); } From a83790e9dae3f2061baae58f5717c75a427ed727 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Mon, 29 Dec 2025 09:30:58 -0500 Subject: [PATCH 27/81] Build conv bwd weigth v3 instances successfully. --- .../builder/conv_algorithm_concepts.hpp | 11 +++++---- .../builder/conv_algorithm_diagnostics.hpp | 14 +++++++---- .../builder/factory/conv_algorithms.hpp | 12 +++++----- .../helpers/ck/conv_block_transfer.hpp | 24 +++++++++---------- 4 files changed, 32 insertions(+), 29 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index cdbe805cdd5..447bbdad5e3 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -65,7 +65,7 @@ concept BlockTransferDescriptor = requires(T t) { }; template -concept BlockTransferDescriptorBwd = requires(T t) { +concept BlockTransferDescriptor4D = requires(T t) { { t.k0 } -> SizeType; { t.m_n } -> SizeType; { t.k1 } -> SizeType; @@ -211,11 +211,12 @@ concept SpecifiesBlockTransfer = requires(T t) { { T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor; }; -// Concept to check if a struct specifies convolution input and output block transfer info (Bwd direction). +// Concept to check if a struct specifies convolution input and output block transfer info +// for 4D thread slices. template -concept SpecifiesBlockTransferBwd = requires(T t) { - { T::transfer.a.block_transfer } -> BlockTransferDescriptorBwd; - { T::transfer.b.block_transfer } -> BlockTransferDescriptorBwd; +concept SpecifiesBlockTransfer4D = requires(T t) { + { T::transfer.a.block_transfer } -> BlockTransferDescriptor4D; + { T::transfer.b.block_transfer } -> BlockTransferDescriptor4D; { T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor; }; diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp index 2340e19e614..6613d2d7367 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp @@ -184,9 +184,9 @@ consteval auto diagnose_block_transfer(const char* prefix) -> std::string { return msg; } -// BlockTransferDescriptorBwd diagnostics (requires k_batch_size) +// BlockTransferDescriptor4D diagnostics (requires k_batch_size) template -consteval auto diagnose_block_transfer_bwd(const char* prefix) -> std::string { +consteval auto diagnose_block_transfer_4d(const char* prefix) -> std::string { std::string msg; if constexpr (requires(BT t) { t.k0; }) { @@ -500,7 +500,7 @@ consteval auto detailed_diagnostic_SpecifiesBlockTransfer() -> std::string { } template -consteval auto detailed_diagnostic_SpecifiesBlockTransferBwd() -> std::string { +consteval auto detailed_diagnostic_SpecifiesBlockTransfer4D() -> std::string { std::string msg; constexpr bool has_transfer = requires { T::transfer; }; @@ -510,16 +510,20 @@ consteval auto detailed_diagnostic_SpecifiesBlockTransferBwd() -> std::string { return msg; } - constexpr bool has_a = requires { { T::transfer.a.block_transfer } -> BlockTransferDescriptorBwd; }; + constexpr bool has_a = requires { { T::transfer.a.block_transfer } -> BlockTransferDescriptor4D; }; msg += " → T::transfer.a: " + std::string(CHECK_MARK(has_a)) + "\n"; if constexpr (!has_a) { msg += " → T::transfer.a.block_transfer: [✗] (missing or wrong type)\n"; + } else { + msg += detail::diagnose_block_transfer_4d("transfer.a.block_transfer"); } - constexpr bool has_b = requires { { T::transfer.b.block_transfer } -> BlockTransferDescriptorBwd; }; + constexpr bool has_b = requires { { T::transfer.b.block_transfer } -> BlockTransferDescriptor4D; }; msg += " → T::transfer.b: " + std::string(CHECK_MARK(has_b)) + "\n"; if constexpr (!has_b) { msg += " → T::transfer.b.block_transfer: [✗] (missing or wrong type)\n"; + } else { + msg += detail::diagnose_block_transfer_4d("transfer.b.block_transfer"); } constexpr bool has_c = requires { { T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor; }; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index e537b7ba99b..bf7f0248fda 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -242,7 +242,7 @@ template struct BwdXdlAlgorithm { CHECK_CONCEPT(T, ConvAlgorithmDescriptor) CHECK_CONCEPT(T, SpecifiesThreadBlock) - CHECK_CONCEPT(T, SpecifiesBlockTransferBwd) + CHECK_CONCEPT(T, SpecifiesBlockTransfer4D) CHECK_CONCEPT(T, SpecifiesLdsTransfer) CHECK_CONCEPT(T, SpecifiesThreadClusterAccessOrder) CHECK_CONCEPT(T, SpecifiesSourceAccessOrder) @@ -252,7 +252,7 @@ struct BwdXdlAlgorithm { static constexpr bool c1 = c_ConvAlgorithmDescriptor; static constexpr bool c2 = c_SpecifiesThreadBlock; - static constexpr bool c3 = c_SpecifiesBlockTransferBwd; + static constexpr bool c3 = c_SpecifiesBlockTransfer4D; static constexpr bool c4 = c_SpecifiesLdsTransfer; static constexpr bool c5 = c_SpecifiesThreadClusterAccessOrder; static constexpr bool c6 = c_SpecifiesSourceAccessOrder; @@ -269,7 +269,7 @@ struct BwdXdlAlgorithm { "Concepts for BwdXdl Algorithm:\n") + DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + DIAGNOSTIC_LINE(SpecifiesThreadBlock) + - DIAGNOSTIC_LINE(SpecifiesBlockTransferBwd) + + DIAGNOSTIC_LINE(SpecifiesBlockTransfer4D) + DIAGNOSTIC_LINE(SpecifiesLdsTransfer) + DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) + DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) + @@ -283,7 +283,7 @@ template struct BwdXdlV3Algorithm { CHECK_CONCEPT(T, ConvAlgorithmDescriptor) CHECK_CONCEPT(T, SpecifiesThreadBlock) - CHECK_CONCEPT(T, SpecifiesBlockTransferBwd) + CHECK_CONCEPT(T, SpecifiesBlockTransfer) CHECK_CONCEPT(T, SpecifiesLdsTransfer) CHECK_CONCEPT(T, SpecifiesThreadClusterAccessOrder) CHECK_CONCEPT(T, SpecifiesSourceAccessOrder) @@ -293,7 +293,7 @@ struct BwdXdlV3Algorithm { static constexpr bool c1 = c_ConvAlgorithmDescriptor; static constexpr bool c2 = c_SpecifiesThreadBlock; - static constexpr bool c3 = c_SpecifiesBlockTransferBwd; + static constexpr bool c3 = c_SpecifiesBlockTransfer; static constexpr bool c4 = c_SpecifiesLdsTransfer; static constexpr bool c5 = c_SpecifiesThreadClusterAccessOrder; static constexpr bool c6 = c_SpecifiesSourceAccessOrder; @@ -310,7 +310,7 @@ struct BwdXdlV3Algorithm { "Concepts for BwdXdlV3 Algorithm:\n") + DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + DIAGNOSTIC_LINE(SpecifiesThreadBlock) + - DIAGNOSTIC_LINE(SpecifiesBlockTransferBwd) + + DIAGNOSTIC_LINE(SpecifiesBlockTransfer) + DIAGNOSTIC_LINE(SpecifiesLdsTransfer) + DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) + DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) + diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp index 25cf773694f..69facce41ba 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp @@ -12,9 +12,9 @@ namespace ck_tile::builder::factory::internal { // Block transfer parameters for A or B tensor. struct BlockTransfer { - ck::Array thread_cluster_dims = {0, 0, 0}; // k0, m, k1 - ck::Array thread_cluster_order = {0, 0, 0}; - ck::Array src_access_order = {0, 0, 0}; + ck::Array thread_cluster_dims{}; // k0, m, k1 + ck::Array thread_cluster_order{}; + ck::Array src_access_order{}; size_t src_vector_dim = 0; size_t src_scalar_per_vector = 0; size_t lds_dst_scalar_per_vector = 0; @@ -22,15 +22,15 @@ struct BlockTransfer bool lds_padding = false; }; +template struct BwdBlockTransfer { - ck::Array thread_cluster_dims = {0, 0, 0, 0}; - ck::Array thread_cluster_order = {0, 0, 0, 0}; - ck::Array src_access_order = {0, 0, 0, 0}; + ck::Array thread_cluster_dims{}; + ck::Array thread_cluster_order{}; + ck::Array src_access_order{}; size_t src_vector_dim = 0; size_t src_scalar_per_vector = 0; size_t lds_dst_scalar_per_vector = 0; - bool is_direct_load = false; bool lds_padding = false; }; @@ -55,7 +55,7 @@ constexpr BlockTransfer SetFwdConvBlockTransfer() } template -constexpr BwdBlockTransfer SetBwdConvBlockTransfer() +constexpr auto SetBwdConvBlockTransfer() { auto& block_xfer = TRANSFER.block_transfer; auto& block_order = TRANSFER.block_transfer_access_order; @@ -68,27 +68,25 @@ constexpr BwdBlockTransfer SetBwdConvBlockTransfer() if constexpr (array_length == 3) { - return BwdBlockTransfer{ - .thread_cluster_dims = {block_xfer.k_batch_size, block_xfer.k0, block_xfer.m_n, block_xfer.k1}, + return BwdBlockTransfer<3>{ + .thread_cluster_dims = {block_xfer.k0, block_xfer.m_n, block_xfer.k1}, .thread_cluster_order = {block_order.order[0], block_order.order[1], block_order.order[2]}, .src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2]}, .src_vector_dim = lds_cfg.src_vector_dim, .src_scalar_per_vector = lds_cfg.src_scalar_per_vector, .lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector, - .is_direct_load = lds_cfg.is_direct_load, .lds_padding = lds_cfg.lds_padding, }; } else if constexpr (array_length == 4) { - return BwdBlockTransfer{ + return BwdBlockTransfer<4>{ .thread_cluster_dims = {block_xfer.k_batch_size, block_xfer.k0, block_xfer.m_n, block_xfer.k1}, .thread_cluster_order = {block_order.order[0], block_order.order[1], block_order.order[2], block_order.order[3]}, .src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2], src_order.order[3]}, .src_vector_dim = lds_cfg.src_vector_dim, .src_scalar_per_vector = lds_cfg.src_scalar_per_vector, .lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector, - .is_direct_load = lds_cfg.is_direct_load, .lds_padding = lds_cfg.lds_padding, }; } From ab88cee0eb274a7de57d2cae3069aba084d1a015 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Mon, 29 Dec 2025 09:53:07 -0500 Subject: [PATCH 28/81] Add instance traits for DeviceGroupedConvBwdWeight_Xdl_CShuffleV3. --- ...rouped_conv_bwd_weight_xdl_cshuffle_v3.hpp | 284 ++++++++++++++++++ ...rouped_conv_bwd_weight_xdl_cshuffle_v3.hpp | 24 ++ 2 files changed, 308 insertions(+) create mode 100644 experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp new file mode 100644 index 00000000000..7970b1ced59 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -0,0 +1,284 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "instance_traits.hpp" +#include "instance_traits_util.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" + +// Forward declaration to avoid circular dependency +namespace ck::tensor_operation::device { + +template +struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3; + +} // namespace ck::tensor_operation::device + +namespace ck_tile { +namespace reflect { + +template +struct InstanceTraits> +{ + static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3"; + + static constexpr ck::index_t kNDimSpatial = NDimSpatial; + + using InLayout = InLayout_; + using WeiLayout = WeiLayout_; + using OutLayout = OutLayout_; + + using InDataType = InDataType_; + using WeiDataType = WeiDataType_; + using OutDataType = OutDataType_; + using AccDataType = AccDataType_; + + using InElementwiseOperation = InElementwiseOperation_; + using WeiElementwiseOperation = WeiElementwiseOperation_; + using OutElementwiseOperation = OutElementwiseOperation_; + + static constexpr auto kConvBackwardWeightSpecialization = ConvBackwardWeightSpecialization; + + static constexpr ck::index_t kBlockSize = BlockSize; + static constexpr ck::index_t kMPerBlock = MPerBlock; + static constexpr ck::index_t kNPerBlock = NPerBlock; + static constexpr ck::index_t kK0PerBlock = K0PerBlock; + static constexpr ck::index_t kK1 = K1; + static constexpr ck::index_t kMPerXDL = MPerXDL; + static constexpr ck::index_t kNPerXDL = NPerXDL; + static constexpr ck::index_t kMXdlPerWave = MXdlPerWave; + static constexpr ck::index_t kNXdlPerWave = NXdlPerWave; + + using ABlockTransferThreadClusterLengths_K0_M_K1 = ABlockTransferThreadClusterLengths_K0_M_K1_; + using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_; + using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_; + static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim; + static constexpr ck::index_t kABlockTransferSrcScalarPerVector = + ABlockTransferSrcScalarPerVector; + static constexpr ck::index_t kABlockTransferDstScalarPerVector_K1 = + ABlockTransferDstScalarPerVector_K1; + static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM; + + using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_; + using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_; + using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_; + static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim; + static constexpr ck::index_t kBBlockTransferSrcScalarPerVector = + BBlockTransferSrcScalarPerVector; + static constexpr ck::index_t kBBlockTransferDstScalarPerVector_K1 = + BBlockTransferDstScalarPerVector_K1; + static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN; + + static constexpr ck::index_t kCShuffleMXdlPerWavePerShuffle = CShuffleMXdlPerWavePerShuffle; + static constexpr ck::index_t kCShuffleNXdlPerWavePerShuffle = CShuffleNXdlPerWavePerShuffle; + + using CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_; + static constexpr ck::index_t kCBlockTransferScalarPerVector_NWaveNPerXdl = + CBlockTransferScalarPerVector_NWaveNPerXdl; + + static constexpr ck::BlockGemmPipelineScheduler kBlkGemmPipeSched = BlkGemmPipeSched; + static constexpr ck::BlockGemmPipelineVersion kBlkGemmPipelineVer = BlkGemmPipelineVer; + + using ComputeTypeA = ComputeTypeA_; + using ComputeTypeB = ComputeTypeB_; + + // Static member function to generate instance string + static std::string instance_string() + { + std::ostringstream oss; + + // Kernel type name + oss << "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3"; + + // Template parameters in exact order + oss << "<" << kNDimSpatial; // 1. NDimSpatial + oss << "," << detail::layout_name(); // 2. InLayout + oss << "," << detail::layout_name(); // 3. WeiLayout + oss << "," << detail::layout_name(); // 4. OutLayout + oss << "," << detail::type_name(); // 5. InDataType + oss << "," << detail::type_name(); // 6. WeiDataType + oss << "," << detail::type_name(); // 7. OutDataType + oss << "," << detail::type_name(); // 8. AccDataType + oss << "," + << detail::elementwise_op_name(); // 9. InElementwiseOperation + oss << "," + << detail::elementwise_op_name(); // 10. + // WeiElementwiseOperation + oss << "," + << detail::elementwise_op_name(); // 11. + // OutElementwiseOperation + oss << "," + << detail::conv_bwd_weight_spec_name( + kConvBackwardWeightSpecialization); // 12. ConvBackwardWeightSpecialization + oss << "," << kBlockSize; // 13. BlockSize + oss << "," << kMPerBlock; // 14. MPerBlock + oss << "," << kNPerBlock; // 15. NPerBlock + oss << "," << kK0PerBlock; // 16. K0PerBlock + oss << "," << kK1; // 17. K1 + oss << "," << kMPerXDL; // 18. MPerXDL + oss << "," << kNPerXDL; // 19. NPerXDL + oss << "," << kMXdlPerWave; // 20. MXdlPerWave + oss << "," << kNXdlPerWave; // 21. NXdlPerWave + oss << "," << detail::sequence_name(); // 22. + oss << "," << detail::sequence_name(); // 23. + oss << "," << detail::sequence_name(); // 24. + oss << "," << kABlockTransferSrcVectorDim; // 25. + oss << "," << kABlockTransferSrcScalarPerVector; // 26. + oss << "," << kABlockTransferDstScalarPerVector_K1; // 27. + oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28. + oss << "," << detail::sequence_name(); // 29. + oss << "," << detail::sequence_name(); // 30. + oss << "," << detail::sequence_name(); // 31. + oss << "," << kBBlockTransferSrcVectorDim; // 32. + oss << "," << kBBlockTransferSrcScalarPerVector; // 33. + oss << "," << kBBlockTransferDstScalarPerVector_K1; // 34. + oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35. + oss << "," << kCShuffleMXdlPerWavePerShuffle; // 36. + oss << "," << kCShuffleNXdlPerWavePerShuffle; // 37. + oss << "," + << detail::sequence_name< + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>(); // 38. + oss << "," << kCBlockTransferScalarPerVector_NWaveNPerXdl; // 39. + oss << "," << detail::pipeline_scheduler_name(kBlkGemmPipeSched); // 40. + oss << "," << detail::pipeline_version_name(kBlkGemmPipelineVer); // 41. + oss << "," << detail::type_name(); // 42. + oss << "," << detail::type_name(); // 43. + oss << ">"; + + return oss.str(); + } +}; + +} // namespace reflect +} // namespace ck_tile diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index a0e06a81d67..1e3177729d3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -29,6 +29,11 @@ #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/flush_cache.hpp" +#ifdef CK_EXPERIMENTAL_BUILDER +#include "ck_tile/builder/reflect/description.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp" +#endif + namespace ck { namespace tensor_operation { namespace device { @@ -1548,6 +1553,25 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 return str.str(); } + +#ifdef CK_EXPERIMENTAL_BUILDER + std::string GetInstanceString() const override + { + static_assert(ck_tile::reflect::HasInstanceTraits, + "Specialization of instance_traits not found. Please check that a " + "specialization exists in file " + "ck_tile/builder/reflect/" + "instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp " + "for the given template parameters."); + return ck_tile::reflect::instance_string(); + } + + std::unique_ptr describe() const override + { + return std::make_unique(GetInstanceString()); + } +#endif + }; } // namespace device From 3e16fa072faad6f12ea56f7f13c9d16da9a57150 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Mon, 29 Dec 2025 09:53:47 -0500 Subject: [PATCH 29/81] Test fix. --- .../test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp index 2dfa6e57719..fdc17fba2a8 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp @@ -35,7 +35,7 @@ TEST(BwdWeight_1DBf16_CShuffle_V3, Create) const auto expected_transfer_parameters = to_string(ALGORITHM); cku::run_test({"DeviceGroupedConvBwdWeight_Xdl_CShuffleV3", expected_transfer_parameters, - "FILTER_1X1_STRIDE1_PAD0", + "Filter1x1Stride1Pad0", "NGCW,GKXC,NGKW", "PassThrough,PassThrough,PassThrough", "Intrawave", From 3c1e2b01709d20c24f9c1bae6ad1a6b8fe98188a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Tue, 30 Dec 2025 04:29:38 -0500 Subject: [PATCH 30/81] Add instance traits for bwd weight algorithms. --- ...aits_device_grouped_conv_bwd_weight_dl.hpp | 251 +++++++++++++++ ...bwd_weight_multiple_d_wmma_cshuffle_v3.hpp | 289 +++++++++++++++++ ...onv_bwd_weight_multiple_d_xdl_cshuffle.hpp | 278 +++++++++++++++++ ..._bwd_weight_two_stage_wmma_cshuffle_v3.hpp | 294 ++++++++++++++++++ ...conv_bwd_weight_two_stage_xdl_cshuffle.hpp | 294 ++++++++++++++++++ ..._grouped_conv_bwd_weight_wmma_cshuffle.hpp | 275 ++++++++++++++++ ...ouped_conv_bwd_weight_wmma_cshuffle_v3.hpp | 289 +++++++++++++++++ ...rouped_conv_bwd_weight_xdl_cshuffle_v3.hpp | 2 +- .../device_grouped_conv_bwd_weight_dl.hpp | 24 ++ ...bwd_weight_multiple_d_wmma_cshuffle_v3.hpp | 24 ++ ...onv_bwd_weight_multiple_d_xdl_cshuffle.hpp | 24 ++ ..._bwd_weight_two_stage_wmma_cshuffle_v3.hpp | 24 ++ ...conv_bwd_weight_two_stage_xdl_cshuffle.hpp | 24 ++ ..._grouped_conv_bwd_weight_wmma_cshuffle.hpp | 24 ++ ...ouped_conv_bwd_weight_wmma_cshuffle_v3.hpp | 24 ++ 15 files changed, 2139 insertions(+), 1 deletion(-) create mode 100644 experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_dl.hpp create mode 100644 experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp create mode 100644 experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp create mode 100644 experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp create mode 100644 experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp create mode 100644 experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp create mode 100644 experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_dl.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_dl.hpp new file mode 100644 index 00000000000..b52a4aee80d --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_dl.hpp @@ -0,0 +1,251 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "instance_traits.hpp" +#include "instance_traits_util.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" + +namespace ck::tensor_operation::device { + +template +struct DeviceGroupedConvBwdWeight_Dl; + +} // namespace ck::tensor_operation::device + +namespace ck_tile { +namespace reflect { + +template +struct InstanceTraits> +{ + static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeight_Dl"; + + static constexpr ck::index_t kNDimSpatial = NDimSpatial; + + using InLayout = InLayout_; + using WeiLayout = WeiLayout_; + using OutLayout = OutLayout_; + + using InDataType = InDataType_; + using WeiDataType = WeiDataType_; + using OutDataType = OutDataType_; + using AccDataType = AccDataType_; + + using InElementwiseOperation = InElementwiseOperation_; + using WeiElementwiseOperation = WeiElementwiseOperation_; + using OutElementwiseOperation = OutElementwiseOperation_; + + static constexpr auto kConvBackwardWeightSpecialization = ConvBackwardWeightSpecialization; + + static constexpr ck::index_t kBlockSize = BlockSize; + static constexpr ck::index_t kMPerBlock = MPerBlock; + static constexpr ck::index_t kNPerBlock = NPerBlock; + static constexpr ck::index_t kK0PerBlock = K0PerBlock; + static constexpr ck::index_t kK1 = K1; + static constexpr ck::index_t kM1PerThread = M1PerThread; + static constexpr ck::index_t kN1PerThread = N1PerThread; + static constexpr ck::index_t kKPerThread = KPerThread; + + using M1N1ThreadClusterM1Xs = M1N1ThreadClusterM1Xs_; + using M1N1ThreadClusterN1Xs = M1N1ThreadClusterN1Xs_; + + using ABlockTransferThreadSliceLengths_K0_M0_M1_K1 = ABlockTransferThreadSliceLengths_K0_M0_M1_K1_; + using ABlockTransferThreadClusterLengths_K0_M0_M1_K1 = ABlockTransferThreadClusterLengths_K0_M0_M1_K1_; + using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_; + using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_; + using ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1_; + using ABlockTransferSrcVectorTensorContiguousDimOrder = ABlockTransferSrcVectorTensorContiguousDimOrder_; + using ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1_; + + using BBlockTransferThreadSliceLengths_K0_N0_N1_K1 = BBlockTransferThreadSliceLengths_K0_N0_N1_K1_; + using BBlockTransferThreadClusterLengths_K0_N0_N1_K1 = BBlockTransferThreadClusterLengths_K0_N0_N1_K1_; + using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_; + using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_; + using BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1_; + using BBlockTransferSrcVectorTensorContiguousDimOrder = BBlockTransferSrcVectorTensorContiguousDimOrder_; + using BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1_; + + using CThreadTransferSrcDstAccessOrder = CThreadTransferSrcDstAccessOrder_; + static constexpr ck::index_t kCThreadTransferSrcDstVectorDim = CThreadTransferSrcDstVectorDim; + static constexpr ck::index_t kCThreadTransferDstScalarPerVector = CThreadTransferDstScalarPerVector; + + // Static member function to generate instance string + static std::string instance_string() + { + std::ostringstream oss; + + // Kernel type name + oss << "DeviceGroupedConvBwdWeight_Dl"; + + // Template parameters in exact order + oss << "<" << kNDimSpatial; // 1. NDimSpatial + oss << "," << detail::layout_name(); // 2. InLayout + oss << "," << detail::layout_name(); // 3. WeiLayout + oss << "," << detail::layout_name(); // 4. OutLayout + oss << "," << detail::type_name(); // 5. InDataType + oss << "," << detail::type_name(); // 6. WeiDataType + oss << "," << detail::type_name(); // 7. OutDataType + oss << "," << detail::type_name(); // 8. AccDataType + oss << "," + << detail::elementwise_op_name(); // 9. InElementwiseOperation + oss << "," + << detail::elementwise_op_name(); // 10. WeiElementwiseOperation + oss << "," + << detail::elementwise_op_name(); // 11. OutElementwiseOperation + oss << "," + << detail::conv_bwd_weight_spec_name( + kConvBackwardWeightSpecialization); // 12. ConvBackwardWeightSpecialization + oss << "," << kBlockSize; // 13. BlockSize + oss << "," << kMPerBlock; // 14. MPerBlock + oss << "," << kNPerBlock; // 15. NPerBlock + oss << "," << kK0PerBlock; // 16. K0PerBlock + oss << "," << kK1; // 17. K1 + oss << "," << kM1PerThread; // 18. M1PerThread + oss << "," << kN1PerThread; // 19. N1PerThread + oss << "," << kKPerThread; // 20. KPerThread + oss << "," << detail::sequence_name(); // 21. + oss << "," << detail::sequence_name(); // 22. + oss << "," << detail::sequence_name(); // 23. + oss << "," << detail::sequence_name(); // 24. + oss << "," << detail::sequence_name(); // 25. + oss << "," << detail::sequence_name(); // 26. + oss << "," << detail::sequence_name(); // 27. + oss << "," << detail::sequence_name(); // 28. + oss << "," << detail::sequence_name(); // 29. + oss << "," << detail::sequence_name(); // 30. + oss << "," << detail::sequence_name(); // 31. + oss << "," << detail::sequence_name(); // 32. + oss << "," << detail::sequence_name(); // 33. + oss << "," << detail::sequence_name(); // 34. + oss << "," << detail::sequence_name(); // 35. + oss << "," << detail::sequence_name(); // 36. + oss << "," << detail::sequence_name(); // 37. + oss << "," << kCThreadTransferSrcDstVectorDim; // 38. + oss << "," << kCThreadTransferDstScalarPerVector; // 39. + oss << ">"; + + return oss.str(); + } +}; + +} // namespace reflect +} // namespace ck_tile diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp new file mode 100644 index 00000000000..63021bde699 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp @@ -0,0 +1,289 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "instance_traits.hpp" +#include "instance_traits_util.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" + +namespace ck::tensor_operation::device { + +template +struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3; + +} // namespace ck::tensor_operation::device + +namespace ck_tile { +namespace reflect { + +template +struct InstanceTraits> +{ + static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3"; + + static constexpr ck::index_t kNDimSpatial = NDimSpatial; + + using InLayout = InLayout_; + using WeiLayout = WeiLayout_; + using OutLayout = OutLayout_; + using DsLayout = DsLayout_; + + using InDataType = InDataType_; + using WeiDataType = WeiDataType_; + using OutDataType = OutDataType_; + using AccDataType = AccDataType_; + using DsDataType = DsDataType_; + + using InElementwiseOperation = InElementwiseOperation_; + using WeiElementwiseOperation = WeiElementwiseOperation_; + using OutElementwiseOperation = OutElementwiseOperation_; + + static constexpr auto kConvBackwardWeightSpecialization = ConvBackwardWeightSpecialization; + + static constexpr ck::index_t kBlockSize = BlockSize; + static constexpr ck::index_t kMPerBlock = MPerBlock; + static constexpr ck::index_t kNPerBlock = NPerBlock; + static constexpr ck::index_t kKPerBlock = KPerBlock; + static constexpr ck::index_t kABK1 = ABK1; + static constexpr ck::index_t kMPerWmma = MPerWmma; + static constexpr ck::index_t kNPerWmma = NPerWmma; + static constexpr ck::index_t kMRepeat = MRepeat; + static constexpr ck::index_t kNRepeat = NRepeat; + static constexpr ck::index_t kCShuffleMRepeatPerShuffle = CShuffleMRepeatPerShuffle; + static constexpr ck::index_t kCShuffleNRepeatPerShuffle = CShuffleNRepeatPerShuffle; + static constexpr ck::index_t kCShuffleBlockTransferScalarPerVector_NPerBlock = CShuffleBlockTransferScalarPerVector_NPerBlock; + + using ABlockTransferThreadClusterLengths_AK0_M_AK1 = ABlockTransferThreadClusterLengths_AK0_M_AK1_; + using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_; + using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_; + static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim; + static constexpr ck::index_t kABlockTransferSrcScalarPerVector = + ABlockTransferSrcScalarPerVector; + static constexpr ck::index_t kABlockTransferDstScalarPerVector_AK1 = + ABlockTransferDstScalarPerVector_AK1; + static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM; + + using BBlockTransferThreadClusterLengths_BK0_N_BK1 = BBlockTransferThreadClusterLengths_BK0_N_BK1_; + using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_; + using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_; + static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim; + static constexpr ck::index_t kBBlockTransferSrcScalarPerVector = + BBlockTransferSrcScalarPerVector; + static constexpr ck::index_t kBBlockTransferDstScalarPerVector_BK1 = + BBlockTransferDstScalarPerVector_BK1; + static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN; + + using CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_; + + static constexpr ck::BlockGemmPipelineScheduler kBlkGemmPipeSched = BlkGemmPipeSched; + static constexpr ck::BlockGemmPipelineVersion kBlkGemmPipelineVer = BlkGemmPipelineVer; + + using ComputeTypeA = ComputeTypeA_; + using ComputeTypeB = ComputeTypeB_; + + // Static member function to generate instance string + static std::string instance_string() + { + std::ostringstream oss; + + // Kernel type name + oss << "DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3"; + + // Template parameters in exact order + oss << "<" << kNDimSpatial; // 1. NDimSpatial + oss << "," << detail::layout_name(); // 2. InLayout + oss << "," << detail::layout_name(); // 3. WeiLayout + oss << "," << detail::layout_name(); // 4. OutLayout + oss << "," << detail::tuple_name(); // 5. DsLayout + oss << "," << detail::type_name(); // 6. InDataType + oss << "," << detail::type_name(); // 7. WeiDataType + oss << "," << detail::type_name(); // 8. OutDataType + oss << "," << detail::type_name(); // 9. AccDataType + oss << "," << detail::tuple_name(); // 10. DsDataType + oss << "," + << detail::elementwise_op_name(); // 11. InElementwiseOperation + oss << "," + << detail::elementwise_op_name(); // 12. WeiElementwiseOperation + oss << "," + << detail::elementwise_op_name(); // 13. OutElementwiseOperation + oss << "," + << detail::conv_bwd_weight_spec_name( + kConvBackwardWeightSpecialization); // 14. ConvBackwardWeightSpecialization + oss << "," << kBlockSize; // 15. BlockSize + oss << "," << kMPerBlock; // 16. MPerBlock + oss << "," << kNPerBlock; // 17. NPerBlock + oss << "," << kKPerBlock; // 18. KPerBlock + oss << "," << kABK1; // 19. ABK1 + oss << "," << kMPerWmma; // 20. MPerWmma + oss << "," << kNPerWmma; // 21. NPerWmma + oss << "," << kMRepeat; // 22. MRepeat + oss << "," << kNRepeat; // 23. NRepeat + oss << "," << detail::sequence_name(); // 24. + oss << "," << detail::sequence_name(); // 25. + oss << "," << detail::sequence_name(); // 26. + oss << "," << kABlockTransferSrcVectorDim; // 27. + oss << "," << kABlockTransferSrcScalarPerVector; // 28. + oss << "," << kABlockTransferDstScalarPerVector_AK1; // 29. + oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 30. + oss << "," << detail::sequence_name(); // 31. + oss << "," << detail::sequence_name(); // 32. + oss << "," << detail::sequence_name(); // 33. + oss << "," << kBBlockTransferSrcVectorDim; // 34. + oss << "," << kBBlockTransferSrcScalarPerVector; // 35. + oss << "," << kBBlockTransferDstScalarPerVector_BK1; // 36. + oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 37. + oss << "," << kCShuffleMRepeatPerShuffle; // 38. + oss << "," << kCShuffleNRepeatPerShuffle; // 39. + oss << "," + << detail::sequence_name< + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>(); // 40. + oss << "," << kCShuffleBlockTransferScalarPerVector_NPerBlock; // 41. + oss << "," << detail::pipeline_scheduler_name(kBlkGemmPipeSched); // 42. + oss << "," << detail::pipeline_version_name(kBlkGemmPipelineVer); // 43. + oss << "," << detail::type_name(); // 44. + oss << "," << detail::type_name(); // 45. + oss << ">"; + + return oss.str(); + } +}; + +} // namespace reflect +} // namespace ck_tile diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp new file mode 100644 index 00000000000..5d33c96c97e --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -0,0 +1,278 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "instance_traits.hpp" +#include "instance_traits_util.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" + +namespace ck::tensor_operation::device { + +template +struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle; + +} // namespace ck::tensor_operation::device + +namespace ck_tile { +namespace reflect { + +template +struct InstanceTraits> +{ + static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle"; + + static constexpr ck::index_t kNDimSpatial = NDimSpatial; + + using InLayout = InLayout_; + using WeiLayout = WeiLayout_; + using OutLayout = OutLayout_; + using DsLayout = DsLayout_; + + using InDataType = InDataType_; + using WeiDataType = WeiDataType_; + using OutDataType = OutDataType_; + using AccDataType = AccDataType_; + using DsDataType = DsDataType_; + + using InElementwiseOperation = InElementwiseOperation_; + using WeiElementwiseOperation = WeiElementwiseOperation_; + using OutElementwiseOperation = OutElementwiseOperation_; + + static constexpr auto kConvBackwardWeightSpecialization = ConvBackwardWeightSpecialization; + + static constexpr ck::index_t kBlockSize = BlockSize; + static constexpr ck::index_t kMPerBlock = MPerBlock; + static constexpr ck::index_t kNPerBlock = NPerBlock; + static constexpr ck::index_t kK0PerBlock = K0PerBlock; + static constexpr ck::index_t kK1 = K1; + static constexpr ck::index_t kMPerXDL = MPerXDL; + static constexpr ck::index_t kNPerXDL = NPerXDL; + static constexpr ck::index_t kMXdlPerWave = MXdlPerWave; + static constexpr ck::index_t kNXdlPerWave = NXdlPerWave; + static constexpr ck::index_t kCShuffleMXdlPerWavePerShuffle = CShuffleMXdlPerWavePerShuffle; + static constexpr ck::index_t kCShuffleNXdlPerWavePerShuffle = CShuffleNXdlPerWavePerShuffle; + static constexpr ck::index_t kCBlockTransferScalarPerVector_NWaveNPerXdl = CBlockTransferScalarPerVector_NWaveNPerXdl; + + using ABlockTransferThreadClusterLengths_K0_M_K1 = ABlockTransferThreadClusterLengths_K0_M_K1_; + using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_; + using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_; + static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim; + static constexpr ck::index_t kABlockTransferSrcScalarPerVector = + ABlockTransferSrcScalarPerVector; + static constexpr ck::index_t kABlockTransferDstScalarPerVector_K1 = + ABlockTransferDstScalarPerVector_K1; + static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM; + + using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_; + using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_; + using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_; + static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim; + static constexpr ck::index_t kBBlockTransferSrcScalarPerVector = + BBlockTransferSrcScalarPerVector; + static constexpr ck::index_t kBBlockTransferDstScalarPerVector_K1 = + BBlockTransferDstScalarPerVector_K1; + static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN; + + using CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_; + + using ComputeTypeA = ComputeTypeA_; + using ComputeTypeB = ComputeTypeB_; + + // Static member function to generate instance string + static std::string instance_string() + { + std::ostringstream oss; + + // Kernel type name + oss << "DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle"; + + // Template parameters in exact order + oss << "<" << kNDimSpatial; // 1. NDimSpatial + oss << "," << detail::layout_name(); // 2. InLayout + oss << "," << detail::layout_name(); // 3. WeiLayout + oss << "," << detail::layout_name(); // 4. OutLayout + oss << "," << detail::tuple_name(); // 5. DsLayout + oss << "," << detail::type_name(); // 6. InDataType + oss << "," << detail::type_name(); // 7. WeiDataType + oss << "," << detail::type_name(); // 8. OutDataType + oss << "," << detail::type_name(); // 9. AccDataType + oss << "," << detail::tuple_name(); // 10. DsDataType + oss << "," + << detail::elementwise_op_name(); // 11. InElementwiseOperation + oss << "," + << detail::elementwise_op_name(); // 12. WeiElementwiseOperation + oss << "," + << detail::elementwise_op_name(); // 13. OutElementwiseOperation + oss << "," + << detail::conv_bwd_weight_spec_name( + kConvBackwardWeightSpecialization); // 14. ConvBackwardWeightSpecialization + oss << "," << kBlockSize; // 15. BlockSize + oss << "," << kMPerBlock; // 16. MPerBlock + oss << "," << kNPerBlock; // 17. NPerBlock + oss << "," << kK0PerBlock; // 18. K0PerBlock + oss << "," << kK1; // 19. K1 + oss << "," << kMPerXDL; // 20. MPerXDL + oss << "," << kNPerXDL; // 21. NPerXDL + oss << "," << kMXdlPerWave; // 22. MXdlPerWave + oss << "," << kNXdlPerWave; // 23. NXdlPerWave + oss << "," << detail::sequence_name(); // 24. + oss << "," << detail::sequence_name(); // 25. + oss << "," << detail::sequence_name(); // 26. + oss << "," << kABlockTransferSrcVectorDim; // 27. + oss << "," << kABlockTransferSrcScalarPerVector; // 28. + oss << "," << kABlockTransferDstScalarPerVector_K1; // 29. + oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 30. + oss << "," << detail::sequence_name(); // 31. + oss << "," << detail::sequence_name(); // 32. + oss << "," << detail::sequence_name(); // 33. + oss << "," << kBBlockTransferSrcVectorDim; // 34. + oss << "," << kBBlockTransferSrcScalarPerVector; // 35. + oss << "," << kBBlockTransferDstScalarPerVector_K1; // 36. + oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 37. + oss << "," << kCShuffleMXdlPerWavePerShuffle; // 38. + oss << "," << kCShuffleNXdlPerWavePerShuffle; // 39. + oss << "," + << detail::sequence_name< + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>(); // 40. + oss << "," << kCBlockTransferScalarPerVector_NWaveNPerXdl; // 41. + oss << "," << detail::type_name(); // 42. + oss << "," << detail::type_name(); // 43. + oss << ">"; + + return oss.str(); + } +}; + +} // namespace reflect +} // namespace ck_tile diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp new file mode 100644 index 00000000000..67d2229a295 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp @@ -0,0 +1,294 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "instance_traits.hpp" +#include "instance_traits_util.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" + +namespace ck::tensor_operation::device { + +template +struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3; + +} // namespace ck::tensor_operation::device + +namespace ck_tile { +namespace reflect { + +template +struct InstanceTraits> +{ + static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3"; + + static constexpr ck::index_t kNDimSpatial = NDimSpatial; + + using InLayout = InLayout_; + using WeiLayout = WeiLayout_; + using OutLayout = OutLayout_; + + using InDataType = InDataType_; + using WeiDataType = WeiDataType_; + using OutDataType = OutDataType_; + using AccDataType = AccDataType_; + + using InElementwiseOperation = InElementwiseOperation_; + using WeiElementwiseOperation = WeiElementwiseOperation_; + using OutElementwiseOperation = OutElementwiseOperation_; + + static constexpr auto kConvBackwardWeightSpecialization = ConvBackwardWeightSpecialization; + + static constexpr ck::index_t kBlockSize = BlockSize; + static constexpr ck::index_t kMPerBlock = MPerBlock; + static constexpr ck::index_t kNPerBlock = NPerBlock; + static constexpr ck::index_t kKPerBlock = KPerBlock; + static constexpr ck::index_t kABK1 = ABK1; + static constexpr ck::index_t kMPerWmma = MPerWmma; + static constexpr ck::index_t kNPerWmma = NPerWmma; + static constexpr ck::index_t kMRepeat = MRepeat; + static constexpr ck::index_t kNRepeat = NRepeat; + static constexpr ck::index_t kCShuffleMRepeatPerShuffle = CShuffleMRepeatPerShuffle; + static constexpr ck::index_t kCShuffleNRepeatPerShuffle = CShuffleNRepeatPerShuffle; + static constexpr ck::index_t kCShuffleBlockTransferScalarPerVector_NPerBlock = CShuffleBlockTransferScalarPerVector_NPerBlock; + static constexpr ck::index_t kNumGroupsToMerge = NumGroupsToMerge; + static constexpr ck::index_t kTransposeTransferSrcScalarPerVector = TransposeTransferSrcScalarPerVector; + static constexpr ck::index_t kTransposeTransferDstScalarPerVector = TransposeTransferDstScalarPerVector; + + using ABlockTransferThreadClusterLengths_AK0_M_AK1 = ABlockTransferThreadClusterLengths_AK0_M_AK1_; + using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_; + using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_; + static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim; + static constexpr ck::index_t kABlockTransferSrcScalarPerVector = + ABlockTransferSrcScalarPerVector; + static constexpr ck::index_t kABlockTransferDstScalarPerVector_AK1 = + ABlockTransferDstScalarPerVector_AK1; + static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM; + + using BBlockTransferThreadClusterLengths_BK0_N_BK1 = BBlockTransferThreadClusterLengths_BK0_N_BK1_; + using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_; + using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_; + static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim; + static constexpr ck::index_t kBBlockTransferSrcScalarPerVector = + BBlockTransferSrcScalarPerVector; + static constexpr ck::index_t kBBlockTransferDstScalarPerVector_BK1 = + BBlockTransferDstScalarPerVector_BK1; + static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN; + + using CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_; + + static constexpr ck::BlockGemmPipelineScheduler kBlkGemmPipeSched = BlkGemmPipeSched; + static constexpr ck::BlockGemmPipelineVersion kBlkGemmPipelineVer = BlkGemmPipelineVer; + + using ComputeTypeA = ComputeTypeA_; + using ComputeTypeB = ComputeTypeB_; + + // Static member function to generate instance string + static std::string instance_string() + { + std::ostringstream oss; + + // Kernel type name + oss << "DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3"; + + // Template parameters in exact order + oss << "<" << kNDimSpatial; // 1. NDimSpatial + oss << "," << detail::layout_name(); // 2. InLayout + oss << "," << detail::layout_name(); // 3. WeiLayout + oss << "," << detail::layout_name(); // 4. OutLayout + oss << "," << detail::type_name(); // 5. InDataType + oss << "," << detail::type_name(); // 6. WeiDataType + oss << "," << detail::type_name(); // 7. OutDataType + oss << "," << detail::type_name(); // 8. AccDataType + oss << "," + << detail::elementwise_op_name(); // 9. InElementwiseOperation + oss << "," + << detail::elementwise_op_name(); // 10. WeiElementwiseOperation + oss << "," + << detail::elementwise_op_name(); // 11. OutElementwiseOperation + oss << "," + << detail::conv_bwd_weight_spec_name( + kConvBackwardWeightSpecialization); // 12. ConvBackwardWeightSpecialization + oss << "," << kBlockSize; // 13. BlockSize + oss << "," << kMPerBlock; // 14. MPerBlock + oss << "," << kNPerBlock; // 15. NPerBlock + oss << "," << kKPerBlock; // 16. KPerBlock + oss << "," << kABK1; // 17. ABK1 + oss << "," << kMPerWmma; // 18. MPerWmma + oss << "," << kNPerWmma; // 19. NPerWmma + oss << "," << kMRepeat; // 20. MRepeat + oss << "," << kNRepeat; // 21. NRepeat + oss << "," << detail::sequence_name(); // 22. + oss << "," << detail::sequence_name(); // 23. + oss << "," << detail::sequence_name(); // 24. + oss << "," << kABlockTransferSrcVectorDim; // 25. + oss << "," << kABlockTransferSrcScalarPerVector; // 26. + oss << "," << kABlockTransferDstScalarPerVector_AK1; // 27. + oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28. + oss << "," << detail::sequence_name(); // 29. + oss << "," << detail::sequence_name(); // 30. + oss << "," << detail::sequence_name(); // 31. + oss << "," << kBBlockTransferSrcVectorDim; // 32. + oss << "," << kBBlockTransferSrcScalarPerVector; // 33. + oss << "," << kBBlockTransferDstScalarPerVector_BK1; // 34. + oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35. + oss << "," << kCShuffleMRepeatPerShuffle; // 36. + oss << "," << kCShuffleNRepeatPerShuffle; // 37. + oss << "," + << detail::sequence_name< + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>(); // 38. + oss << "," << kCShuffleBlockTransferScalarPerVector_NPerBlock; // 39. + oss << "," << detail::pipeline_scheduler_name(kBlkGemmPipeSched); // 40. + oss << "," << detail::pipeline_version_name(kBlkGemmPipelineVer); // 41. + oss << "," << kNumGroupsToMerge; // 42. + oss << "," << detail::type_name(); // 43. + oss << "," << detail::type_name(); // 44. + oss << "," << kTransposeTransferSrcScalarPerVector; // 45. + oss << "," << kTransposeTransferDstScalarPerVector; // 46. + oss << ">"; + + return oss.str(); + } +}; + +} // namespace reflect +} // namespace ck_tile diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp new file mode 100644 index 00000000000..7c0c2ed9b36 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -0,0 +1,294 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "instance_traits.hpp" +#include "instance_traits_util.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" + +namespace ck::tensor_operation::device { + +template +struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle; + +} // namespace ck::tensor_operation::device + +namespace ck_tile { +namespace reflect { + +template +struct InstanceTraits> +{ + static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle"; + + static constexpr ck::index_t kNDimSpatial = NDimSpatial; + + using InLayout = InLayout_; + using WeiLayout = WeiLayout_; + using OutLayout = OutLayout_; + + using InDataType = InDataType_; + using WeiDataType = WeiDataType_; + using OutDataType = OutDataType_; + using AccDataType = AccDataType_; + + using InElementwiseOperation = InElementwiseOperation_; + using WeiElementwiseOperation = WeiElementwiseOperation_; + using OutElementwiseOperation = OutElementwiseOperation_; + + static constexpr auto kConvBackwardWeightSpecialization = ConvBackwardWeightSpecialization; + + static constexpr ck::index_t kBlockSize = BlockSize; + static constexpr ck::index_t kMPerBlock = MPerBlock; + static constexpr ck::index_t kNPerBlock = NPerBlock; + static constexpr ck::index_t kKPerBlock = KPerBlock; + static constexpr ck::index_t kK1 = K1; + static constexpr ck::index_t kMPerXDL = MPerXDL; + static constexpr ck::index_t kNPerXDL = NPerXDL; + static constexpr ck::index_t kMXdlPerWave = MXdlPerWave; + static constexpr ck::index_t kNXdlPerWave = NXdlPerWave; + static constexpr ck::index_t kCShuffleMXdlPerWavePerShuffle = CShuffleMXdlPerWavePerShuffle; + static constexpr ck::index_t kCShuffleNXdlPerWavePerShuffle = CShuffleNXdlPerWavePerShuffle; + static constexpr ck::index_t kCBlockTransferScalarPerVector_NWaveNPerXdl = CBlockTransferScalarPerVector_NWaveNPerXdl; + static constexpr ck::index_t kNumGroupsToMerge = NumGroupsToMerge; + static constexpr ck::index_t kTransposeTransferSrcScalarPerVector = TransposeTransferSrcScalarPerVector; + static constexpr ck::index_t kTransposeTransferDstScalarPerVector = TransposeTransferDstScalarPerVector; + + using ABlockTransferThreadClusterLengths_K0_M_K1 = ABlockTransferThreadClusterLengths_K0_M_K1_; + using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_; + using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_; + static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim; + static constexpr ck::index_t kABlockTransferSrcScalarPerVector = + ABlockTransferSrcScalarPerVector; + static constexpr ck::index_t kABlockTransferDstScalarPerVector_K1 = + ABlockTransferDstScalarPerVector_K1; + static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM; + + using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_; + using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_; + using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_; + static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim; + static constexpr ck::index_t kBBlockTransferSrcScalarPerVector = + BBlockTransferSrcScalarPerVector; + static constexpr ck::index_t kBBlockTransferDstScalarPerVector_K1 = + BBlockTransferDstScalarPerVector_K1; + static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN; + + using CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_; + + static constexpr ck::BlockGemmPipelineScheduler kBlkGemmPipeSched = BlkGemmPipeSched; + static constexpr ck::BlockGemmPipelineVersion kBlkGemmPipelineVer = BlkGemmPipelineVer; + + using ComputeTypeA = ComputeTypeA_; + using ComputeTypeB = ComputeTypeB_; + + // Static member function to generate instance string + static std::string instance_string() + { + std::ostringstream oss; + + // Kernel type name + oss << "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle"; + + // Template parameters in exact order + oss << "<" << kNDimSpatial; // 1. NDimSpatial + oss << "," << detail::layout_name(); // 2. InLayout + oss << "," << detail::layout_name(); // 3. WeiLayout + oss << "," << detail::layout_name(); // 4. OutLayout + oss << "," << detail::type_name(); // 5. InDataType + oss << "," << detail::type_name(); // 6. WeiDataType + oss << "," << detail::type_name(); // 7. OutDataType + oss << "," << detail::type_name(); // 8. AccDataType + oss << "," + << detail::elementwise_op_name(); // 9. InElementwiseOperation + oss << "," + << detail::elementwise_op_name(); // 10. WeiElementwiseOperation + oss << "," + << detail::elementwise_op_name(); // 11. OutElementwiseOperation + oss << "," + << detail::conv_bwd_weight_spec_name( + kConvBackwardWeightSpecialization); // 12. ConvBackwardWeightSpecialization + oss << "," << kBlockSize; // 13. BlockSize + oss << "," << kMPerBlock; // 14. MPerBlock + oss << "," << kNPerBlock; // 15. NPerBlock + oss << "," << kKPerBlock; // 16. KPerBlock + oss << "," << kK1; // 17. K1 + oss << "," << kMPerXDL; // 18. MPerXDL + oss << "," << kNPerXDL; // 19. NPerXDL + oss << "," << kMXdlPerWave; // 20. MXdlPerWave + oss << "," << kNXdlPerWave; // 21. NXdlPerWave + oss << "," << detail::sequence_name(); // 22. + oss << "," << detail::sequence_name(); // 23. + oss << "," << detail::sequence_name(); // 24. + oss << "," << kABlockTransferSrcVectorDim; // 25. + oss << "," << kABlockTransferSrcScalarPerVector; // 26. + oss << "," << kABlockTransferDstScalarPerVector_K1; // 27. + oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28. + oss << "," << detail::sequence_name(); // 29. + oss << "," << detail::sequence_name(); // 30. + oss << "," << detail::sequence_name(); // 31. + oss << "," << kBBlockTransferSrcVectorDim; // 32. + oss << "," << kBBlockTransferSrcScalarPerVector; // 33. + oss << "," << kBBlockTransferDstScalarPerVector_K1; // 34. + oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35. + oss << "," << kCShuffleMXdlPerWavePerShuffle; // 36. + oss << "," << kCShuffleNXdlPerWavePerShuffle; // 37. + oss << "," + << detail::sequence_name< + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>(); // 38. + oss << "," << kCBlockTransferScalarPerVector_NWaveNPerXdl; // 39. + oss << "," << detail::pipeline_scheduler_name(kBlkGemmPipeSched); // 40. + oss << "," << detail::pipeline_version_name(kBlkGemmPipelineVer); // 41. + oss << "," << kNumGroupsToMerge; // 42. + oss << "," << detail::type_name(); // 43. + oss << "," << detail::type_name(); // 44. + oss << "," << kTransposeTransferSrcScalarPerVector; // 45. + oss << "," << kTransposeTransferDstScalarPerVector; // 46. + oss << ">"; + + return oss.str(); + } +}; + +} // namespace reflect +} // namespace ck_tile diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp new file mode 100644 index 00000000000..669c3fccaba --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp @@ -0,0 +1,275 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "instance_traits.hpp" +#include "instance_traits_util.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" + +namespace ck::tensor_operation::device { + +template ::type> +struct DeviceGroupedConvBwdWeight_Wmma_CShuffle; + +} // namespace ck::tensor_operation::device + +namespace ck_tile { +namespace reflect { + +template ::type> +struct InstanceTraits> +{ + static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeight_Wmma_CShuffle"; + + static constexpr ck::index_t kNDimSpatial = NDimSpatial; + + using InLayout = InLayout_; + using WeiLayout = WeiLayout_; + using OutLayout = OutLayout_; + + using InDataType = InDataType_; + using WeiDataType = WeiDataType_; + using OutDataType = OutDataType_; + using AccDataType = AccDataType_; + + using InElementwiseOperation = InElementwiseOperation_; + using WeiElementwiseOperation = WeiElementwiseOperation_; + using OutElementwiseOperation = OutElementwiseOperation_; + + static constexpr auto kConvBackwardWeightSpecialization = ConvBackwardWeightSpecialization; + + static constexpr ck::index_t kBlockSize = BlockSize; + static constexpr ck::index_t kMPerBlock = MPerBlock; + static constexpr ck::index_t kNPerBlock = NPerBlock; + static constexpr ck::index_t kK0PerBlock = K0PerBlock; + static constexpr ck::index_t kK1 = K1; + static constexpr ck::index_t kMPerWMMA = MPerWMMA; + static constexpr ck::index_t kNPerWMMA = NPerWMMA; + static constexpr ck::index_t kMRepeat = MRepeat; + static constexpr ck::index_t kNRepeat = NRepeat; + static constexpr ck::index_t kCShuffleMRepeatPerShuffle = CShuffleMRepeatPerShuffle; + static constexpr ck::index_t kCShuffleNRepeatPerShuffle = CShuffleNRepeatPerShuffle; + static constexpr ck::index_t kCShuffleBlockTransferScalarPerVector_NPerBlock = CShuffleBlockTransferScalarPerVector_NPerBlock; + static constexpr ck::index_t kNumGemmKPrefetchStage = NumGemmKPrefetchStage; + + using ABlockTransferThreadClusterLengths_K0_M_K1 = ABlockTransferThreadClusterLengths_K0_M_K1_; + using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_; + using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_; + static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim; + static constexpr ck::index_t kABlockTransferSrcScalarPerVector = + ABlockTransferSrcScalarPerVector; + static constexpr ck::index_t kABlockTransferDstScalarPerVector_K1 = + ABlockTransferDstScalarPerVector_K1; + static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM; + + using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_; + using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_; + using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_; + static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim; + static constexpr ck::index_t kBBlockTransferSrcScalarPerVector = + BBlockTransferSrcScalarPerVector; + static constexpr ck::index_t kBBlockTransferDstScalarPerVector_K1 = + BBlockTransferDstScalarPerVector_K1; + static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN; + + using CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_; + + static constexpr ck::LoopScheduler kLoopSched = LoopSched; + static constexpr ck::PipelineVersion kPipelineVer = PipelineVer; + + // Static member function to generate instance string + static std::string instance_string() + { + std::ostringstream oss; + + // Kernel type name + oss << "DeviceGroupedConvBwdWeight_Wmma_CShuffle"; + + // Template parameters in exact order + oss << "<" << kNDimSpatial; // 1. NDimSpatial + oss << "," << detail::layout_name(); // 2. InLayout + oss << "," << detail::layout_name(); // 3. WeiLayout + oss << "," << detail::layout_name(); // 4. OutLayout + oss << "," << detail::type_name(); // 5. InDataType + oss << "," << detail::type_name(); // 6. WeiDataType + oss << "," << detail::type_name(); // 7. OutDataType + oss << "," << detail::type_name(); // 8. AccDataType + oss << "," + << detail::elementwise_op_name(); // 9. InElementwiseOperation + oss << "," + << detail::elementwise_op_name(); // 10. WeiElementwiseOperation + oss << "," + << detail::elementwise_op_name(); // 11. OutElementwiseOperation + oss << "," + << detail::conv_bwd_weight_spec_name( + kConvBackwardWeightSpecialization); // 12. ConvBackwardWeightSpecialization + oss << "," << kBlockSize; // 13. BlockSize + oss << "," << kMPerBlock; // 14. MPerBlock + oss << "," << kNPerBlock; // 15. NPerBlock + oss << "," << kK0PerBlock; // 16. K0PerBlock + oss << "," << kK1; // 17. K1 + oss << "," << kMPerWMMA; // 18. MPerWMMA + oss << "," << kNPerWMMA; // 19. NPerWMMA + oss << "," << kMRepeat; // 20. MRepeat + oss << "," << kNRepeat; // 21. NRepeat + oss << "," << detail::sequence_name(); // 22. + oss << "," << detail::sequence_name(); // 23. + oss << "," << detail::sequence_name(); // 24. + oss << "," << kABlockTransferSrcVectorDim; // 25. + oss << "," << kABlockTransferSrcScalarPerVector; // 26. + oss << "," << kABlockTransferDstScalarPerVector_K1; // 27. + oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28. + oss << "," << detail::sequence_name(); // 29. + oss << "," << detail::sequence_name(); // 30. + oss << "," << detail::sequence_name(); // 31. + oss << "," << kBBlockTransferSrcVectorDim; // 32. + oss << "," << kBBlockTransferSrcScalarPerVector; // 33. + oss << "," << kBBlockTransferDstScalarPerVector_K1; // 34. + oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35. + oss << "," << kCShuffleMRepeatPerShuffle; // 36. + oss << "," << kCShuffleNRepeatPerShuffle; // 37. + oss << "," + << detail::sequence_name< + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>(); // 38. + oss << "," << kCShuffleBlockTransferScalarPerVector_NPerBlock; // 39. + oss << "," << kNumGemmKPrefetchStage; // 40. + oss << "," << detail::loop_scheduler_name(kLoopSched); // 41. + oss << "," << detail::pipeline_version_name(kPipelineVer); // 42. + oss << ">"; + + return oss.str(); + } +}; + +} // namespace reflect +} // namespace ck_tile diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp new file mode 100644 index 00000000000..0161b6b6812 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp @@ -0,0 +1,289 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "instance_traits.hpp" +#include "instance_traits_util.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" + +namespace ck::tensor_operation::device { + +template +struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3; + +} // namespace ck::tensor_operation::device + +namespace ck_tile { +namespace reflect { + +template +struct InstanceTraits> +{ + static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeight_Wmma_CShuffleV3"; + + static constexpr ck::index_t kNDimSpatial = NDimSpatial; + + using InLayout = InLayout_; + using WeiLayout = WeiLayout_; + using OutLayout = OutLayout_; + + using InDataType = InDataType_; + using WeiDataType = WeiDataType_; + using OutDataType = OutDataType_; + using AccDataType = AccDataType_; + + using InElementwiseOperation = InElementwiseOperation_; + using WeiElementwiseOperation = WeiElementwiseOperation_; + using OutElementwiseOperation = OutElementwiseOperation_; + + static constexpr auto kConvBackwardWeightSpecialization = ConvBackwardWeightSpecialization; + + static constexpr ck::index_t kBlockSize = BlockSize; + static constexpr ck::index_t kMPerBlock = MPerBlock; + static constexpr ck::index_t kNPerBlock = NPerBlock; + static constexpr ck::index_t kKPerBlock = KPerBlock; + static constexpr ck::index_t kABK1 = ABK1; + static constexpr ck::index_t kMPerWmma = MPerWmma; + static constexpr ck::index_t kNPerWmma = NPerWmma; + static constexpr ck::index_t kMRepeat = MRepeat; + static constexpr ck::index_t kNRepeat = NRepeat; + static constexpr ck::index_t kCShuffleMRepeatPerShuffle = CShuffleMRepeatPerShuffle; + static constexpr ck::index_t kCShuffleNRepeatPerShuffle = CShuffleNRepeatPerShuffle; + static constexpr ck::index_t kCShuffleBlockTransferScalarPerVector_NPerBlock = CShuffleBlockTransferScalarPerVector_NPerBlock; + static constexpr ck::index_t kMaxTransposeTransferSrcScalarPerVector = MaxTransposeTransferSrcScalarPerVector; + static constexpr ck::index_t kMaxTransposeTransferDstScalarPerVector = MaxTransposeTransferDstScalarPerVector; + + using ABlockTransferThreadClusterLengths_AK0_M_AK1 = ABlockTransferThreadClusterLengths_AK0_M_AK1_; + using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_; + using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_; + static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim; + static constexpr ck::index_t kABlockTransferSrcScalarPerVector = + ABlockTransferSrcScalarPerVector; + static constexpr ck::index_t kABlockTransferDstScalarPerVector_AK1 = + ABlockTransferDstScalarPerVector_AK1; + static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM; + + using BBlockTransferThreadClusterLengths_BK0_N_BK1 = BBlockTransferThreadClusterLengths_BK0_N_BK1_; + using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_; + using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_; + static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim; + static constexpr ck::index_t kBBlockTransferSrcScalarPerVector = + BBlockTransferSrcScalarPerVector; + static constexpr ck::index_t kBBlockTransferDstScalarPerVector_BK1 = + BBlockTransferDstScalarPerVector_BK1; + static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN; + + using CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_; + + static constexpr ck::BlockGemmPipelineScheduler kBlkGemmPipeSched = BlkGemmPipeSched; + static constexpr ck::BlockGemmPipelineVersion kBlkGemmPipelineVer = BlkGemmPipelineVer; + + using ComputeTypeA = ComputeTypeA_; + using ComputeTypeB = ComputeTypeB_; + + // Static member function to generate instance string + static std::string instance_string() + { + std::ostringstream oss; + + // Kernel type name + oss << "DeviceGroupedConvBwdWeight_Wmma_CShuffleV3"; + + // Template parameters in exact order + oss << "<" << kNDimSpatial; // 1. NDimSpatial + oss << "," << detail::layout_name(); // 2. InLayout + oss << "," << detail::layout_name(); // 3. WeiLayout + oss << "," << detail::layout_name(); // 4. OutLayout + oss << "," << detail::type_name(); // 5. InDataType + oss << "," << detail::type_name(); // 6. WeiDataType + oss << "," << detail::type_name(); // 7. OutDataType + oss << "," << detail::type_name(); // 8. AccDataType + oss << "," + << detail::elementwise_op_name(); // 9. InElementwiseOperation + oss << "," + << detail::elementwise_op_name(); // 10. WeiElementwiseOperation + oss << "," + << detail::elementwise_op_name(); // 11. OutElementwiseOperation + oss << "," + << detail::conv_bwd_weight_spec_name( + kConvBackwardWeightSpecialization); // 12. ConvBackwardWeightSpecialization + oss << "," << kBlockSize; // 13. BlockSize + oss << "," << kMPerBlock; // 14. MPerBlock + oss << "," << kNPerBlock; // 15. NPerBlock + oss << "," << kKPerBlock; // 16. KPerBlock + oss << "," << kABK1; // 17. ABK1 + oss << "," << kMPerWmma; // 18. MPerWmma + oss << "," << kNPerWmma; // 19. NPerWmma + oss << "," << kMRepeat; // 20. MRepeat + oss << "," << kNRepeat; // 21. NRepeat + oss << "," << detail::sequence_name(); // 22. + oss << "," << detail::sequence_name(); // 23. + oss << "," << detail::sequence_name(); // 24. + oss << "," << kABlockTransferSrcVectorDim; // 25. + oss << "," << kABlockTransferSrcScalarPerVector; // 26. + oss << "," << kABlockTransferDstScalarPerVector_AK1; // 27. + oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28. + oss << "," << detail::sequence_name(); // 29. + oss << "," << detail::sequence_name(); // 30. + oss << "," << detail::sequence_name(); // 31. + oss << "," << kBBlockTransferSrcVectorDim; // 32. + oss << "," << kBBlockTransferSrcScalarPerVector; // 33. + oss << "," << kBBlockTransferDstScalarPerVector_BK1; // 34. + oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35. + oss << "," << kCShuffleMRepeatPerShuffle; // 36. + oss << "," << kCShuffleNRepeatPerShuffle; // 37. + oss << "," + << detail::sequence_name< + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>(); // 38. + oss << "," << kCShuffleBlockTransferScalarPerVector_NPerBlock; // 39. + oss << "," << detail::pipeline_scheduler_name(kBlkGemmPipeSched); // 40. + oss << "," << detail::pipeline_version_name(kBlkGemmPipelineVer); // 41. + oss << "," << detail::type_name(); // 42. + oss << "," << detail::type_name(); // 43. + oss << "," << kMaxTransposeTransferSrcScalarPerVector; // 44. + oss << "," << kMaxTransposeTransferDstScalarPerVector; // 45. + oss << ">"; + + return oss.str(); + } +}; + +} // namespace reflect +} // namespace ck_tile diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index 7970b1ced59..9ba245bf21f 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -271,7 +271,7 @@ struct InstanceTraits(); // 38. oss << "," << kCBlockTransferScalarPerVector_NWaveNPerXdl; // 39. oss << "," << detail::pipeline_scheduler_name(kBlkGemmPipeSched); // 40. - oss << "," << detail::pipeline_version_name(kBlkGemmPipelineVer); // 41. + oss << "," << detail::pipeline_version_name(kBlkGemmPipelineVer); // 41. oss << "," << detail::type_name(); // 42. oss << "," << detail::type_name(); // 43. oss << ">"; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp index b52502eb45f..2edf213be19 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp @@ -20,6 +20,11 @@ #include "ck/host_utility/kernel_launch.hpp" #include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp" +#ifdef CK_EXPERIMENTAL_BUILDER +#include "ck_tile/builder/reflect/description.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_dl.hpp" +#endif + namespace ck { namespace tensor_operation { namespace device { @@ -1227,6 +1232,25 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight, + "Specialization of instance_traits not found. Please check that a " + "specialization exists in file " + "ck_tile/builder/reflect/" + "instance_traits_device_grouped_conv_bwd_weight_dl.hpp " + "for the given template parameters."); + return ck_tile::reflect::instance_string(); + } + + std::unique_ptr describe() const override + { + return std::make_unique(GetInstanceString()); + } +#endif + }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp index 86e8defb83e..702665fe12b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp @@ -28,6 +28,11 @@ #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/flush_cache.hpp" +#ifdef CK_EXPERIMENTAL_BUILDER +#include "ck_tile/builder/reflect/description.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp" +#endif + namespace ck { namespace tensor_operation { namespace device { @@ -1251,6 +1256,25 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 "The argument pointer is not an object of " "DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3::Argument structure!"); } + +#ifdef CK_EXPERIMENTAL_BUILDER + std::string GetInstanceString() const override + { + static_assert(ck_tile::reflect::HasInstanceTraits, + "Specialization of instance_traits not found. Please check that a " + "specialization exists in file " + "ck_tile/builder/reflect/" + "instance_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp " + "for the given template parameters."); + return ck_tile::reflect::instance_string(); + } + + std::unique_ptr describe() const override + { + return std::make_unique(GetInstanceString()); + } +#endif + }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp index d7392e13a04..759318afe92 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -26,6 +26,11 @@ #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" +#ifdef CK_EXPERIMENTAL_BUILDER +#include "ck_tile/builder/reflect/description.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp" +#endif + namespace ck { namespace tensor_operation { namespace device { @@ -1207,6 +1212,25 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle "The argument pointer is not an object of " "DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle::Argument structure!"); } + +#ifdef CK_EXPERIMENTAL_BUILDER + std::string GetInstanceString() const override + { + static_assert(ck_tile::reflect::HasInstanceTraits, + "Specialization of instance_traits not found. Please check that a " + "specialization exists in file " + "ck_tile/builder/reflect/" + "instance_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp " + "for the given template parameters."); + return ck_tile::reflect::instance_string(); + } + + std::unique_ptr describe() const override + { + return std::make_unique(GetInstanceString()); + } +#endif + }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp index 37fe0b2c7ba..2e03bd5d152 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp @@ -30,6 +30,11 @@ #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/flush_cache.hpp" +#ifdef CK_EXPERIMENTAL_BUILDER +#include "ck_tile/builder/reflect/description.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp" +#endif + namespace ck { namespace tensor_operation { namespace device { @@ -1571,6 +1576,25 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3 "The argument pointer is not an object of " "DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3::Argument structure!"); } + +#ifdef CK_EXPERIMENTAL_BUILDER + std::string GetInstanceString() const override + { + static_assert(ck_tile::reflect::HasInstanceTraits, + "Specialization of instance_traits not found. Please check that a " + "specialization exists in file " + "ck_tile/builder/reflect/" + "instance_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp " + "for the given template parameters."); + return ck_tile::reflect::instance_string(); + } + + std::unique_ptr describe() const override + { + return std::make_unique(GetInstanceString()); + } +#endif + }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index e975534a064..e56c91c2dff 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -30,6 +30,11 @@ #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/flush_cache.hpp" +#ifdef CK_EXPERIMENTAL_BUILDER +#include "ck_tile/builder/reflect/description.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp" +#endif + namespace ck { namespace tensor_operation { namespace device { @@ -2074,6 +2079,25 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle "The argument pointer is not an object of " "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument structure!"); } + +#ifdef CK_EXPERIMENTAL_BUILDER + std::string GetInstanceString() const override + { + static_assert(ck_tile::reflect::HasInstanceTraits, + "Specialization of instance_traits not found. Please check that a " + "specialization exists in file " + "ck_tile/builder/reflect/" + "instance_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp " + "for the given template parameters."); + return ck_tile::reflect::instance_string(); + } + + std::unique_ptr describe() const override + { + return std::make_unique(GetInstanceString()); + } +#endif + }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp index c50940da419..a07d05f798b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp @@ -19,6 +19,11 @@ #include "ck/host_utility/kernel_launch.hpp" #include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp" +#ifdef CK_EXPERIMENTAL_BUILDER +#include "ck_tile/builder/reflect/description.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp" +#endif + namespace ck { namespace tensor_operation { namespace device { @@ -865,6 +870,25 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle return str.str(); } + +#ifdef CK_EXPERIMENTAL_BUILDER + std::string GetInstanceString() const override + { + static_assert(ck_tile::reflect::HasInstanceTraits, + "Specialization of instance_traits not found. Please check that a " + "specialization exists in file " + "ck_tile/builder/reflect/" + "instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp " + "for the given template parameters."); + return ck_tile::reflect::instance_string(); + } + + std::unique_ptr describe() const override + { + return std::make_unique(GetInstanceString()); + } +#endif + }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp index 1ab6bc446f8..d83174ad65b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp @@ -31,6 +31,11 @@ #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/flush_cache.hpp" +#ifdef CK_EXPERIMENTAL_BUILDER +#include "ck_tile/builder/reflect/description.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp" +#endif + namespace ck { namespace tensor_operation { namespace device { @@ -1422,6 +1427,25 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 "The argument pointer is not an object of " "DeviceGroupedConvBwdWeight_Wmma_CShuffleV3::Argument structure!"); } + +#ifdef CK_EXPERIMENTAL_BUILDER + std::string GetInstanceString() const override + { + static_assert(ck_tile::reflect::HasInstanceTraits, + "Specialization of instance_traits not found. Please check that a " + "specialization exists in file " + "ck_tile/builder/reflect/" + "instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp " + "for the given template parameters."); + return ck_tile::reflect::instance_string(); + } + + std::unique_ptr describe() const override + { + return std::make_unique(GetInstanceString()); + } +#endif + }; } // namespace device From adfab9db7e648fd45ceed49ae9b59a8947cfc03f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Tue, 30 Dec 2025 05:52:15 -0500 Subject: [PATCH 31/81] Add unit tests for instance strings. --- ...instance_string_bwd_weight_grp_conv_dl.cpp | 79 ++++++++++++++++ ...bwd_weight_grp_conv_multiple_d_wmma_v3.cpp | 85 +++++++++++++++++ ...ing_bwd_weight_grp_conv_multiple_d_xdl.cpp | 83 +++++++++++++++++ ..._bwd_weight_grp_conv_two_stage_wmma_v3.cpp | 88 ++++++++++++++++++ ...ring_bwd_weight_grp_conv_two_stage_xdl.cpp | 88 ++++++++++++++++++ ...nce_string_bwd_weight_grp_conv_wmma_v3.cpp | 91 +++++++++++++++++++ ...ance_string_bwd_weight_grp_conv_xdl_v3.cpp | 85 +++++++++++++++++ 7 files changed, 599 insertions(+) create mode 100644 experimental/builder/test/test_instance_string_bwd_weight_grp_conv_dl.cpp create mode 100644 experimental/builder/test/test_instance_string_bwd_weight_grp_conv_multiple_d_wmma_v3.cpp create mode 100644 experimental/builder/test/test_instance_string_bwd_weight_grp_conv_multiple_d_xdl.cpp create mode 100644 experimental/builder/test/test_instance_string_bwd_weight_grp_conv_two_stage_wmma_v3.cpp create mode 100644 experimental/builder/test/test_instance_string_bwd_weight_grp_conv_two_stage_xdl.cpp create mode 100644 experimental/builder/test/test_instance_string_bwd_weight_grp_conv_wmma_v3.cpp create mode 100644 experimental/builder/test/test_instance_string_bwd_weight_grp_conv_xdl_v3.cpp diff --git a/experimental/builder/test/test_instance_string_bwd_weight_grp_conv_dl.cpp b/experimental/builder/test/test_instance_string_bwd_weight_grp_conv_dl.cpp new file mode 100644 index 00000000000..9758e39901e --- /dev/null +++ b/experimental/builder/test/test_instance_string_bwd_weight_grp_conv_dl.cpp @@ -0,0 +1,79 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/conv_description.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp" + +namespace { + +namespace ckr = ck_tile::reflect; + +// Use the first instance from device_grouped_conv_bwd_weight_dl_f16_instances +using InstanceTuple = ck::tensor_operation::device::instance:: + device_grouped_conv_bwd_weight_dl_f16_instances< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // ALayout (InLayout) + ck::tensor_layout::convolution::GKYXC, // BLayout (WeiLayout) + ck::tensor_layout::convolution::GNHWK, // ELayout (OutLayout) + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default>; + +using DeviceInstance = typename std::tuple_element<0, InstanceTuple>::type; + +// Expected string based on the generic instance +std::string expected_str = "DeviceGroupedConvBwdWeight_Dl" + "<2" // NDimSpatial + ",GNHWC" // InLayout + ",GKYXC" // WeiLayout + ",GNHWK" // OutLayout + ",fp16" // InDataType + ",fp16" // WeiDataType + ",fp16" // OutDataType + ",fp32" // AccDataType + ",PassThrough" // InElementwiseOperation + ",PassThrough" // WeiElementwiseOperation + ",PassThrough" // OutElementwiseOperation + ",Default" // ConvBackwardWeightSpecialization + ",256" // BlockSize + ",128" // MPerBlock + ",128" // NPerBlock + ",16" // K0PerBlock + ",1" // K1 + ",4" // M1PerThread + ",4" // N1PerThread + ",1" // KPerThread + ",Seq(8,2)" // M1N1ThreadClusterM1Xs + ",Seq(8,2)" // M1N1ThreadClusterN1Xs + ",Seq(1,8,1,1,1)" // ABlockTransferThreadSliceLengths_K0_M0_M1_K1 + ",Seq(1,2,1,128,1)" // ABlockTransferThreadClusterLengths_K0_M0_M1_K1 + ",Seq(0,2,3,1,4)" // ABlockTransferThreadClusterArrangeOrder + ",Seq(0,2,3,1,4)" // ABlockTransferSrcAccessOrder + ",Seq(1,1,1,1,1)" // ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 + ",Seq(0,2,3,1,4)" // ABlockTransferSrcVectorTensorContiguousDimOrder + ",Seq(1,1,1,1,1)" // ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 + ",Seq(1,1,1,8,1)" // BBlockTransferThreadSliceLengths_K0_N0_N1_K1 + ",Seq(1,16,1,16,1)" // BBlockTransferThreadClusterLengths_K0_N0_N1_K1 + ",Seq(0,1,4,2,3)" // BBlockTransferThreadClusterArrangeOrder + ",Seq(0,1,4,2,3)" // BBlockTransferSrcAccessOrder + ",Seq(1,1,1,1,1)" // BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 + ",Seq(0,1,4,2,3)" // BBlockTransferSrcVectorTensorContiguousDimOrder + ",Seq(1,1,1,1,1)" // BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 + ",Seq(0,1,2,3,4,5)" // CThreadTransferSrcDstAccessOrder + ",5" // CThreadTransferSrcDstVectorDim + ",1" // CThreadTransferDstScalarPerVector + ">"; + +TEST(InstanceString, DescribeReturnsCorrectValueForBwdWeightGrpConvDl) +{ + using BaseClass = ck::tensor_operation::device::BaseOperator; + DeviceInstance device_instance; + BaseClass* base_ptr = &device_instance; + + auto desc = base_ptr->describe(); + ASSERT_NE(desc, nullptr); + EXPECT_EQ(desc->instance_string(), expected_str); +} + +} // namespace diff --git a/experimental/builder/test/test_instance_string_bwd_weight_grp_conv_multiple_d_wmma_v3.cpp b/experimental/builder/test/test_instance_string_bwd_weight_grp_conv_multiple_d_wmma_v3.cpp new file mode 100644 index 00000000000..d610ea666da --- /dev/null +++ b/experimental/builder/test/test_instance_string_bwd_weight_grp_conv_multiple_d_wmma_v3.cpp @@ -0,0 +1,85 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/conv_description.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_scale_instance.hpp" + +namespace { + +namespace ckr = ck_tile::reflect; + +// Use the first instance from device_grouped_conv_bwd_weight_wmma_c_shuffle_f16_scale_instances +using InstanceTuple = ck::tensor_operation::device::instance:: + device_grouped_conv_bwd_weight_wmma_c_shuffle_f16_scale_instances< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // ALayout (InLayout) + ck::tensor_layout::convolution::GKYXC, // BLayout (WeiLayout) + ck::tensor_layout::convolution::GNHWK, // ELayout (OutLayout) + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default>; + +using DeviceInstance = typename std::tuple_element<0, InstanceTuple>::type; + +// Expected string based on the generic instance +std::string expected_str = "DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3" + "<2" // NDimSpatial + ",GNHWC" // InLayout + ",GKYXC" // WeiLayout + ",GNHWK" // OutLayout + ",EmptyTuple" // DsLayout + ",fp16" // InDataType + ",fp16" // WeiDataType + ",fp16" // OutDataType + ",fp32" // AccDataType + ",EmptyTuple" // DsDataType + ",PassThrough" // InElementwiseOperation + ",Scale" // WeiElementwiseOperation + ",PassThrough" // OutElementwiseOperation + ",Default" // ConvBackwardWeightSpecialization + ",64" // BlockSize + ",64" // MPerBlock + ",64" // NPerBlock + ",32" // KPerBlock + ",8" // ABK1 + ",16" // MPerWmma + ",16" // NPerWmma + ",4" // MRepeat + ",2" // NRepeat + ",Seq(4,8,1)" // ABlockTransferThreadClusterLengths_AK0_M_AK1 + ",Seq(2,0,1)" // ABlockTransferThreadClusterArrangeOrder + ",Seq(1,0,2)" // ABlockTransferSrcAccessOrder + ",1" // ABlockTransferSrcVectorDim + ",2" // ABlockTransferSrcScalarPerVector + ",4" // ABlockTransferDstScalarPerVector_AK1 + ",true" // ABlockLdsAddExtraM + ",Seq(4,8,1)" // BBlockTransferThreadClusterLengths_BK0_N_BK1 + ",Seq(2,0,1)" // BBlockTransferThreadClusterArrangeOrder + ",Seq(1,0,2)" // BBlockTransferSrcAccessOrder + ",1" // BBlockTransferSrcVectorDim + ",2" // BBlockTransferSrcScalarPerVector + ",4" // BBlockTransferDstScalarPerVector_BK1 + ",true" // BBlockLdsAddExtraN + ",1" // CShuffleMRepeatPerShuffle + ",1" // CShuffleNRepeatPerShuffle + ",Seq(1,16,1,4)" // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + ",2" // CShuffleBlockTransferScalarPerVector_NPerBlock + ",Intrawave" // BlkGemmPipeSched + ",v1" // BlkGemmPipelineVer + ",fp16" // ComputeTypeA + ",fp16" // ComputeTypeB + ">"; + +TEST(InstanceString, DescribeReturnsCorrectValueForBwdWeightGrpConvMultipleDWmmaV3) +{ + using BaseClass = ck::tensor_operation::device::BaseOperator; + DeviceInstance device_instance; + BaseClass* base_ptr = &device_instance; + + auto desc = base_ptr->describe(); + ASSERT_NE(desc, nullptr); + EXPECT_EQ(desc->instance_string(), expected_str); +} + +} // namespace diff --git a/experimental/builder/test/test_instance_string_bwd_weight_grp_conv_multiple_d_xdl.cpp b/experimental/builder/test/test_instance_string_bwd_weight_grp_conv_multiple_d_xdl.cpp new file mode 100644 index 00000000000..d9a6d2fbdf8 --- /dev/null +++ b/experimental/builder/test/test_instance_string_bwd_weight_grp_conv_multiple_d_xdl.cpp @@ -0,0 +1,83 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/conv_description.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_scale_instance.hpp" + +namespace { + +namespace ckr = ck_tile::reflect; + +// Use the first instance from device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_scale_instances +using InstanceTuple = ck::tensor_operation::device::instance:: + device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_scale_instances< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // ALayout (InLayout) + ck::tensor_layout::convolution::GKYXC, // BLayout (WeiLayout) + ck::tensor_layout::convolution::GNHWK, // ELayout (OutLayout) + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default>; + +using DeviceInstance = typename std::tuple_element<0, InstanceTuple>::type; + +// Expected string based on the generic instance +std::string expected_str = "DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle" + "<2" // NDimSpatial + ",GNHWC" // InLayout + ",GKYXC" // WeiLayout + ",GNHWK" // OutLayout + ",EmptyTuple" // DsLayout + ",fp16" // InDataType + ",fp16" // WeiDataType + ",fp16" // OutDataType + ",fp32" // AccDataType + ",EmptyTuple" // DsDataType + ",PassThrough" // InElementwiseOperation + ",Scale" // WeiElementwiseOperation + ",PassThrough" // OutElementwiseOperation + ",Default" // ConvBackwardWeightSpecialization + ",64" // BlockSize + ",64" // MPerBlock + ",64" // NPerBlock + ",4" // K0PerBlock + ",8" // K1 + ",32" // MPerXDL + ",32" // NPerXDL + ",2" // MXdlPerWave + ",2" // NXdlPerWave + ",Seq(1,4,8,2)" // ABlockTransferThreadClusterLengths_K0_M_K1 + ",Seq(0,3,1,2)" // ABlockTransferThreadClusterArrangeOrder + ",Seq(0,2,1,3)" // ABlockTransferSrcAccessOrder + ",2" // ABlockTransferSrcVectorDim + ",2" // ABlockTransferSrcScalarPerVector + ",4" // ABlockTransferDstScalarPerVector_K1 + ",true" // ABlockLdsAddExtraM + ",Seq(1,4,8,2)" // BBlockTransferThreadClusterLengths_K0_N_K1 + ",Seq(0,3,1,2)" // BBlockTransferThreadClusterArrangeOrder + ",Seq(0,2,1,3)" // BBlockTransferSrcAccessOrder + ",2" // BBlockTransferSrcVectorDim + ",2" // BBlockTransferSrcScalarPerVector + ",4" // BBlockTransferDstScalarPerVector_K1 + ",true" // BBlockLdsAddExtraN + ",1" // CShuffleMXdlPerWavePerShuffle + ",1" // CShuffleNXdlPerWavePerShuffle + ",Seq(1,16,1,4)" // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + ",2" // CBlockTransferScalarPerVector_NWaveNPerXdl + ",fp16" // ComputeTypeA + ",fp16" // ComputeTypeB + ">"; + +TEST(InstanceString, DescribeReturnsCorrectValueForBwdWeightGrpConvMultipleDXdl) +{ + using BaseClass = ck::tensor_operation::device::BaseOperator; + DeviceInstance device_instance; + BaseClass* base_ptr = &device_instance; + + auto desc = base_ptr->describe(); + ASSERT_NE(desc, nullptr); + EXPECT_EQ(desc->instance_string(), expected_str); +} + +} // namespace diff --git a/experimental/builder/test/test_instance_string_bwd_weight_grp_conv_two_stage_wmma_v3.cpp b/experimental/builder/test/test_instance_string_bwd_weight_grp_conv_two_stage_wmma_v3.cpp new file mode 100644 index 00000000000..476ec7bb639 --- /dev/null +++ b/experimental/builder/test/test_instance_string_bwd_weight_grp_conv_two_stage_wmma_v3.cpp @@ -0,0 +1,88 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/conv_description.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_wmma_instance.hpp" + +namespace { + +namespace ckr = ck_tile::reflect; + +// Use the first instance from device_grouped_conv_bwd_weight_two_stage_nhwgc_wmma_c_shuffle_f16_instances +using InstanceTuple = ck::tensor_operation::device::instance:: + device_grouped_conv_bwd_weight_two_stage_nhwgc_wmma_c_shuffle_f16_instances< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // ALayout (InLayout) + ck::tensor_layout::convolution::GKYXC, // BLayout (WeiLayout) + ck::tensor_layout::convolution::GNHWK, // ELayout (OutLayout) + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1>; + +using DeviceInstance = typename std::tuple_element<0, InstanceTuple>::type; + +// Expected string based on the first instance (BlockSize=32, MPerBlock=16, NPerBlock=16, etc.) +std::string expected_str = "DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3" + "<2" // NDimSpatial + ",GNHWC" // InLayout + ",GKYXC" // WeiLayout + ",GNHWK" // OutLayout + ",fp16" // InDataType + ",fp16" // WeiDataType + ",fp16" // OutDataType + ",fp32" // AccDataType + ",PassThrough" // InElementwiseOperation + ",PassThrough" // WeiElementwiseOperation + ",PassThrough" // OutElementwiseOperation + ",Default" // ConvBackwardWeightSpecialization + ",32" // BlockSize + ",16" // MPerBlock + ",16" // NPerBlock + ",32" // KPerBlock + ",8" // ABK1 + ",16" // MPerWmma + ",16" // NPerWmma + ",1" // MRepeat + ",1" // NRepeat + ",Seq(4,8,1)" // ABlockTransferThreadClusterLengths_AK0_M_AK1 + ",Seq(2,0,1)" // ABlockTransferThreadClusterArrangeOrder + ",Seq(1,0,2)" // ABlockTransferSrcAccessOrder + ",1" // ABlockTransferSrcVectorDim + ",1" // ABlockTransferSrcScalarPerVector + ",4" // ABlockTransferDstScalarPerVector_AK1 + ",false" // ABlockLdsAddExtraM + ",Seq(4,8,1)" // BBlockTransferThreadClusterLengths_BK0_N_BK1 + ",Seq(2,0,1)" // BBlockTransferThreadClusterArrangeOrder + ",Seq(1,0,2)" // BBlockTransferSrcAccessOrder + ",1" // BBlockTransferSrcVectorDim + ",1" // BBlockTransferSrcScalarPerVector + ",4" // BBlockTransferDstScalarPerVector_BK1 + ",false" // BBlockLdsAddExtraN + ",1" // CShuffleMRepeatPerShuffle + ",1" // CShuffleNRepeatPerShuffle + ",Seq(1,4,1,8)" // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + ",1" // CShuffleBlockTransferScalarPerVector_NPerBlock + ",Intrawave" // BlkGemmPipeSched + ",v1" // BlkGemmPipelineVer + ",1" // NumGroupsToMerge + ",fp16" // ComputeTypeA + ",fp16" // ComputeTypeB + ",1" // TransposeTransferSrcScalarPerVector + ",1" // TransposeTransferDstScalarPerVector + ">"; + +TEST(InstanceString, DescribeReturnsCorrectValueForBwdWeightGrpConvTwoStageWmmaV3) +{ + using BaseClass = ck::tensor_operation::device::BaseOperator; + DeviceInstance device_instance; + BaseClass* base_ptr = &device_instance; + + auto desc = base_ptr->describe(); + ASSERT_NE(desc, nullptr); + EXPECT_EQ(desc->instance_string(), expected_str); +} + +} // namespace diff --git a/experimental/builder/test/test_instance_string_bwd_weight_grp_conv_two_stage_xdl.cpp b/experimental/builder/test/test_instance_string_bwd_weight_grp_conv_two_stage_xdl.cpp new file mode 100644 index 00000000000..7393757e4fb --- /dev/null +++ b/experimental/builder/test/test_instance_string_bwd_weight_grp_conv_two_stage_xdl.cpp @@ -0,0 +1,88 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/conv_description.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" + +namespace { + +namespace ckr = ck_tile::reflect; + +// Use the first instance from device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_generic_instances +using InstanceTuple = ck::tensor_operation::device::instance:: + device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_generic_instances< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // ALayout (InLayout) + ck::tensor_layout::convolution::GKYXC, // BLayout (WeiLayout) + ck::tensor_layout::convolution::GNHWK, // ELayout (OutLayout) + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1>; + +using DeviceInstance = typename std::tuple_element<0, InstanceTuple>::type; + +// Expected string based on the first instance in the tuple +std::string expected_str = "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle" + "<2" // NDimSpatial + ",GNHWC" // InLayout + ",GKYXC" // WeiLayout + ",GNHWK" // OutLayout + ",fp16" // InDataType + ",fp16" // WeiDataType + ",fp16" // OutDataType + ",fp32" // AccDataType + ",PassThrough" // InElementwiseOperation + ",PassThrough" // WeiElementwiseOperation + ",PassThrough" // OutElementwiseOperation + ",Default" // ConvBackwardWeightSpecialization + ",64" // BlockSize + ",16" // MPerBlock + ",16" // NPerBlock + ",32" // KPerBlock + ",8" // K1 + ",16" // MPerXDL + ",16" // NPerXDL + ",1" // MXdlPerWave + ",1" // NXdlPerWave + ",Seq(4,8,1)" // ABlockTransferThreadClusterLengths_K0_M_K1 + ",Seq(2,0,1)" // ABlockTransferThreadClusterArrangeOrder + ",Seq(1,0,2)" // ABlockTransferSrcAccessOrder + ",1" // ABlockTransferSrcVectorDim + ",1" // ABlockTransferSrcScalarPerVector + ",4" // ABlockTransferDstScalarPerVector_K1 + ",false" // ABlockLdsAddExtraM + ",Seq(4,8,1)" // BBlockTransferThreadClusterLengths_K0_N_K1 + ",Seq(2,0,1)" // BBlockTransferThreadClusterArrangeOrder + ",Seq(1,0,2)" // BBlockTransferSrcAccessOrder + ",1" // BBlockTransferSrcVectorDim + ",1" // BBlockTransferSrcScalarPerVector + ",4" // BBlockTransferDstScalarPerVector_K1 + ",false" // BBlockLdsAddExtraN + ",1" // CShuffleMXdlPerWavePerShuffle + ",1" // CShuffleNXdlPerWavePerShuffle + ",Seq(1,8,1,8)" // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + ",1" // CBlockTransferScalarPerVector_NWaveNPerXdl + ",Intrawave" // BlkGemmPipeSched + ",v1" // BlkGemmPipelineVer + ",1" // NumGroupsToMerge + ",fp16" // ComputeTypeA + ",fp16" // ComputeTypeB + ",1" // TransposeTransferSrcScalarPerVector + ",1" // TransposeTransferDstScalarPerVector + ">"; + +TEST(InstanceString, DescribeReturnsCorrectValueForBwdWeightGrpConvTwoStageXdl) +{ + using BaseClass = ck::tensor_operation::device::BaseOperator; + DeviceInstance device_instance; + BaseClass* base_ptr = &device_instance; + + auto desc = base_ptr->describe(); + ASSERT_NE(desc, nullptr); + EXPECT_EQ(desc->instance_string(), expected_str); +} + +} // namespace diff --git a/experimental/builder/test/test_instance_string_bwd_weight_grp_conv_wmma_v3.cpp b/experimental/builder/test/test_instance_string_bwd_weight_grp_conv_wmma_v3.cpp new file mode 100644 index 00000000000..1b7f599e807 --- /dev/null +++ b/experimental/builder/test/test_instance_string_bwd_weight_grp_conv_wmma_v3.cpp @@ -0,0 +1,91 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +//#ifdef _NOT_DEFINED_ + +#include +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/conv_description.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_instance.hpp" + +namespace { + +namespace ckr = ck_tile::reflect; + +using InstanceTuple = ck::tensor_operation::device::instance:: + device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_f16_instances< + 2, // NDimSpatial + ck::tensor_operation::device::instance::NHWGC, // InLayout + ck::tensor_operation::device::instance::GKYXC, // WeiLayout + ck::tensor_operation::device::instance::NHWGK, // OutLayout + ck::tensor_operation::device::instance:: + ConvBwdWeightDefault>; + +// Expected complete instance string +std::string expected_str = "DeviceGroupedConvBwdWeight_Wmma_CShuffleV3" + "<2" // NDimSpatial + ",NHWGC" // InLayout + ",GKYXC" // WeiLayout + ",NHWGK" // OutLayout + ",fp16" // InDataType + ",fp16" // WeiDataType + ",fp16" // OutDataType + ",fp32" // AccDataType + ",PassThrough" // InElementwiseOperation + ",PassThrough" // WeiElementwiseOperation + ",PassThrough" // OutElementwiseOperation + ",Default" // ConvBackwardWeightSpecialization + ",64" // BlockSize + ",32" // MPerBlock + ",32" // NPerBlock + ",32" // KPerBlock + ",8" // ABK1 + ",16" // MPerWmma + ",16" // NPerWmma + ",2" // MRepeat + ",1" // NRepeat + ",Seq(4,8,1)" // ABlockTransferThreadClusterLengths_AK0_M_AK1 + ",Seq(2,0,1)" // ABlockTransferThreadClusterArrangeOrder + ",Seq(1,0,2)" // ABlockTransferSrcAccessOrder + ",1" // ABlockTransferSrcVectorDim + ",2" // ABlockTransferSrcScalarPerVector + ",2" // ABlockTransferDstScalarPerVector_AK1 + ",false" // ABlockLdsAddExtraM + ",Seq(4,16,1)" // BBlockTransferThreadClusterLengths_BK0_N_BK1 + ",Seq(2,0,1)" // BBlockTransferThreadClusterArrangeOrder + ",Seq(1,0,2)" // BBlockTransferSrcAccessOrder + ",1" // BBlockTransferSrcVectorDim + ",2" // BBlockTransferSrcScalarPerVector + ",2" // BBlockTransferDstScalarPerVector_BK1 + ",false" // BBlockLdsAddExtraN + ",1" // CShuffleMRepeatPerShuffle + ",1" // CShuffleNRepeatPerShuffle + ",Seq(1,8,1,8)" // CShuffleBlockTransferClusterLengths + ",2" // CShuffleBlockTransferScalarPerVector_NPerBlock + ",Intrawave" // BlkGemmPipeSched + ",v1" // BlkGemmPipelineVer + ",fp16" // ComputeTypeA + ",fp16" // ComputeTypeB + ",1" // MaxTransposeTransferSrcScalarPerVector + ",1" // MaxTransposeTransferDstScalarPerVector + ">"; + +// Get the first instance from the tuple +using DeviceInstance = typename std::tuple_element<0, InstanceTuple>::type; + +// Test describe() through base class pointer for WMMA V3 variant +TEST(InstanceString, DescribeReturnsCorrectValueForBwdWeightGrpConvWmmaV3) +{ + using BaseClass = ck::tensor_operation::device::BaseOperator; + DeviceInstance device_instance; + BaseClass* base_ptr = &device_instance; + + auto desc = base_ptr->describe(); + ASSERT_NE(desc, nullptr); + EXPECT_EQ(desc->instance_string(), expected_str); +} + +} // namespace + +//#endif diff --git a/experimental/builder/test/test_instance_string_bwd_weight_grp_conv_xdl_v3.cpp b/experimental/builder/test/test_instance_string_bwd_weight_grp_conv_xdl_v3.cpp new file mode 100644 index 00000000000..d3b6acbd140 --- /dev/null +++ b/experimental/builder/test/test_instance_string_bwd_weight_grp_conv_xdl_v3.cpp @@ -0,0 +1,85 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/conv_description.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace { + +namespace ckr = ck_tile::reflect; + +using InstanceTuple = ck::tensor_operation::device::instance:: + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f16_instances< + 2, // NDimSpatial + ck::tensor_operation::device::instance::GNHWC, // InLayout + ck::tensor_operation::device::instance::GKYXC, // WeiLayout + ck::tensor_operation::device::instance::GNHWK, // OutLayout + ck::tensor_operation::device::instance::ConvBwdWeightDefault, // ConvBwdWeightSpecialization + ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched + ck::BlockGemmPipelineVersion::v1>; // BlkGemmPipelineVer + +using DeviceInstance = typename std::tuple_element<0, InstanceTuple>::type; + +// Expected complete instance string based on the generic instance +std::string expected_str = "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3" + "<2" // NDimSpatial + ",GNHWC" // InLayout + ",GKYXC" // WeiLayout + ",GNHWK" // OutLayout + ",fp16" // InDataType + ",fp16" // WeiDataType + ",fp16" // OutDataType + ",fp32" // AccDataType + ",PassThrough" // InElementwiseOperation + ",PassThrough" // WeiElementwiseOperation + ",PassThrough" // OutElementwiseOperation + ",Default" // ConvBackwardWeightSpecialization + ",64" // BlockSize + ",32" // MPerBlock + ",32" // NPerBlock + ",32" // K0PerBlock + ",8" // K1 + ",32" // MPerXDL + ",32" // NPerXDL + ",1" // MXdlPerWave + ",1" // NXdlPerWave + ",Seq(4,8,1)" // ABlockTransferThreadClusterLengths_K0_M_K1 + ",Seq(2,0,1)" // ABlockTransferThreadClusterArrangeOrder + ",Seq(1,0,2)" // ABlockTransferSrcAccessOrder + ",1" // ABlockTransferSrcVectorDim + ",2" // ABlockTransferSrcScalarPerVector + ",2" // ABlockTransferDstScalarPerVector_K1 + ",false" // ABlockLdsAddExtraM + ",Seq(4,16,1)" // BBlockTransferThreadClusterLengths_K0_N_K1 + ",Seq(2,0,1)" // BBlockTransferThreadClusterArrangeOrder + ",Seq(1,0,2)" // BBlockTransferSrcAccessOrder + ",1" // BBlockTransferSrcVectorDim + ",2" // BBlockTransferSrcScalarPerVector + ",2" // BBlockTransferDstScalarPerVector_K1 + ",false" // BBlockLdsAddExtraN + ",1" // CShuffleMXdlPerWavePerShuffle + ",1" // CShuffleNXdlPerWavePerShuffle + ",Seq(1,8,1,8)" // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + ",2" // CBlockTransferScalarPerVector_NWaveNPerXdl + ",Intrawave" // BlkGemmPipeSched + ",v1" // BlkGemmPipelineVer + ",fp16" // ComputeTypeA + ",fp16" // ComputeTypeB + ">"; + +// Test describe() through base class pointer for XDL V3 variant +TEST(InstanceString, DescribeReturnsCorrectValueForBwdWeightGrpConvXdlV3) +{ + using BaseClass = ck::tensor_operation::device::BaseOperator; + DeviceInstance device_instance; + BaseClass* base_ptr = &device_instance; + + auto desc = base_ptr->describe(); + ASSERT_NE(desc, nullptr); + EXPECT_EQ(desc->instance_string(), expected_str); +} + +} // namespace From 30c10e2544af167664a777e9d463d8bb25848c5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Tue, 30 Dec 2025 05:52:55 -0500 Subject: [PATCH 32/81] Build new instance traits unit tests but exclude WMMA for now. --- experimental/builder/test/CMakeLists.txt | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 8105a41bf57..6a54edab9ae 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -110,7 +110,7 @@ add_ck_builder_test(test_ckb_conv_description # Verifies that GetInstanceString() methods and other functions produce valid kernel code. # Tests various convolution types: # - Group convolution (v3, standard, large tensor, WMMA, DL variants) -# - Backward weight group convolution (XDL) +# - Backward weight group convolution (XDL standard, XDL v3, WMMA, DL, multiple D, two-stage variants) # Requires kernel compilation to validate the generated strings through the base class. add_ck_builder_test(test_ckb_instance_string test_instance_string_fwd_grp_conv_v3.cpp @@ -118,7 +118,16 @@ add_ck_builder_test(test_ckb_instance_string test_instance_string_fwd_grp_conv_large_tensor.cpp test_instance_string_fwd_grp_conv_wmma.cpp test_instance_string_fwd_grp_conv_dl.cpp - test_instance_string_bwd_weight_grp_conv_xdl.cpp) + test_instance_string_bwd_weight_grp_conv_xdl.cpp + test_instance_string_bwd_weight_grp_conv_dl.cpp + test_instance_string_bwd_weight_grp_conv_multiple_d_xdl.cpp + test_instance_string_bwd_weight_grp_conv_two_stage_xdl.cpp + test_instance_string_bwd_weight_grp_conv_xdl_v3.cpp + # WMMA variants do not yet compile for backward weight + # test_instance_string_bwd_weight_grp_conv_wmma_v3.cpp + # test_instance_string_bwd_weight_grp_conv_multiple_d_wmma_v3.cpp + # test_instance_string_bwd_weight_grp_conv_two_stage_wmma_v3.cpp + ) # Tests the forward convolution builder across multiple data types and dimensions. # Individual tests are split into separate files to enable parallel compilation. From 75710202ab3466a6eb8d63e2c6903e35ecce484f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Wed, 31 Dec 2025 04:32:28 -0500 Subject: [PATCH 33/81] Added factory for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle. --- .../builder/conv_algorithm_concepts.hpp | 17 +++ .../builder/conv_algorithm_diagnostics.hpp | 30 +++++ .../builder/factory/conv_algorithms.hpp | 71 +++++++++++- .../conv_bwd_weight_two_stage_xdl_factory.hpp | 106 ++++++++++++++++++ .../builder/factory/conv_dispatcher.hpp | 24 +--- .../builder/include/ck_tile/builder/types.hpp | 3 +- experimental/builder/test/CMakeLists.txt | 1 + ...conv_bwd_weight_two_stage_xdl_cshuffle.cpp | 44 ++++++++ .../test_ckb_conv_bwd_weight_xdl_cshuffle.cpp | 6 +- .../test/impl/conv_algorithm_types.hpp | 25 ++++- .../test/utils/conv_algorithm_type_utils.hpp | 10 ++ 11 files changed, 310 insertions(+), 27 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp create mode 100644 experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index 447bbdad5e3..d554f92422d 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -321,12 +321,29 @@ concept SpecifiesLargeTensorSupport = requires { requires T::specialization == ConvAlgorithmSpecialization::LARGE_TENSOR; }; +template +concept SpecifiesTwoStageSupport = requires { + { T::specialization } -> std::convertible_to; + requires T::specialization == ConvAlgorithmSpecialization::TWO_STAGE; +}; + +template +concept SpecifiesGenericInstance = !requires { + { T::specialization }; +}; + template concept SpecifiesTransposeTransfer = requires { { T::max_transpose_transfer_src_scalar_per_vector } -> SizeType; { T::max_transpose_transfer_dst_scalar_per_vector } -> SizeType; }; + +template +concept SpecifiesGemmBatchOptions = requires { + { T::num_conv_groups_to_merge } -> SizeType; +}; + /******************************************** */ /* DL-specific descriptors and requirements */ /******************************************** */ diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp index 6613d2d7367..74973552247 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp @@ -712,6 +712,36 @@ consteval auto detailed_diagnostic_SpecifiesLargeTensorSupport() -> std::string return msg; } +template +consteval auto detailed_diagnostic_SpecifiesTwoStageSupport() -> std::string { + std::string msg; + if constexpr (requires { T::specialization; }) { + using SpecType = decltype(T::specialization); + constexpr bool convertible = std::convertible_to; + msg += " → T::specialization: " + std::string(CHECK_MARK(convertible)) + + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; + + if constexpr (convertible) { + constexpr bool is_two_stage = (T::specialization == ConvAlgorithmSpecialization::TWO_STAGE); + msg += " → specialization == TWO_STAGE: " + std::string(CHECK_MARK(is_two_stage)) + "\n"; + } + } else { + msg += " → T::specialization: [✗] (missing member)\n"; + } + + return msg; +} + +template +consteval auto detailed_diagnostic_SpecifiesGenericInstance() -> std::string { + std::string msg; + if constexpr (requires { T::specialization; }) { + msg += " → T::specialization: [✗] (member should NOT exist for generic instance)\n"; + msg += " → This concept requires the absence of the specialization member\n"; + } + return msg; +} + template consteval auto detailed_diagnostic_SpecifiesTransposeTransfer() -> std::string { std::string msg; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index bf7f0248fda..312e746f8a7 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -290,6 +290,7 @@ struct BwdXdlV3Algorithm { CHECK_CONCEPT(T, SpecifiesGridwiseBwdXdlGemm) CHECK_CONCEPT(T, SpecifiesBwdWeightConvSpecialization) CHECK_CONCEPT(T, SpecifiesBlockGemm) + CHECK_CONCEPT(T, SpecifiesGenericInstance) static constexpr bool c1 = c_ConvAlgorithmDescriptor; static constexpr bool c2 = c_SpecifiesThreadBlock; @@ -300,9 +301,10 @@ struct BwdXdlV3Algorithm { static constexpr bool c7 = c_SpecifiesGridwiseBwdXdlGemm; static constexpr bool c8 = c_SpecifiesBwdWeightConvSpecialization; static constexpr bool c9 = c_SpecifiesBlockGemm; + static constexpr bool c10 = c_SpecifiesGenericInstance; static consteval bool is_valid() { - return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9; + return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && c10; } static consteval auto message() -> std::string { @@ -316,7 +318,58 @@ struct BwdXdlV3Algorithm { DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) + DIAGNOSTIC_LINE(SpecifiesGridwiseBwdXdlGemm) + DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization) + - DIAGNOSTIC_LINE(SpecifiesBlockGemm); + DIAGNOSTIC_LINE(SpecifiesBlockGemm) + + DIAGNOSTIC_LINE(SpecifiesGenericInstance); + } +}; + +template +struct BwdTwoStageXdlAlgorithm { + CHECK_CONCEPT(T, ConvAlgorithmDescriptor) + CHECK_CONCEPT(T, SpecifiesThreadBlock) + CHECK_CONCEPT(T, SpecifiesBlockTransfer) + CHECK_CONCEPT(T, SpecifiesLdsTransfer) + CHECK_CONCEPT(T, SpecifiesThreadClusterAccessOrder) + CHECK_CONCEPT(T, SpecifiesSourceAccessOrder) + CHECK_CONCEPT(T, SpecifiesGridwiseBwdXdlGemm) + CHECK_CONCEPT(T, SpecifiesBwdWeightConvSpecialization) + CHECK_CONCEPT(T, SpecifiesBlockGemm) + CHECK_CONCEPT(T, SpecifiesTransposeTransfer) + CHECK_CONCEPT(T, SpecifiesGemmBatchOptions) + CHECK_CONCEPT(T, SpecifiesTwoStageSupport) + + static constexpr bool c1 = c_ConvAlgorithmDescriptor; + static constexpr bool c2 = c_SpecifiesThreadBlock; + static constexpr bool c3 = c_SpecifiesBlockTransfer; + static constexpr bool c4 = c_SpecifiesLdsTransfer; + static constexpr bool c5 = c_SpecifiesThreadClusterAccessOrder; + static constexpr bool c6 = c_SpecifiesSourceAccessOrder; + static constexpr bool c7 = c_SpecifiesGridwiseBwdXdlGemm; + static constexpr bool c8 = c_SpecifiesBwdWeightConvSpecialization; + static constexpr bool c9 = c_SpecifiesBlockGemm; + static constexpr bool c10 = c_SpecifiesTransposeTransfer; + static constexpr bool c11 = c_SpecifiesGemmBatchOptions; + static constexpr bool c12 = c_SpecifiesTwoStageSupport; + + static consteval bool is_valid() { + return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && & c10 && c11 && c12; + } + + static consteval auto message() -> std::string { + return std::string("\n=== Backward two stage XDL Algorithm Diagnostic (closest match) ===\n" + "Concepts for BwdXdlV3 Algorithm:\n") + + DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + + DIAGNOSTIC_LINE(SpecifiesThreadBlock) + + DIAGNOSTIC_LINE(SpecifiesBlockTransfer) + + DIAGNOSTIC_LINE(SpecifiesLdsTransfer) + + DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) + + DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) + + DIAGNOSTIC_LINE(SpecifiesGridwiseBwdXdlGemm) + + DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization) + + DIAGNOSTIC_LINE(SpecifiesBlockGemm) + + DIAGNOSTIC_LINE(SpecifiesTransposeTransfer) + + DIAGNOSTIC_LINE(SpecifiesGemmBatchOptions) + + DIAGNOSTIC_LINE(SpecifiesTwoStageSupport); } }; @@ -356,6 +409,12 @@ consteval int count_matches_bwd_xdl_v3() { return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9; } +template +consteval int count_matches_bwd_two_stage_xdl() { + using Alg = BwdTwoStageXdlAlgorithm; + return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + Alg::c10 + Alg::c11 + Alg::c12; +} + template consteval int count_matches_large_tensor() { using Alg = LargeTensorAlgorithm; @@ -417,8 +476,10 @@ consteval void diagnose_bwd_weight_algorithm_signature() { constexpr int xdl_matches = count_matches_bwd_xdl(); constexpr int xdl_v3_matches = count_matches_fwd_xdl_v3(); + constexpr int two_stage_xdl_matches = count_matches_bwd_two_stage_xdl(); - constexpr int max_matches = xdl_v3_matches > xdl_matches ? xdl_v3_matches : xdl_matches; + constexpr int max1 = xdl_v3_matches > xdl_matches ? xdl_v3_matches : xdl_matches; + constexpr int max_matches = max1 > two_stage_xdl_matches ? max1 : two_stage_xdl_matches; if constexpr (max_matches == xdl_matches) { using Alg = BwdXdlAlgorithm; @@ -428,6 +489,10 @@ consteval void diagnose_bwd_weight_algorithm_signature() using Alg = BwdXdlV3Algorithm; static_assert(Alg::is_valid(), Alg::message()); } + else if constexpr (max_matches == two_stage_xdl_matches) { + using Alg = BwdTwoStageXdlAlgorithm; + static_assert(Alg::is_valid(), Alg::message()); + } else { // This should never happen static_assert(false, "Internal Error: No matching algorithm variant found for diagnostics."); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp new file mode 100644 index 00000000000..b9852127e86 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp @@ -0,0 +1,106 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_limits.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" + +namespace ck_tile::builder::factory { + +// Factory for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle instance +// of a grouped forward convolution kernel. +template + requires ConvDirectionIsBackwardWeight +struct ConvBwdWeightTwoStageXdlFactory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = internal::ConvTensorLayouts; + using Types = internal::BwdWeightConvTensorDataTypes; + using Ops = internal::ElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto BWD_CONV_SPECIALIZATION = internal::SetBwdWeightConvSpecialization(); + + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; + static constexpr auto A_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); + static constexpr auto BLOCK_GEMM = internal::SetBlockGemm(); + + // Check limits for the algorithm parameters. + // TODO: Add more limits checks as needed. + static_assert(InputVectorTransferLimits, "Invalid A block transfer config"); + static_assert(InputVectorTransferLimits, "Invalid B block transfer config"); + static_assert(OutputVectorTransferLimits, "Invalid C block transfer config"); + static_assert(AccessOrderLimits3D, "Invalid A thread cluster access order"); + static_assert(AccessOrderLimits3D, "Invalid B thread cluster access order"); + static_assert(AccessOrderLimits3D, "Invalid A source access order"); + static_assert(AccessOrderLimits3D, "Invalid B source access order"); + + // The forward convolution kernel class instance. + using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< + SPATIAL_DIM, + typename Layouts::InLayout, + typename Layouts::WeiLayout, + typename Layouts::OutLayout, + typename Types::InDataType, + typename Types::WeiDataType, + typename Types::OutDataType, + typename Types::AccDataType, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, + BWD_CONV_SPECIALIZATION, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.k1, + XDL_PARAMS.m_per_xdl, + XDL_PARAMS.n_per_xdl, + XDL_PARAMS.m_xdl_per_wave, + XDL_PARAMS.n_xdl_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + A_BLOCK_TRANSFER.lds_padding, + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + B_BLOCK_TRANSFER.lds_padding, + C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + BLOCK_GEMM.scheduler, + BLOCK_GEMM.pipeline_version, + ALGORITHM.num_conv_groups_to_merge, + typename Types::InComputeType, + typename Types::WeiComputeType, + ALGORITHM.max_transpose_transfer_src_scalar_per_vector, + ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp index 01c0fb9c56b..1812f1a0ff6 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -49,12 +49,6 @@ #pragma once -// Disable pragma message warnings for factory selection diagnostics -#ifdef __clang__ - #pragma clang diagnostic push - #pragma clang diagnostic ignored "-W#pragma-messages" -#endif - #include "ck_tile/builder/conv_signature_concepts.hpp" #include "ck_tile/builder/conv_algorithm_concepts.hpp" #include "ck_tile/builder/types.hpp" @@ -71,6 +65,7 @@ #include "ck_tile/builder/factory/conv_tile_factory.hpp" #include "ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp" #include "ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp" +#include "ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp" namespace ck_tile::builder::factory { @@ -103,34 +98,28 @@ constexpr auto make_conv_instance() // CK Tile supports common factory for each direction if constexpr(TileAlgorithm::is_valid()) { - #pragma message("[CK Builder] Using ConvTileFactory...") return typename ConvTileFactory::Instance{}; } else if constexpr(ConvDirectionIsForward) { if constexpr(FwdXdlV3Algorithm::is_valid()) { - #pragma message("[CK Builder] Using ConvFwdXdlV3Factory...") return typename ConvFwdXdlV3Factory::Instance{}; } else if constexpr(FwdXdlAlgorithm::is_valid()) { - #pragma message("[CK Builder] Using ConvFwdXdlFactory...") return typename ConvFwdXdlFactory::Instance{}; } else if constexpr(FwdWmmaAlgorithm::is_valid()) { - #pragma message("[CK Builder] Using ConvFwdWmmaFactory...") return typename ConvFwdWmmaFactory::Instance{}; } else if constexpr(FwdDlAlgorithm::is_valid()) { - #pragma message("[CK Builder] Using ConvFwdDlFactory...") return typename ConvFwdDlFactory::Instance{}; } else if constexpr(LargeTensorAlgorithm::is_valid()) { - #pragma message("[CK Builder] Using ConvFwdLargeTensorFactory...") return typename ConvFwdLargeTensorFactory::Instance{}; } else @@ -148,14 +137,16 @@ constexpr auto make_conv_instance() { if constexpr (BwdXdlAlgorithm::is_valid()) { - #pragma message("[CK Builder] Using ConvBwdWeightXdlFactory...") return typename ConvBwdWeightXdlFactory::Instance{}; } else if constexpr (BwdXdlV3Algorithm::is_valid()) { - #pragma message("[CK Builder] Using ConvBwdWeightXdlV3Factory...") return typename ConvBwdWeightXdlV3Factory::Instance{}; } + else if constexpr (BwdTwoStageXdlAlgorithm::is_valid()) + { + return typename ConvBwdWeightTwoStageXdlFactory::Instance{}; + } else { diagnose_bwd_weight_algorithm_signature(); @@ -171,8 +162,3 @@ constexpr auto make_conv_instance() } } // namespace ck_tile::builder::factory - -// Re-enable pragma message warnings -#ifdef __clang__ - #pragma clang diagnostic pop -#endif diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index ade9484640f..c44f3368aea 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -232,7 +232,8 @@ enum class PipelineScheduler enum class ConvAlgorithmSpecialization { - LARGE_TENSOR + LARGE_TENSOR, + TWO_STAGE }; // toString methods for enum classes diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 6a54edab9ae..a3d08e82edb 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -151,6 +151,7 @@ target_link_libraries(test_ckb_build_fwd_instances PRIVATE utility) add_ck_builder_test(test_ckb_build_bwd_weight_instances conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle.cpp + conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp ) diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp new file mode 100644 index 00000000000..9a8b9573fa6 --- /dev/null +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" +#include "ck_tile/host/device_prop.hpp" + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace cku = ck_tile::builder::test_utils; + +constexpr auto SIGNATURE = + ckt::ConvSignature{.spatial_dim = 2, + .direction = ckb::ConvDirection::BACKWARD_WEIGHT, + .data_type = ckb::DataType::BF16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = ckb::TensorLayout::GNHWC}}, + .weight = {.config = {.layout = ckb::TensorLayout::GKYXC}}, + .output = {.config = {.layout = ckb::TensorLayout::GNHWK}}}; + +constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle{} + .with_thread_block(cku::ThreadBlock_64_32x32x32) + .with_gemm_config(cku::BwdGemmParams_Xdl_1x1_per_wave) + .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) + .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT) + .with_block_gemm(cku::BlockGemmDesc_v2_intrawave) + .with_num_conv_groups_to_merge(2) + .with_transpose_params(2, 4); + +using Builder = ckb::ConvBuilder; +using Instance = Builder::Instance; + +TEST(BwdWeight_2DBf16_TwoStage_CShuffle, Create) +{ + const auto expected_transfer_parameters = to_string(ALGORITHM); + cku::run_test({"DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle", + expected_transfer_parameters, + "Default", + "GNHWC,GKYXC,GNHWK", + "PassThrough,PassThrough,PassThrough", + "Intrawave,v2", // pipeline versions + "bf16,bf16,2,4>"}); // compute types and transpose params +} diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle.cpp index ad11eba6938..892f1d35ef1 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle.cpp @@ -23,7 +23,8 @@ constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CSh .with_thread_block(cku::ThreadBlock_256_128x128x8) .with_gemm_config(cku::BwdGemmParams_Xdl_4x4_per_wave) .with_transfer(cku::BwdTransfer_4x64x1) - .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT); + .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT) + .with_transpose_params(2, 2); using Builder = ckb::ConvBuilder; using Instance = Builder::Instance; @@ -35,5 +36,6 @@ TEST(BwdWeight_2DFp16_CShuffle_GNHWC, Create) expected_transfer_parameters, "Default", "GNHWC,GKYXC,GNHWK", - "PassThrough,PassThrough,PassThrough"}); + "PassThrough,PassThrough,PassThrough", + "fp16,fp16,2,2>"}); // check compute types and transpose params } diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index b045d185e2a..d003440935f 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -243,6 +243,11 @@ struct TransposeParams_ size_t max_transpose_transfer_dst_scalar_per_vector{1}; }; +struct GemmBatchOptions_ +{ + size_t num_conv_groups_to_merge{1}; +}; + struct BlockGemm_ { BlockGemm block_gemm; @@ -280,6 +285,11 @@ struct DlTransfer_ DlTransferABC transfer; }; +struct TwoStageSpecialization_ +{ + static constexpr ConvAlgorithmSpecialization specialization = ConvAlgorithmSpecialization::TWO_STAGE; +}; + // Specialization wrapper for large tensor support template struct LargeTensorWrapper @@ -433,8 +443,8 @@ struct ConvAlgorithmTemplate : Components... return result; } - constexpr auto with_transpose_params(bool max_src_scalar_per_vector, - bool max_dst_scalar_per_vector) const + constexpr auto with_transpose_params(size_t max_src_scalar_per_vector, + size_t max_dst_scalar_per_vector) const { static_assert(std::is_base_of_v); auto result = *this; @@ -443,6 +453,14 @@ struct ConvAlgorithmTemplate : Components... return result; } + constexpr auto with_num_conv_groups_to_merge(size_t num_groups_to_merge) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.num_conv_groups_to_merge = num_groups_to_merge; + return result; + } + template constexpr auto with_block_gemm(const BG& bg) const { @@ -555,6 +573,9 @@ using ConvAlgorithm_Tile_GroupedConvolutionKernel = ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, TransposeParams_>; +using ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle = + ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, BlockGemm_, TransposeParams_, GemmBatchOptions_, TwoStageSpecialization_>; + using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 = ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, BlockGemm_>; diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index cf13f39391a..8f530600acc 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -397,4 +397,14 @@ inline std::string to_string +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); + return oss.str(); +} + } // namespace ck_tile::builder::test From 3b0777f629a8e3a87b816fc324a1b75aae57a7aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Wed, 31 Dec 2025 07:38:52 -0500 Subject: [PATCH 34/81] Conv bwd weight DL factory. --- .../builder/conv_algorithm_concepts.hpp | 36 +++-- .../builder/conv_algorithm_diagnostics.hpp | 116 ++++++++++++---- .../builder/factory/conv_algorithms.hpp | 55 +++++++- .../factory/conv_bwd_weight_dl_factory.hpp | 131 ++++++++++++++++++ .../builder/factory/conv_dispatcher.hpp | 5 + .../builder/factory/conv_fwd_dl_factory.hpp | 6 +- experimental/builder/test/CMakeLists.txt | 1 + .../conv/ck/test_ckb_conv_bwd_weight_dl.cpp | 39 ++++++ .../test/impl/conv_algorithm_types.hpp | 29 ++-- .../test/utils/ckb_conv_test_configs.hpp | 22 ++- .../test/utils/conv_algorithm_type_utils.hpp | 26 ++-- 11 files changed, 375 insertions(+), 91 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_dl_factory.hpp create mode 100644 experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_dl.cpp diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index d554f92422d..cbded4f8b05 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -366,17 +366,23 @@ concept DlThreadClusterDescriptor = requires(T t) { }; // Concept for DL block transfer -template +template concept DlBlockTransferDescriptor = requires(T t) { - { t.thread_slice_lengths } -> std::convertible_to>; - { t.thread_cluster_lengths } -> std::convertible_to>; - { t.thread_cluster_arrange_order } -> std::convertible_to>; - { t.src_access_order } -> std::convertible_to>; - { t.src_vector_tensor_lengths } -> std::convertible_to>; - { t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to>; - { t.dst_vector_tensor_lengths } -> std::convertible_to>; + { t.thread_slice_lengths } -> std::convertible_to>; + { t.thread_cluster_lengths } -> std::convertible_to>; + { t.thread_cluster_arrange_order } -> std::convertible_to>; + { t.src_access_order } -> std::convertible_to>; + { t.src_vector_tensor_lengths } -> std::convertible_to>; + { t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to>; + { t.dst_vector_tensor_lengths } -> std::convertible_to>; }; +template +concept DlBlockTransferDescriptor4D = DlBlockTransferDescriptor; + +template +concept DlBlockTransferDescriptor5D = DlBlockTransferDescriptor; + // Concept for DL epilogue template concept DlEpilogueDescriptor = requires(T t) { @@ -399,15 +405,21 @@ concept SpecifiesDlThreadCluster = requires { // Concept to check if algorithm specifies DL block transfer template -concept SpecifiesDlBlockTransfer = requires { - { T::transfer.a.block_transfer } -> DlBlockTransferDescriptor; - { T::transfer.b.block_transfer } -> DlBlockTransferDescriptor; +concept SpecifiesDlFwdBlockTransfer = requires { + { T::transfer.a } -> DlBlockTransferDescriptor4D; + { T::transfer.b } -> DlBlockTransferDescriptor4D; +}; + +template +concept SpecifiesDlBwdBlockTransfer = requires { + { T::transfer.a } -> DlBlockTransferDescriptor5D; + { T::transfer.b } -> DlBlockTransferDescriptor5D; }; // Concept to check if algorithm specifies DL C thread transfer template concept SpecifiesDlEpilogue = requires { - { T::transfer.c.epilogue } -> DlEpilogueDescriptor; + { T::transfer.c } -> DlEpilogueDescriptor; }; } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp index 74973552247..0ca4c340572 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp @@ -767,6 +767,18 @@ consteval auto detailed_diagnostic_SpecifiesTransposeTransfer() -> std::string { return msg; } +template +consteval auto detailed_diagnostic_SpecifiesGemmBatchOptions() -> std::string { + if constexpr (requires { T::num_conv_groups_to_merge; }) { + using NumGroupsType = decltype(T::num_conv_groups_to_merge); + constexpr bool convertible = std::convertible_to; + return " → T::num_conv_groups_to_merge: " + std::string(CHECK_MARK(convertible)) + + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; + } else { + return " → T::num_conv_groups_to_merge: [✗] (missing member)\n"; + } +} + template consteval auto detailed_diagnostic_SpecifiesGridwiseWmmaGemm() -> std::string { std::string msg; @@ -943,7 +955,7 @@ consteval auto detailed_diagnostic_SpecifiesDlThreadCluster() -> std::string { } template -consteval auto detailed_diagnostic_SpecifiesDlBlockTransfer() -> std::string { +consteval auto detailed_diagnostic_SpecifiesDlFwdBlockTransfer() -> std::string { std::string msg; constexpr bool has_transfer = requires { T::transfer; }; msg += " → T::transfer member: " + std::string(CHECK_MARK(has_transfer)) + "\n"; @@ -952,13 +964,12 @@ consteval auto detailed_diagnostic_SpecifiesDlBlockTransfer() -> std::string { return msg; } - constexpr bool has_a = requires { T::transfer.a; }; - constexpr bool has_b = requires { T::transfer.b; }; + constexpr bool has_a = requires { { T::transfer.a } -> DlBlockTransferDescriptor4D; }; + constexpr bool has_b = requires { { T::transfer.b } -> DlBlockTransferDescriptor4D; }; msg += " → T::transfer.a: " + std::string(CHECK_MARK(has_a)) + "\n"; - msg += " → T::transfer.b: " + std::string(CHECK_MARK(has_b)) + "\n"; - if constexpr (has_a && requires { T::transfer.a.block_transfer; }) { - using ABT = decltype(T::transfer.a.block_transfer); + if constexpr (has_a) { + using ABT = decltype(T::transfer.a); constexpr bool has_thread_slice = requires(ABT t) { { t.thread_slice_lengths } -> std::convertible_to>; }; constexpr bool has_thread_cluster = requires(ABT t) { { t.thread_cluster_lengths } -> std::convertible_to>; }; constexpr bool has_cluster_arrange = requires(ABT t) { { t.thread_cluster_arrange_order } -> std::convertible_to>; }; @@ -967,22 +978,69 @@ consteval auto detailed_diagnostic_SpecifiesDlBlockTransfer() -> std::string { constexpr bool has_src_contiguous = requires(ABT t) { { t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to>; }; constexpr bool has_dst_vector = requires(ABT t) { { t.dst_vector_tensor_lengths } -> std::convertible_to>; }; - msg += " → transfer.a.block_transfer.thread_slice_lengths: " + std::string(CHECK_MARK(has_thread_slice)) + (has_thread_slice ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.a.block_transfer.thread_cluster_lengths: " + std::string(CHECK_MARK(has_thread_cluster)) + (has_thread_cluster ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.a.block_transfer.thread_cluster_arrange_order: " + std::string(CHECK_MARK(has_cluster_arrange)) + (has_cluster_arrange ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.a.block_transfer.src_access_order: " + std::string(CHECK_MARK(has_src_access)) + (has_src_access ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.a.block_transfer.src_vector_tensor_lengths: " + std::string(CHECK_MARK(has_src_vector)) + (has_src_vector ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.a.block_transfer.src_vector_tensor_contiguous_dim_order: " + std::string(CHECK_MARK(has_src_contiguous)) + (has_src_contiguous ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.a.block_transfer.dst_vector_tensor_lengths: " + std::string(CHECK_MARK(has_dst_vector)) + (has_dst_vector ? "\n" : " (missing or wrong type)\n"); - } else if constexpr (has_a) { - msg += " → T::transfer.a.block_transfer: [✗] (missing)\n"; + msg += " → transfer.a.thread_slice_lengths (4D): " + std::string(CHECK_MARK(has_thread_slice)) + (has_thread_slice ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.a.thread_cluster_lengths (4D): " + std::string(CHECK_MARK(has_thread_cluster)) + (has_thread_cluster ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.a.thread_cluster_arrange_order (4D): " + std::string(CHECK_MARK(has_cluster_arrange)) + (has_cluster_arrange ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.a.src_access_order (4D): " + std::string(CHECK_MARK(has_src_access)) + (has_src_access ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.a.src_vector_tensor_lengths (4D): " + std::string(CHECK_MARK(has_src_vector)) + (has_src_vector ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.a.src_vector_tensor_contiguous_dim_order (4D): " + std::string(CHECK_MARK(has_src_contiguous)) + (has_src_contiguous ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.a.dst_vector_tensor_lengths (4D): " + std::string(CHECK_MARK(has_dst_vector)) + (has_dst_vector ? "\n" : " (missing or wrong type)\n"); + } else { + msg += " → T::transfer.a (4D): [✗] (missing or wrong type)\n"; } - // Similar checks for transfer.b - if constexpr (has_b && requires { T::transfer.b.block_transfer; }) { - msg += " → T::transfer.b.block_transfer: [✓] (similar fields as transfer.a)\n"; - } else if constexpr (has_b) { - msg += " → T::transfer.b.block_transfer: [✗] (missing)\n"; + msg += " → T::transfer.b: " + std::string(CHECK_MARK(has_b)) + "\n"; + + if constexpr (has_b) { + msg += " → T::transfer.b (4D): [✓] (similar fields as transfer.a)\n"; + } else { + msg += " → T::transfer.b (4D): [✗] (missing or wrong type)\n"; + } + + return msg; +} + +template +consteval auto detailed_diagnostic_SpecifiesDlBwdBlockTransfer() -> std::string { + std::string msg; + constexpr bool has_transfer = requires { T::transfer; }; + msg += " → T::transfer member: " + std::string(CHECK_MARK(has_transfer)) + "\n"; + + if constexpr (!has_transfer) { + return msg; + } + + constexpr bool has_a = requires { { T::transfer.a } -> DlBlockTransferDescriptor5D; }; + constexpr bool has_b = requires { { T::transfer.b } -> DlBlockTransferDescriptor5D; }; + msg += " → T::transfer.a: " + std::string(CHECK_MARK(has_a)) + "\n"; + + if constexpr (has_a) { + using ABT = decltype(T::transfer.a); + constexpr bool has_thread_slice = requires(ABT t) { { t.thread_slice_lengths } -> std::convertible_to>; }; + constexpr bool has_thread_cluster = requires(ABT t) { { t.thread_cluster_lengths } -> std::convertible_to>; }; + constexpr bool has_cluster_arrange = requires(ABT t) { { t.thread_cluster_arrange_order } -> std::convertible_to>; }; + constexpr bool has_src_access = requires(ABT t) { { t.src_access_order } -> std::convertible_to>; }; + constexpr bool has_src_vector = requires(ABT t) { { t.src_vector_tensor_lengths } -> std::convertible_to>; }; + constexpr bool has_src_contiguous = requires(ABT t) { { t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to>; }; + constexpr bool has_dst_vector = requires(ABT t) { { t.dst_vector_tensor_lengths } -> std::convertible_to>; }; + + msg += " → transfer.a.thread_slice_lengths (5D): " + std::string(CHECK_MARK(has_thread_slice)) + (has_thread_slice ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.a.thread_cluster_lengths (5D): " + std::string(CHECK_MARK(has_thread_cluster)) + (has_thread_cluster ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.a.thread_cluster_arrange_order (5D): " + std::string(CHECK_MARK(has_cluster_arrange)) + (has_cluster_arrange ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.a.src_access_order (5D): " + std::string(CHECK_MARK(has_src_access)) + (has_src_access ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.a.src_vector_tensor_lengths (5D): " + std::string(CHECK_MARK(has_src_vector)) + (has_src_vector ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.a.src_vector_tensor_contiguous_dim_order (5D): " + std::string(CHECK_MARK(has_src_contiguous)) + (has_src_contiguous ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.a.dst_vector_tensor_lengths (5D): " + std::string(CHECK_MARK(has_dst_vector)) + (has_dst_vector ? "\n" : " (missing or wrong type)\n"); + } else { + msg += " → T::transfer.a (5D): [✗] (missing or wrong type)\n"; + } + + msg += " → T::transfer.b: " + std::string(CHECK_MARK(has_b)) + "\n"; + + if constexpr (has_b) { + msg += " → T::transfer.b (5D): [✓] (similar fields as transfer.a)\n"; + } else { + msg += " → T::transfer.b (5D): [✗] (missing or wrong type)\n"; } return msg; @@ -999,17 +1057,17 @@ consteval auto detailed_diagnostic_SpecifiesDlEpilogue() -> std::string { constexpr bool has_c = requires { T::transfer.c; }; msg += " → T::transfer.c: " + std::string(CHECK_MARK(has_c)) + "\n"; - if constexpr (has_c && requires { T::transfer.c.epilogue; }) { - using E = decltype(T::transfer.c.epilogue); - constexpr bool has_src_dst_access = requires(E t) { { t.src_dst_access_order } -> std::convertible_to>; }; - constexpr bool has_src_dst_vector_dim = requires(E t) { { t.src_dst_vector_dim } -> std::convertible_to; }; - constexpr bool has_dst_scalar = requires(E t) { { t.dst_scalar_per_vector } -> std::convertible_to; }; + if constexpr (has_c && requires { T::transfer.c.src_dst_access_order; }) { + using C = decltype(T::transfer.c); + constexpr bool has_src_dst_access = requires(C t) { { t.src_dst_access_order } -> std::convertible_to>; }; + constexpr bool has_src_dst_vector_dim = requires(C t) { { t.src_dst_vector_dim } -> std::convertible_to; }; + constexpr bool has_dst_scalar = requires(C t) { { t.dst_scalar_per_vector } -> std::convertible_to; }; - msg += " → transfer.c.epilogue.src_dst_access_order: " + std::string(CHECK_MARK(has_src_dst_access)) + (has_src_dst_access ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.c.epilogue.src_dst_vector_dim: " + std::string(CHECK_MARK(has_src_dst_vector_dim)) + (has_src_dst_vector_dim ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.c.epilogue.dst_scalar_per_vector: " + std::string(CHECK_MARK(has_dst_scalar)) + (has_dst_scalar ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.c.src_dst_access_order: " + std::string(CHECK_MARK(has_src_dst_access)) + (has_src_dst_access ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.c.src_dst_vector_dim: " + std::string(CHECK_MARK(has_src_dst_vector_dim)) + (has_src_dst_vector_dim ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.c.dst_scalar_per_vector: " + std::string(CHECK_MARK(has_dst_scalar)) + (has_dst_scalar ? "\n" : " (missing or wrong type)\n"); } else if constexpr (has_c) { - msg += " → T::transfer.c.epilogue: [✗] (missing)\n"; + msg += " → T::transfer.c (DlEpilogue): [✗] (missing required fields)\n"; } return msg; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index 312e746f8a7..1ad877e878d 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -158,7 +158,7 @@ struct FwdDlAlgorithm { CHECK_CONCEPT(T, SpecifiesGemmSpecialization) CHECK_CONCEPT(T, SpecifiesDlThreadConfig) CHECK_CONCEPT(T, SpecifiesDlThreadCluster) - CHECK_CONCEPT(T, SpecifiesDlBlockTransfer) + CHECK_CONCEPT(T, SpecifiesDlFwdBlockTransfer) CHECK_CONCEPT(T, SpecifiesDlEpilogue) static constexpr bool c1 = c_ConvAlgorithmDescriptor; @@ -167,7 +167,7 @@ struct FwdDlAlgorithm { static constexpr bool c4 = c_SpecifiesGemmSpecialization; static constexpr bool c5 = c_SpecifiesDlThreadConfig; static constexpr bool c6 = c_SpecifiesDlThreadCluster; - static constexpr bool c7 = c_SpecifiesDlBlockTransfer; + static constexpr bool c7 = c_SpecifiesDlFwdBlockTransfer; static constexpr bool c8 = c_SpecifiesDlEpilogue; static consteval bool is_valid() { @@ -183,7 +183,7 @@ struct FwdDlAlgorithm { DIAGNOSTIC_LINE(SpecifiesGemmSpecialization) + DIAGNOSTIC_LINE(SpecifiesDlThreadConfig) + DIAGNOSTIC_LINE(SpecifiesDlThreadCluster) + - DIAGNOSTIC_LINE(SpecifiesDlBlockTransfer) + + DIAGNOSTIC_LINE(SpecifiesDlFwdBlockTransfer) + DIAGNOSTIC_LINE(SpecifiesDlEpilogue); } }; @@ -373,6 +373,41 @@ struct BwdTwoStageXdlAlgorithm { } }; +template +struct BwdDlAlgorithm { + CHECK_CONCEPT(T, ConvAlgorithmDescriptor) + CHECK_CONCEPT(T, SpecifiesThreadBlock) + CHECK_CONCEPT(T, SpecifiesBwdWeightConvSpecialization) + CHECK_CONCEPT(T, SpecifiesDlThreadConfig) + CHECK_CONCEPT(T, SpecifiesDlThreadCluster) + CHECK_CONCEPT(T, SpecifiesDlBwdBlockTransfer) + CHECK_CONCEPT(T, SpecifiesDlEpilogue) + + static constexpr bool c1 = c_ConvAlgorithmDescriptor; + static constexpr bool c2 = c_SpecifiesThreadBlock; + static constexpr bool c3 = c_SpecifiesBwdWeightConvSpecialization; + static constexpr bool c4 = c_SpecifiesDlThreadConfig; + static constexpr bool c5 = c_SpecifiesDlThreadCluster; + static constexpr bool c6 = c_SpecifiesDlBwdBlockTransfer; + static constexpr bool c7 = c_SpecifiesDlEpilogue; + + static consteval bool is_valid() { + return c1 && c2 && c3 && c4 && c5 && c6 && c7; + } + + static consteval auto message() -> std::string { + return std::string("\n=== Backward DL Algorithm Diagnostic (closest match) ===\n" + "Concepts for BwdDl Algorithm:\n") + + DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + + DIAGNOSTIC_LINE(SpecifiesThreadBlock) + + DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization) + + DIAGNOSTIC_LINE(SpecifiesDlThreadConfig) + + DIAGNOSTIC_LINE(SpecifiesDlThreadCluster) + + DIAGNOSTIC_LINE(SpecifiesDlBwdBlockTransfer) + + DIAGNOSTIC_LINE(SpecifiesDlEpilogue); + } +}; + template consteval int count_matches_fwd_xdl_v3() { using Alg = FwdXdlV3Algorithm; @@ -415,6 +450,12 @@ consteval int count_matches_bwd_two_stage_xdl() { return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + Alg::c10 + Alg::c11 + Alg::c12; } +template +consteval int count_matches_bwd_dl() { + using Alg = BwdDlAlgorithm; + return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7; +} + template consteval int count_matches_large_tensor() { using Alg = LargeTensorAlgorithm; @@ -477,9 +518,11 @@ consteval void diagnose_bwd_weight_algorithm_signature() constexpr int xdl_matches = count_matches_bwd_xdl(); constexpr int xdl_v3_matches = count_matches_fwd_xdl_v3(); constexpr int two_stage_xdl_matches = count_matches_bwd_two_stage_xdl(); + constexpr int dl_matches = count_matches_bwd_dl(); constexpr int max1 = xdl_v3_matches > xdl_matches ? xdl_v3_matches : xdl_matches; - constexpr int max_matches = max1 > two_stage_xdl_matches ? max1 : two_stage_xdl_matches; + constexpr int max_2 = max1 > two_stage_xdl_matches ? max1 : two_stage_xdl_matches; + constexpr int max_matches = max_2 > dl_matches ? max_2 : dl_matches; if constexpr (max_matches == xdl_matches) { using Alg = BwdXdlAlgorithm; @@ -493,6 +536,10 @@ consteval void diagnose_bwd_weight_algorithm_signature() using Alg = BwdTwoStageXdlAlgorithm; static_assert(Alg::is_valid(), Alg::message()); } + else if constexpr (max_matches == dl_matches) { + using Alg = BwdDlAlgorithm; + static_assert(Alg::is_valid(), Alg::message()); + } else { // This should never happen static_assert(false, "Internal Error: No matching algorithm variant found for diagnostics."); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_dl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_dl_factory.hpp new file mode 100644 index 00000000000..0ddeae2eeda --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_dl_factory.hpp @@ -0,0 +1,131 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +//#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" + +namespace ck_tile::builder::factory { + +// Factory for DeviceGroupedConvBwdWeight_Dl instance +// of a grouped forward convolution kernel. +template + requires ConvDirectionIsBackwardWeight +struct ConvBwdWeightDlFactory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = internal::ConvTensorLayouts; + using Types = internal::BwdWeightConvTensorDataTypes; + using Ops = internal::ElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto BWD_CONV_SPECIALIZATION = internal::SetBwdWeightConvSpecialization(); + + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + + // DL-specific parameters from algorithm descriptor + static constexpr auto DL_THREAD_CFG = ALGORITHM.thread_config; + static constexpr ck::index_t K0PerBlock = DL_THREAD_CFG.k0_per_block; + static constexpr ck::index_t K1 = DL_THREAD_CFG.k1; + static constexpr ck::index_t M1PerThread = DL_THREAD_CFG.m1_per_thread; + static constexpr ck::index_t N1PerThread = DL_THREAD_CFG.n1_per_thread; + static constexpr ck::index_t KPerThread = DL_THREAD_CFG.k_per_thread; + + // Thread cluster from descriptor + static constexpr auto DL_CLUSTER = ALGORITHM.thread_cluster; + using M1N1ThreadClusterM1Xs = to_sequence_v; + using M1N1ThreadClusterN1Xs = to_sequence_v; + + // A Block Transfer from descriptor - K0_M0_M1_K1 tensor format + static constexpr auto DL_A_TRANSFER = ALGORITHM.transfer.a; + using ABlockTransferThreadSliceLengths_K0_M0_M1_K1 = + to_sequence_v; + using ABlockTransferThreadClusterLengths_K0_M0_M1_K1 = + to_sequence_v; + using ABlockTransferThreadClusterArrangeOrder = + to_sequence_v; + using ABlockTransferSrcAccessOrder = to_sequence_v; + using ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = + to_sequence_v; + using ABlockTransferSrcVectorTensorContiguousDimOrder = + to_sequence_v; + using ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = + to_sequence_v; + + // B Block Transfer from descriptor - K0_N0_N1_K1 tensor format + static constexpr auto DL_B_TRANSFER = ALGORITHM.transfer.b; + using BBlockTransferThreadSliceLengths_K0_N0_N1_K1 = + to_sequence_v; + using BBlockTransferThreadClusterLengths_K0_N0_N1_K1 = + to_sequence_v; + using BBlockTransferThreadClusterArrangeOrder = + to_sequence_v; + using BBlockTransferSrcAccessOrder = to_sequence_v; + using BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = + to_sequence_v; + using BBlockTransferSrcVectorTensorContiguousDimOrder = + to_sequence_v; + using BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = + to_sequence_v; + + // C Thread Transfer from descriptor + static constexpr auto DL_C_TRANSFER = ALGORITHM.transfer.c; + using CThreadTransferSrcDstAccessOrder = to_sequence_v; + static constexpr ck::index_t CThreadTransferSrcDstVectorDim = DL_C_TRANSFER.src_dst_vector_dim; + static constexpr ck::index_t CThreadTransferDstScalarPerVector = + DL_C_TRANSFER.dst_scalar_per_vector; + + // The DL forward convolution kernel class instance + using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Dl< + SPATIAL_DIM, + typename Layouts::InLayout, + typename Layouts::WeiLayout, + typename Layouts::OutLayout, + typename Types::InDataType, + typename Types::WeiDataType, + typename Types::OutDataType, + typename Types::AccDataType, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, + BWD_CONV_SPECIALIZATION, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + K0PerBlock, + K1, + M1PerThread, + N1PerThread, + KPerThread, + M1N1ThreadClusterM1Xs, + M1N1ThreadClusterN1Xs, + ABlockTransferThreadSliceLengths_K0_M0_M1_K1, + ABlockTransferThreadClusterLengths_K0_M0_M1_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, + ABlockTransferSrcVectorTensorContiguousDimOrder, + ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, + BBlockTransferThreadSliceLengths_K0_N0_N1_K1, + BBlockTransferThreadClusterLengths_K0_N0_N1_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, + BBlockTransferSrcVectorTensorContiguousDimOrder, + BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector>; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp index 1812f1a0ff6..7e7aa275c72 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -66,6 +66,7 @@ #include "ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp" #include "ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp" #include "ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp" +#include "ck_tile/builder/factory/conv_bwd_weight_dl_factory.hpp" namespace ck_tile::builder::factory { @@ -147,6 +148,10 @@ constexpr auto make_conv_instance() { return typename ConvBwdWeightTwoStageXdlFactory::Instance{}; } + else if constexpr (BwdDlAlgorithm::is_valid()) + { + return typename ConvBwdWeightDlFactory::Instance{}; + } else { diagnose_bwd_weight_algorithm_signature(); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp index 42c59dfaec2..07f0976b74a 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp @@ -48,7 +48,7 @@ struct ConvFwdDlFactory using M1N1ThreadClusterN1Xs = to_sequence_v; // A Block Transfer from descriptor - K0_M0_M1_K1 tensor format - static constexpr auto DL_A_TRANSFER = ALGORITHM.transfer.a.block_transfer; + static constexpr auto DL_A_TRANSFER = ALGORITHM.transfer.a; using ABlockTransferThreadSliceLengths_K0_M0_M1_K1 = to_sequence_v; using ABlockTransferThreadClusterLengths_K0_M0_M1_K1 = @@ -64,7 +64,7 @@ struct ConvFwdDlFactory to_sequence_v; // B Block Transfer from descriptor - K0_N0_N1_K1 tensor format - static constexpr auto DL_B_TRANSFER = ALGORITHM.transfer.b.block_transfer; + static constexpr auto DL_B_TRANSFER = ALGORITHM.transfer.b; using BBlockTransferThreadSliceLengths_K0_N0_N1_K1 = to_sequence_v; using BBlockTransferThreadClusterLengths_K0_N0_N1_K1 = @@ -80,7 +80,7 @@ struct ConvFwdDlFactory to_sequence_v; // C Thread Transfer from descriptor - static constexpr auto DL_C_TRANSFER = ALGORITHM.transfer.c.epilogue; + static constexpr auto DL_C_TRANSFER = ALGORITHM.transfer.c; using CThreadTransferSrcDstAccessOrder = to_sequence_v; static constexpr ck::index_t CThreadTransferSrcDstVectorDim = DL_C_TRANSFER.src_dst_vector_dim; static constexpr ck::index_t CThreadTransferDstScalarPerVector = diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index a3d08e82edb..b042dd6b610 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -153,6 +153,7 @@ add_ck_builder_test(test_ckb_build_bwd_weight_instances conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle.cpp conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp + conv/ck/test_ckb_conv_bwd_weight_dl.cpp conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp ) target_link_libraries(test_ckb_build_bwd_weight_instances PRIVATE utility) diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_dl.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_dl.cpp new file mode 100644 index 00000000000..79bdcc8ef18 --- /dev/null +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_dl.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace cku = ck_tile::builder::test_utils; + +constexpr auto SIGNATURE = + ckt::ConvSignature{.spatial_dim = 2, + .direction = ckb::ConvDirection::BACKWARD_WEIGHT, + .data_type = ckb::DataType::BF16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = ckb::TensorLayout::GNHWC}}, + .weight = {.config = {.layout = ckb::TensorLayout::GKYXC}}, + .output = {.config = {.layout = ckb::TensorLayout::GNHWK}}}; + +constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Dl{} + .with_thread_block(cku::ThreadBlock_256_128x128x16) + .with_bwd_specialization(cku::ConvSpecialization::DEFAULT) + .with_dl_thread_config(cku::DlThreadConfig_16x2x4x4x1) + .with_dl_thread_cluster(cku::DlThreadCluster_8x2) + .with_dl_transfer(cku::DlTransfer5D); + +using Builder = ckb::ConvBuilder; +using Instance = Builder::Instance; + +TEST(BwdWeight_2DBf16_DL, Create) +{ + const auto expected_transfer_parameters = to_string(ALGORITHM); + cku::run_test({"DeviceGroupedConvBwdWeight_Dl", + expected_transfer_parameters, + "Default", + "GNHWC,GKYXC,GNHWK", + "PassThrough,PassThrough,PassThrough"}); +} diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index d003440935f..c45d89ed038 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -181,7 +181,7 @@ struct DlBlockTransfer std::array src_vector_tensor_contiguous_dim_order; std::array dst_vector_tensor_lengths; }; -static_assert(ckb::DlBlockTransferDescriptor); +static_assert(ckb::DlBlockTransferDescriptor4D); struct DlEpilogue { @@ -263,26 +263,16 @@ struct DlThreadCluster_ DlThreadCluster thread_cluster; }; -struct DlBlockTransferAB +struct DlTransfer { - DlBlockTransfer block_transfer; -}; - -struct DlBlockTransferC -{ - DlEpilogue epilogue; -}; - -struct DlTransferABC -{ - DlBlockTransferAB a; - DlBlockTransferAB b; - DlBlockTransferC c; + DlBlockTransfer a; + DlBlockTransfer b; + DlEpilogue c; }; struct DlTransfer_ { - DlTransferABC transfer; + DlTransfer transfer; }; struct TwoStageSpecialization_ @@ -579,4 +569,11 @@ using ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle = using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 = ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, BlockGemm_>; +using ConvAlgorithm_DeviceGroupedConvBwdWeight_Dl = + ConvAlgorithmTemplate; + } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/utils/ckb_conv_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_test_configs.hpp index 7b5807ef233..3deeeca80a2 100644 --- a/experimental/builder/test/utils/ckb_conv_test_configs.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_configs.hpp @@ -17,7 +17,8 @@ constexpr DlThreadConfig DlThreadConfig_16x2x4x4x1{ constexpr DlThreadCluster DlThreadCluster_8x2{.m1_xs = {8, 2}, .n1_xs = {8, 2}}; -constexpr DlBlockTransfer DlBlockTransferAB{.thread_slice_lengths = {8, 1, 1, 2}, +constexpr DlBlockTransfer DlBlockTransfer_8x1x1x2 + {.thread_slice_lengths = {8, 1, 1, 2}, .thread_cluster_lengths = {2, 1, 128, 1}, .thread_cluster_arrange_order = {1, 2, 0, 3}, .src_access_order = {1, 2, 0, 3}, @@ -25,19 +26,12 @@ constexpr DlBlockTransfer DlBlockTransferAB{.thread_slice_lengths = {8, .src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3}, .dst_vector_tensor_lengths = {1, 1, 1, 2}}; -constexpr DlTransferABC DlFwdTransfer{.a = - { - .block_transfer = DlBlockTransferAB, - }, - .b = - { - .block_transfer = DlBlockTransferAB, - }, - .c = { - .epilogue = {.src_dst_access_order = {0, 1, 2, 3, 4, 5}, - .src_dst_vector_dim = 5, - .dst_scalar_per_vector = 4}, - }}; +constexpr DlTransfer DlTransfer5D {.a = DlBlockTransfer_8x1x1x2, + .b = DlBlockTransfer_8x1x1x2, + .c = { + .src_dst_access_order = {0, 1, 2, 3, 4, 5}, + .src_dst_vector_dim = 5, + .dst_scalar_per_vector = 4}}; constexpr Transfer<> Transfer_4x64x1{ .a = diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index 8f530600acc..80c8f067d99 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -226,19 +226,7 @@ inline std::string to_string(DlEpilogue t) } template <> -inline std::string to_string(DlBlockTransferAB t) -{ - return to_string(t.block_transfer); -} - -template <> -inline std::string to_string(DlBlockTransferC t) -{ - return to_string(t.epilogue); -} - -template <> -inline std::string to_string(DlTransferABC t) +inline std::string to_string(DlTransfer t) { std::ostringstream oss; oss << to_string(t.a) << "," << to_string(t.b) << "," << to_string(t.c); @@ -407,4 +395,16 @@ inline std::string to_string +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvBwdWeight_Dl t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," + << to_string(static_cast(t)) << "," + << to_string(static_cast(t)) << "," + << to_string(static_cast(t)); + return oss.str(); +} + } // namespace ck_tile::builder::test From e1b4acd4311a36c35939e2785d35a0b9c2b146f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Wed, 31 Dec 2025 08:29:11 -0500 Subject: [PATCH 35/81] Final implementation for bwd weight DL factory. --- .../conv/ck/test_ckb_conv_bwd_weight_dl.cpp | 3 +- .../conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp | 6 ++-- .../test/impl/conv_algorithm_types.hpp | 35 ++++++++++-------- .../test/utils/ckb_conv_test_configs.hpp | 27 +++++++++++--- .../test/utils/conv_algorithm_type_utils.hpp | 36 ++++++++++++++++--- 5 files changed, 80 insertions(+), 27 deletions(-) diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_dl.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_dl.cpp index 79bdcc8ef18..99cadb6d20d 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_dl.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_dl.cpp @@ -21,7 +21,7 @@ constexpr auto SIGNATURE = constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Dl{} .with_thread_block(cku::ThreadBlock_256_128x128x16) .with_bwd_specialization(cku::ConvSpecialization::DEFAULT) - .with_dl_thread_config(cku::DlThreadConfig_16x2x4x4x1) + .with_dl_thread_config(cku::DlThreadConfig_16x1x4x4x1) .with_dl_thread_cluster(cku::DlThreadCluster_8x2) .with_dl_transfer(cku::DlTransfer5D); @@ -31,6 +31,7 @@ using Instance = Builder::Instance; TEST(BwdWeight_2DBf16_DL, Create) { const auto expected_transfer_parameters = to_string(ALGORITHM); + std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl; cku::run_test({"DeviceGroupedConvBwdWeight_Dl", expected_transfer_parameters, "Default", diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp index b740ac8704b..79d4827feea 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp @@ -29,11 +29,12 @@ TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Ins .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_dl_thread_config(DlThreadConfig_16x2x4x4x1) .with_dl_thread_cluster(DlThreadCluster_8x2) - .with_dl_transfer(DlFwdTransfer); + .with_dl_transfer(DlTransfer4D); using Builder = ConvBuilder; const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); + std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl; run_test({"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK", expected_transfer_parameters, "Default", @@ -64,11 +65,12 @@ TEST(FwdConvInstances, GemmSpecialization::MNKPadding) .with_dl_thread_config(DlThreadConfig_16x2x4x4x1) .with_dl_thread_cluster(DlThreadCluster_8x2) - .with_dl_transfer(DlFwdTransfer); + .with_dl_transfer(DlTransfer4D); using Builder = ConvBuilder; const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); + std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl; run_test({"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK", expected_transfer_parameters, "Filter1x1Pad0", diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index c45d89ed038..a4dc20f03e2 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -171,17 +171,19 @@ struct DlThreadCluster }; static_assert(ckb::DlThreadClusterDescriptor); +template struct DlBlockTransfer { - std::array thread_slice_lengths; - std::array thread_cluster_lengths; - std::array thread_cluster_arrange_order; - std::array src_access_order; - std::array src_vector_tensor_lengths; - std::array src_vector_tensor_contiguous_dim_order; - std::array dst_vector_tensor_lengths; + std::array thread_slice_lengths; + std::array thread_cluster_lengths; + std::array thread_cluster_arrange_order; + std::array src_access_order; + std::array src_vector_tensor_lengths; + std::array src_vector_tensor_contiguous_dim_order; + std::array dst_vector_tensor_lengths; }; -static_assert(ckb::DlBlockTransferDescriptor4D); +static_assert(ckb::DlBlockTransferDescriptor4D>); +static_assert(ckb::DlBlockTransferDescriptor5D>); struct DlEpilogue { @@ -263,16 +265,18 @@ struct DlThreadCluster_ DlThreadCluster thread_cluster; }; +template struct DlTransfer { - DlBlockTransfer a; - DlBlockTransfer b; + DlBlockTransfer a; + DlBlockTransfer b; DlEpilogue c; }; +template struct DlTransfer_ { - DlTransfer transfer; + DlTransfer transfer; }; struct TwoStageSpecialization_ @@ -396,7 +400,7 @@ struct ConvAlgorithmTemplate : Components... template constexpr auto with_transfer(const T& t) const { - static_assert(std::is_base_of_v, ConvAlgorithmTemplate> || + static_assert(std::is_base_of_v, ConvAlgorithmTemplate> || std::is_base_of_v, ConvAlgorithmTemplate>); auto result = *this; result.transfer = t; @@ -481,7 +485,8 @@ struct ConvAlgorithmTemplate : Components... template constexpr auto with_dl_transfer(const T& t) const { - static_assert(std::is_base_of_v); + static_assert(std::is_base_of_v, ConvAlgorithmTemplate> || + std::is_base_of_v, ConvAlgorithmTemplate>); auto result = *this; result.transfer = t; return result; @@ -549,7 +554,7 @@ using ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK = ConvSpecializationFwd_, DlThreadConfig_, DlThreadCluster_, - DlTransfer_>; + DlTransfer_<>>; using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor = LargeTensorWrapper; @@ -573,7 +578,7 @@ using ConvAlgorithm_DeviceGroupedConvBwdWeight_Dl = ConvAlgorithmTemplate, ConvSpecializationBwdWeight_>; } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/utils/ckb_conv_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_test_configs.hpp index 3deeeca80a2..38f0675da75 100644 --- a/experimental/builder/test/utils/ckb_conv_test_configs.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_configs.hpp @@ -15,9 +15,12 @@ using namespace test; constexpr DlThreadConfig DlThreadConfig_16x2x4x4x1{ .k0_per_block = 16, .k1 = 2, .m1_per_thread = 4, .n1_per_thread = 4, .k_per_thread = 1}; +constexpr DlThreadConfig DlThreadConfig_16x1x4x4x1{ + .k0_per_block = 16, .k1 = 1, .m1_per_thread = 4, .n1_per_thread = 4, .k_per_thread = 1}; + constexpr DlThreadCluster DlThreadCluster_8x2{.m1_xs = {8, 2}, .n1_xs = {8, 2}}; -constexpr DlBlockTransfer DlBlockTransfer_8x1x1x2 +constexpr DlBlockTransfer<4> DlBlockTransfer_8x1x1x2 {.thread_slice_lengths = {8, 1, 1, 2}, .thread_cluster_lengths = {2, 1, 128, 1}, .thread_cluster_arrange_order = {1, 2, 0, 3}, @@ -26,13 +29,29 @@ constexpr DlBlockTransfer DlBlockTransfer_8x1x1x2 .src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3}, .dst_vector_tensor_lengths = {1, 1, 1, 2}}; -constexpr DlTransfer DlTransfer5D {.a = DlBlockTransfer_8x1x1x2, - .b = DlBlockTransfer_8x1x1x2, - .c = { +constexpr DlTransfer<4> DlTransfer4D {.a = DlBlockTransfer_8x1x1x2, + .b = DlBlockTransfer_8x1x1x2, + .c = { .src_dst_access_order = {0, 1, 2, 3, 4, 5}, .src_dst_vector_dim = 5, .dst_scalar_per_vector = 4}}; +constexpr DlBlockTransfer<5> DlBlockTransfer_1x8x1x1x1 + {.thread_slice_lengths = {1, 8, 1, 1, 1}, + .thread_cluster_lengths = {1, 2, 1, 128, 1}, + .thread_cluster_arrange_order = {0, 2, 3, 1, 4}, + .src_access_order = {0, 2, 3, 1, 4}, + .src_vector_tensor_lengths = {1, 1, 1, 1, 1}, + .src_vector_tensor_contiguous_dim_order = {0, 2, 3, 1, 4}, + .dst_vector_tensor_lengths = {1, 1, 1, 1, 1}}; + +constexpr DlTransfer<5> DlTransfer5D {.a = DlBlockTransfer_1x8x1x1x1, + .b = DlBlockTransfer_1x8x1x1x1, + .c = { + .src_dst_access_order = {0, 1, 2, 3, 4, 5}, + .src_dst_vector_dim = 5, + .dst_scalar_per_vector = 1}}; + constexpr Transfer<> Transfer_4x64x1{ .a = { diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index 80c8f067d99..ee22b8392fc 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -205,7 +205,19 @@ inline std::string to_string(DlThreadCluster t) } template <> -inline std::string to_string(DlBlockTransfer t) +inline std::string to_string>(DlBlockTransfer<4> t) +{ + std::ostringstream oss; + oss << array_to_seq(t.thread_slice_lengths) << "," << array_to_seq(t.thread_cluster_lengths) + << "," << array_to_seq(t.thread_cluster_arrange_order) << "," + << array_to_seq(t.src_access_order) << "," << array_to_seq(t.src_vector_tensor_lengths) + << "," << array_to_seq(t.src_vector_tensor_contiguous_dim_order) << "," + << array_to_seq(t.dst_vector_tensor_lengths); + return oss.str(); +} + +template <> +inline std::string to_string>(DlBlockTransfer<5> t) { std::ostringstream oss; oss << array_to_seq(t.thread_slice_lengths) << "," << array_to_seq(t.thread_cluster_lengths) @@ -226,7 +238,15 @@ inline std::string to_string(DlEpilogue t) } template <> -inline std::string to_string(DlTransfer t) +inline std::string to_string>(DlTransfer<4> t) +{ + std::ostringstream oss; + oss << to_string(t.a) << "," << to_string(t.b) << "," << to_string(t.c); + return oss.str(); +} + +template <> +inline std::string to_string>(DlTransfer<5> t) { std::ostringstream oss; oss << to_string(t.a) << "," << to_string(t.b) << "," << to_string(t.c); @@ -309,7 +329,13 @@ inline std::string to_string(DlThreadCluster_ t) } template <> -inline std::string to_string(DlTransfer_ t) +inline std::string to_string>(DlTransfer_<4> t) +{ + return to_string(t.transfer); +} + +template <> +inline std::string to_string>(DlTransfer_<5> t) { return to_string(t.transfer); } @@ -354,7 +380,7 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) << "," << to_string(static_cast(t)) << "," - << to_string(static_cast(t)); + << to_string(static_cast>(t)); return oss.str(); } @@ -403,7 +429,7 @@ inline std::string to_string( oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) << "," << to_string(static_cast(t)) << "," - << to_string(static_cast(t)); + << to_string(static_cast>(t)); return oss.str(); } From 83be9c740c0f6f1238e78420fefe68488fa60d88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Wed, 31 Dec 2025 09:08:06 -0500 Subject: [PATCH 36/81] Add test for creating DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle instance. --- .../builder/conv_algorithm_concepts.hpp | 15 +++++ .../builder/conv_algorithm_diagnostics.hpp | 57 +++++++++++++++++++ .../builder/factory/conv_algorithms.hpp | 11 ++-- .../builder/include/ck_tile/builder/types.hpp | 3 +- experimental/builder/test/CMakeLists.txt | 1 + ...b_conv_bwd_weight_multi_d_xdl_cshuffle.cpp | 41 +++++++++++++ .../test/impl/conv_algorithm_types.hpp | 14 +++-- 7 files changed, 132 insertions(+), 10 deletions(-) create mode 100644 experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_xdl_cshuffle.cpp diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index cbded4f8b05..b3c3a9a59f4 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -327,6 +327,12 @@ concept SpecifiesTwoStageSupport = requires { requires T::specialization == ConvAlgorithmSpecialization::TWO_STAGE; }; +template +concept SpecifiesMultipleDSupport = requires { + { T::specialization } -> std::convertible_to; + requires T::specialization == ConvAlgorithmSpecialization::MULTIPLE_D; +}; + template concept SpecifiesGenericInstance = !requires { { T::specialization }; @@ -338,6 +344,15 @@ concept SpecifiesTransposeTransfer = requires { { T::max_transpose_transfer_dst_scalar_per_vector } -> SizeType; }; +template +concept HasTransposeTransfer = requires { + { T::max_transpose_transfer_src_scalar_per_vector }; + { T::max_transpose_transfer_dst_scalar_per_vector }; +}; + +template +concept TransposeTransferWellDefinedIfProvided = + !HasTransposeTransfer || SpecifiesTransposeTransfer; template concept SpecifiesGemmBatchOptions = requires { diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp index 0ca4c340572..d979b778091 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp @@ -732,6 +732,26 @@ consteval auto detailed_diagnostic_SpecifiesTwoStageSupport() -> std::string { return msg; } +template +consteval auto detailed_diagnostic_SpecifiesMultipleDSupport() -> std::string { + std::string msg; + if constexpr (requires { T::specialization; }) { + using SpecType = decltype(T::specialization); + constexpr bool convertible = std::convertible_to; + msg += " → T::specialization: " + std::string(CHECK_MARK(convertible)) + + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; + + if constexpr (convertible) { + constexpr bool is_multiple_d = (T::specialization == ConvAlgorithmSpecialization::MULTIPLE_D); + msg += " → specialization == MULTIPLE_D: " + std::string(CHECK_MARK(is_multiple_d)) + "\n"; + } + } else { + msg += " → T::specialization: [✗] (missing member)\n"; + } + + return msg; +} + template consteval auto detailed_diagnostic_SpecifiesGenericInstance() -> std::string { std::string msg; @@ -767,6 +787,43 @@ consteval auto detailed_diagnostic_SpecifiesTransposeTransfer() -> std::string { return msg; } +template +consteval auto detailed_diagnostic_TransposeTransferWellDefinedIfProvided() -> std::string { + std::string msg; + + constexpr bool has_src = requires { T::max_transpose_transfer_src_scalar_per_vector; }; + constexpr bool has_dst = requires { T::max_transpose_transfer_dst_scalar_per_vector; }; + constexpr bool has_transpose_transfer = has_src || has_dst; + + if constexpr (!has_transpose_transfer) { + msg += " → Transpose transfer fields not provided: [✓] (optional, not required)\n"; + } else { + msg += " → Transpose transfer fields provided, checking if well-defined:\n"; + + if constexpr (has_src) { + using SrcType = decltype(T::max_transpose_transfer_src_scalar_per_vector); + constexpr bool src_convertible = std::convertible_to; + msg += " → T::max_transpose_transfer_src_scalar_per_vector: " + + std::string(CHECK_MARK(src_convertible)) + + (src_convertible ? "" : std::string(detail::get_type_info())) + "\n"; + } else { + msg += " → T::max_transpose_transfer_src_scalar_per_vector: [✗] (missing, but dst is provided)\n"; + } + + if constexpr (has_dst) { + using DstType = decltype(T::max_transpose_transfer_dst_scalar_per_vector); + constexpr bool dst_convertible = std::convertible_to; + msg += " → T::max_transpose_transfer_dst_scalar_per_vector: " + + std::string(CHECK_MARK(dst_convertible)) + + (dst_convertible ? "" : std::string(detail::get_type_info())) + "\n"; + } else { + msg += " → T::max_transpose_transfer_dst_scalar_per_vector: [✗] (missing, but src is provided)\n"; + } + } + + return msg; +} + template consteval auto detailed_diagnostic_SpecifiesGemmBatchOptions() -> std::string { if constexpr (requires { T::num_conv_groups_to_merge; }) { diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index 1ad877e878d..cd1d0d50706 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -248,7 +248,8 @@ struct BwdXdlAlgorithm { CHECK_CONCEPT(T, SpecifiesSourceAccessOrder) CHECK_CONCEPT(T, SpecifiesGridwiseBwdXdlGemm) CHECK_CONCEPT(T, SpecifiesBwdWeightConvSpecialization) - CHECK_CONCEPT(T, SpecifiesTransposeTransfer) + CHECK_CONCEPT(T, TransposeTransferWellDefinedIfProvided) + CHECK_CONCEPT(T, SpecifiesGenericInstance) static constexpr bool c1 = c_ConvAlgorithmDescriptor; static constexpr bool c2 = c_SpecifiesThreadBlock; @@ -258,10 +259,11 @@ struct BwdXdlAlgorithm { static constexpr bool c6 = c_SpecifiesSourceAccessOrder; static constexpr bool c7 = c_SpecifiesGridwiseBwdXdlGemm; static constexpr bool c8 = c_SpecifiesBwdWeightConvSpecialization; - static constexpr bool c9 = c_SpecifiesTransposeTransfer; + static constexpr bool c9 = c_TransposeTransferWellDefinedIfProvided; + static constexpr bool c10 = c_SpecifiesGenericInstance; static consteval bool is_valid() { - return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9; + return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && c10; } static consteval auto message() -> std::string { @@ -275,7 +277,8 @@ struct BwdXdlAlgorithm { DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) + DIAGNOSTIC_LINE(SpecifiesGridwiseBwdXdlGemm) + DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization) + - DIAGNOSTIC_LINE(SpecifiesTransposeTransfer); + DIAGNOSTIC_LINE(TransposeTransferWellDefinedIfProvided) + + DIAGNOSTIC_LINE(SpecifiesGenericInstance); } }; diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index c44f3368aea..296caf21b40 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -233,7 +233,8 @@ enum class PipelineScheduler enum class ConvAlgorithmSpecialization { LARGE_TENSOR, - TWO_STAGE + TWO_STAGE, + MULTIPLE_D }; // toString methods for enum classes diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index b042dd6b610..206c2f95dab 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -152,6 +152,7 @@ target_link_libraries(test_ckb_build_fwd_instances PRIVATE utility) add_ck_builder_test(test_ckb_build_bwd_weight_instances conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle.cpp conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp + conv/ck/test_ckb_conv_bwd_weight_multi_d_xdl_cshuffle.cpp conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp conv/ck/test_ckb_conv_bwd_weight_dl.cpp conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_xdl_cshuffle.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_xdl_cshuffle.cpp new file mode 100644 index 00000000000..206fc8beb96 --- /dev/null +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_xdl_cshuffle.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" +#include "ck_tile/host/device_prop.hpp" + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace cku = ck_tile::builder::test_utils; + +constexpr auto SIGNATURE = + ckt::ConvSignature{.spatial_dim = 2, + .direction = ckb::ConvDirection::BACKWARD_WEIGHT, + .data_type = ckb::DataType::FP16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = ckb::TensorLayout::GNHWC}}, + .weight = {.config = {.layout = ckb::TensorLayout::GKYXC}}, + .output = {.config = {.layout = ckb::TensorLayout::GNHWK}}}; + +constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle{} + .with_thread_block(cku::ThreadBlock_256_128x128x8) + .with_gemm_config(cku::BwdGemmParams_Xdl_4x4_per_wave) + .with_transfer(cku::BwdTransfer_4x64x1) + .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT); + +using Builder = ckb::ConvBuilder; +using Instance = Builder::Instance; + +TEST(BwdWeight_2DFp16_MultiD_CShuffle_GNHWC, Create) +{ + const auto expected_transfer_parameters = to_string(ALGORITHM); + std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl; + cku::run_test({"DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle", + expected_transfer_parameters, + "Default", + "GNHWC,GKYXC,GNHWK", + "PassThrough,PassThrough,PassThrough", + "fp16,fp16>"}); // check compute types +} diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index a4dc20f03e2..d7bf60dc1d3 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -284,6 +284,11 @@ struct TwoStageSpecialization_ static constexpr ConvAlgorithmSpecialization specialization = ConvAlgorithmSpecialization::TWO_STAGE; }; +struct MultipleDSpecialization_ +{ + static constexpr ConvAlgorithmSpecialization specialization = ConvAlgorithmSpecialization::MULTIPLE_D; +}; + // Specialization wrapper for large tensor support template struct LargeTensorWrapper @@ -575,10 +580,9 @@ using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 = ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, BlockGemm_>; using ConvAlgorithm_DeviceGroupedConvBwdWeight_Dl = - ConvAlgorithmTemplate, - ConvSpecializationBwdWeight_>; + ConvAlgorithmTemplate, ConvSpecializationBwdWeight_>; + +using ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle = + ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, MultipleDSpecialization_>; } // namespace ck_tile::builder::test From fba80401d1b458f58b9bcc2e22b1c5f848e40a66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Wed, 31 Dec 2025 09:32:58 -0500 Subject: [PATCH 37/81] Add factory for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle --- .../builder/factory/conv_algorithms.hpp | 57 +++++++++- .../conv_bwd_weight_multi_d_xdl_factory.hpp | 102 ++++++++++++++++++ .../builder/factory/conv_dispatcher.hpp | 5 + .../factory/helpers/ck/conv_tensor_layout.hpp | 4 +- .../factory/helpers/ck/conv_tensor_type.hpp | 2 + .../test/utils/conv_algorithm_type_utils.hpp | 10 ++ ...onv_bwd_weight_multiple_d_xdl_cshuffle.hpp | 6 +- 7 files changed, 180 insertions(+), 6 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index cd1d0d50706..3f237ae7449 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -282,6 +282,47 @@ struct BwdXdlAlgorithm { } }; +template +struct BwdMultiDXdlAlgorithm { + CHECK_CONCEPT(T, ConvAlgorithmDescriptor) + CHECK_CONCEPT(T, SpecifiesThreadBlock) + CHECK_CONCEPT(T, SpecifiesBlockTransfer4D) + CHECK_CONCEPT(T, SpecifiesLdsTransfer) + CHECK_CONCEPT(T, SpecifiesThreadClusterAccessOrder) + CHECK_CONCEPT(T, SpecifiesSourceAccessOrder) + CHECK_CONCEPT(T, SpecifiesGridwiseBwdXdlGemm) + CHECK_CONCEPT(T, SpecifiesBwdWeightConvSpecialization) + CHECK_CONCEPT(T, SpecifiesMultipleDSupport) + + static constexpr bool c1 = c_ConvAlgorithmDescriptor; + static constexpr bool c2 = c_SpecifiesThreadBlock; + static constexpr bool c3 = c_SpecifiesBlockTransfer4D; + static constexpr bool c4 = c_SpecifiesLdsTransfer; + static constexpr bool c5 = c_SpecifiesThreadClusterAccessOrder; + static constexpr bool c6 = c_SpecifiesSourceAccessOrder; + static constexpr bool c7 = c_SpecifiesGridwiseBwdXdlGemm; + static constexpr bool c8 = c_SpecifiesBwdWeightConvSpecialization; + static constexpr bool c9 = c_SpecifiesMultipleDSupport; + + static consteval bool is_valid() { + return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9; + } + + static consteval auto message() -> std::string { + return std::string("\n=== Backward XDL Algorithm Diagnostic (closest match) ===\n" + "Concepts for BwdXdl Algorithm:\n") + + DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + + DIAGNOSTIC_LINE(SpecifiesThreadBlock) + + DIAGNOSTIC_LINE(SpecifiesBlockTransfer4D) + + DIAGNOSTIC_LINE(SpecifiesLdsTransfer) + + DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) + + DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) + + DIAGNOSTIC_LINE(SpecifiesGridwiseBwdXdlGemm) + + DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization) + + DIAGNOSTIC_LINE(SpecifiesMultipleDSupport); + } +}; + template struct BwdXdlV3Algorithm { CHECK_CONCEPT(T, ConvAlgorithmDescriptor) @@ -441,6 +482,12 @@ consteval int count_matches_bwd_xdl() { return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9; } +template +consteval int count_matches_bwd_multi_d_xdl() { + using Alg = BwdMultiDXdlAlgorithm; + return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9; +} + template consteval int count_matches_bwd_xdl_v3() { using Alg = BwdXdlV3Algorithm; @@ -522,10 +569,12 @@ consteval void diagnose_bwd_weight_algorithm_signature() constexpr int xdl_v3_matches = count_matches_fwd_xdl_v3(); constexpr int two_stage_xdl_matches = count_matches_bwd_two_stage_xdl(); constexpr int dl_matches = count_matches_bwd_dl(); + constexpr int multi_d_xdl_matches = count_matches_bwd_multi_d_xdl(); constexpr int max1 = xdl_v3_matches > xdl_matches ? xdl_v3_matches : xdl_matches; - constexpr int max_2 = max1 > two_stage_xdl_matches ? max1 : two_stage_xdl_matches; - constexpr int max_matches = max_2 > dl_matches ? max_2 : dl_matches; + constexpr int max2 = max1 > two_stage_xdl_matches ? max1 : two_stage_xdl_matches; + constexpr int max3 = max2 > dl_matches ? max2 : dl_matches; + constexpr int max_matches = max3 > multi_d_xdl_matches ? max3 : multi_d_xdl_matches; if constexpr (max_matches == xdl_matches) { using Alg = BwdXdlAlgorithm; @@ -543,6 +592,10 @@ consteval void diagnose_bwd_weight_algorithm_signature() using Alg = BwdDlAlgorithm; static_assert(Alg::is_valid(), Alg::message()); } + else if constexpr (max_matches == multi_d_xdl_matches) { + using Alg = BwdMultiDXdlAlgorithm; + static_assert(Alg::is_valid(), Alg::message()); + } else { // This should never happen static_assert(false, "Internal Error: No matching algorithm variant found for diagnostics."); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp new file mode 100644 index 00000000000..f4787ddc4ab --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp @@ -0,0 +1,102 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_limits.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" + +namespace ck_tile::builder::factory { + +// Factory for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle instance +// of a grouped forward convolution kernel. +template + requires ConvDirectionIsBackwardWeight +struct ConvBwdWeightMultiDXdlFactory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = internal::ConvTensorLayouts; + using Types = internal::BwdWeightConvTensorDataTypes; + using Ops = internal::ElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto BWD_CONV_SPECIALIZATION = internal::SetBwdWeightConvSpecialization(); + + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; + static constexpr auto A_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); + + // Check limits for the algorithm parameters. + // TODO: Add more limits checks as needed. + static_assert(InputVectorTransferLimits); + static_assert(InputVectorTransferLimits); + static_assert(OutputVectorTransferLimits); + static_assert(AccessOrderLimits4D); + static_assert(AccessOrderLimits4D); + static_assert(AccessOrderLimits4D); + static_assert(AccessOrderLimits4D); + + // The forward convolution kernel class instance. + using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< + SPATIAL_DIM, + typename Layouts::InLayout, + typename Layouts::WeiLayout, + typename Layouts::OutLayout, + typename Layouts::DsLayout, + typename Types::InDataType, + typename Types::WeiDataType, + typename Types::OutDataType, + typename Types::AccDataType, + typename Types::DsDataType, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, + BWD_CONV_SPECIALIZATION, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.k1, + XDL_PARAMS.m_per_xdl, + XDL_PARAMS.n_per_xdl, + XDL_PARAMS.m_xdl_per_wave, + XDL_PARAMS.n_xdl_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + A_BLOCK_TRANSFER.lds_padding, + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + B_BLOCK_TRANSFER.lds_padding, + C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + typename Types::InComputeType, + typename Types::WeiComputeType>; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp index 7e7aa275c72..19f8dde34f5 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -67,6 +67,7 @@ #include "ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp" #include "ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp" #include "ck_tile/builder/factory/conv_bwd_weight_dl_factory.hpp" +#include "ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp" namespace ck_tile::builder::factory { @@ -152,6 +153,10 @@ constexpr auto make_conv_instance() { return typename ConvBwdWeightDlFactory::Instance{}; } + else if constexpr (BwdMultiDXdlAlgorithm::is_valid()) + { + return typename ConvBwdWeightMultiDXdlFactory::Instance{}; + } else { diagnose_bwd_weight_algorithm_signature(); diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp index 566524f3a07..fd6df5f09ab 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp @@ -236,12 +236,14 @@ struct ConvTensorLayouts using ALayout = std::conditional_t; using BLayout = std::conditional_t; using ELayout = std::conditional_t; - using DsLayout = std::conditional_t; // Backward weight convolution layouts using InLayout = std::conditional_t; using WeiLayout = std::conditional_t; using OutLayout = std::conditional_t; + + // Applicable for all directions + using DsLayout = AuxLayout; }; } // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp index d6b0e067009..2ab1c40882c 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp @@ -195,6 +195,8 @@ struct BwdWeightConvTensorDataTypes using AccDataType = typename decltype(GetTensorAccumulationType())::type; + // Data types for the auxiliary tensors (e.g., bias). + using DsDataType = typename decltype(GetAuxiliaryTensorDataTypes())::type; }; } // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index ee22b8392fc..e48faefc9f8 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -433,4 +433,14 @@ inline std::string to_string( return oss.str(); } +template <> +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); + return oss.str(); +} + } // namespace ck_tile::builder::test diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp index 759318afe92..5716f10db2e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -52,7 +52,7 @@ __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batched_gemm_xdlops_bwd_weight(const FloatA* __restrict__ p_a_grid, + kernel_batched_gemm_xdlops_bwd_weight_multiple_d(const FloatA* __restrict__ p_a_grid, const FloatB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, const AElementwiseOperation a_element_op, @@ -568,7 +568,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle int max_occupancy = 0; hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( &max_occupancy, - kernel_batched_gemm_xdlops_bwd_weight< + kernel_batched_gemm_xdlops_bwd_weight_multiple_d< GridwiseGemm, ADataType, BDataType, @@ -841,7 +841,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle p_c_grid, 0, arg.c_space_size_bytes, stream_config.stream_id_)); }; - const auto kernel = kernel_batched_gemm_xdlops_bwd_weight< + const auto kernel = kernel_batched_gemm_xdlops_bwd_weight_multiple_d< GridwiseGemm, ADataType, BDataType, From 09e188f2a8d4f74200f170e73f89ac06887d6a49 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Fri, 2 Jan 2026 02:28:05 -0500 Subject: [PATCH 38/81] Treat ref algorithm the same way as real algorithms in the dispatcher. --- .../builder/conv_algorithm_concepts.hpp | 45 +++++++++++-------- .../builder/conv_algorithm_diagnostics.hpp | 20 +++++++++ .../builder/factory/conv_algorithms.hpp | 20 +++++++++ 3 files changed, 67 insertions(+), 18 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index b3c3a9a59f4..120590e2d7d 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -315,24 +315,6 @@ concept SpecifiesLoopScheduler = requires { { T::loop_scheduler } -> std::convertible_to; }; -template -concept SpecifiesLargeTensorSupport = requires { - { T::specialization } -> std::convertible_to; - requires T::specialization == ConvAlgorithmSpecialization::LARGE_TENSOR; -}; - -template -concept SpecifiesTwoStageSupport = requires { - { T::specialization } -> std::convertible_to; - requires T::specialization == ConvAlgorithmSpecialization::TWO_STAGE; -}; - -template -concept SpecifiesMultipleDSupport = requires { - { T::specialization } -> std::convertible_to; - requires T::specialization == ConvAlgorithmSpecialization::MULTIPLE_D; -}; - template concept SpecifiesGenericInstance = !requires { { T::specialization }; @@ -359,6 +341,33 @@ concept SpecifiesGemmBatchOptions = requires { { T::num_conv_groups_to_merge } -> SizeType; }; +/******************************************** */ +/* Algorithm specialization concepts */ +/******************************************** */ +template +concept SpecifiesLargeTensorSupport = requires { + { T::specialization } -> std::convertible_to; + requires T::specialization == ConvAlgorithmSpecialization::LARGE_TENSOR; +}; + +template +concept SpecifiesReferenceAlgorithm = requires { + { T::specialization } -> std::convertible_to; + requires T::specialization == ConvAlgorithmSpecialization::REFERENCE; +}; + +template +concept SpecifiesTwoStageSupport = requires { + { T::specialization } -> std::convertible_to; + requires T::specialization == ConvAlgorithmSpecialization::TWO_STAGE; +}; + +template +concept SpecifiesMultipleDSupport = requires { + { T::specialization } -> std::convertible_to; + requires T::specialization == ConvAlgorithmSpecialization::MULTIPLE_D; +}; + /******************************************** */ /* DL-specific descriptors and requirements */ /******************************************** */ diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp index d979b778091..1f7bf6565d0 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp @@ -712,6 +712,26 @@ consteval auto detailed_diagnostic_SpecifiesLargeTensorSupport() -> std::string return msg; } +template +consteval auto detailed_diagnostic_SpecifiesReferenceAlgorithm() -> std::string { + std::string msg; + if constexpr (requires { T::specialization; }) { + using SpecType = decltype(T::specialization); + constexpr bool convertible = std::convertible_to; + msg += " → T::specialization: " + std::string(CHECK_MARK(convertible)) + + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; + + if constexpr (convertible) { + constexpr bool is_reference = (T::specialization == ConvAlgorithmSpecialization::REFERENCE); + msg += " → specialization == REFERENCE: " + std::string(CHECK_MARK(is_reference)) + "\n"; + } + } else { + msg += " → T::specialization: [✗] (missing member)\n"; + } + + return msg; +} + template consteval auto detailed_diagnostic_SpecifiesTwoStageSupport() -> std::string { std::string msg; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index 3f237ae7449..15c383fb7a8 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -9,6 +9,26 @@ namespace ck_tile::builder::factory { using namespace ck_tile::builder::diagnostics; +template +struct ReferenceAlgorithm { + CHECK_CONCEPT(T, ConvAlgorithmDescriptor) + CHECK_CONCEPT(T, SpecifiesReferenceAlgorithm) + + static constexpr bool c1 = c_ConvAlgorithmDescriptor; + static constexpr bool c2 = c_SpecifiesReferenceAlgorithm; + + static consteval bool is_valid() { + return c1 && c2; + } + + static consteval auto message() -> std::string { + return std::string("\n=== Reference Algorithm Diagnostic (closest match) ===\n" + "Concepts for Reference Algorithm:\n") + + DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + + DIAGNOSTIC_LINE(SpecifiesReferenceAlgorithm); + } +}; + template struct FwdXdlV3Algorithm { CHECK_CONCEPT(T, ConvAlgorithmDescriptor) From d045923a0dcaba8c9ecf51b4662426bc2e3e307c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Fri, 2 Jan 2026 05:12:19 -0500 Subject: [PATCH 39/81] Refactor large tensor support and WMMA configuration. --- .../builder/conv_algorithm_concepts.hpp | 13 +- .../builder/conv_algorithm_diagnostics.hpp | 36 +++-- .../builder/factory/conv_algorithms.hpp | 125 ++++++++++++++---- .../builder/factory/conv_dispatcher.hpp | 14 +- .../factory/conv_fwd_large_tensor_factory.hpp | 20 ++- .../factory/helpers/ck/conv_tuning_params.hpp | 4 +- .../test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp | 5 +- ...est_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp | 6 +- ...test_ckb_conv_fwd_2d_large_tensor_fp16.cpp | 10 +- .../test/impl/conv_algorithm_types.hpp | 36 +++-- .../test/utils/ckb_conv_test_configs.hpp | 21 ++- .../test/utils/conv_algorithm_type_utils.hpp | 9 +- 12 files changed, 196 insertions(+), 103 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index 120590e2d7d..326350eb274 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -39,7 +39,7 @@ concept GridwiseXdlGemmDescriptor = requires(T t) { // Concept for parameter that describe block GEMM problem. template -concept BlockGemmDescriptor = requires(T t) { +concept BlockGemmPipelineDescriptor = requires(T t) { { t.pipeline_version } -> std::convertible_to; { t.scheduler } -> std::convertible_to; }; @@ -52,10 +52,8 @@ concept GridwiseWmmaGemmDescriptor = requires(T t) { { t.n_per_wmma } -> SizeType; { t.m_wmma_per_wave } -> SizeType; { t.n_wmma_per_wave } -> SizeType; - { t.pipeline_version } -> std::convertible_to; }; - // Concept for vectorized data transfer for convolution input tensors. template concept BlockTransferDescriptor = requires(T t) { @@ -253,8 +251,13 @@ concept SpecifiesSourceAccessOrder = requires(T t) { // Concept to check if struct specifies block GEMM. template concept SpecifiesBlockGemm = requires { - { T::block_gemm.pipeline_version } -> std::convertible_to; - { T::block_gemm.scheduler } -> std::convertible_to; + { T::block_gemm_pipeline } -> BlockGemmPipelineDescriptor; +}; + +template +concept SpecifiedGridwiseGemmPipeline = requires +{ + { T::pipeline_version } -> std::convertible_to; }; // Concept to check if struct specifies block GEMM (CK Tile). diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp index 1f7bf6565d0..df81b20f8ee 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp @@ -593,28 +593,28 @@ template consteval auto detailed_diagnostic_SpecifiesBlockGemm() -> std::string { std::string msg; - if constexpr (!requires { { T::block_gemm } -> BlockGemmDescriptor; }) { - return " → T::block_gemm: [✗] (missing or wrong type)\n"; + if constexpr (!requires { { T::block_gemm_pipeline } -> BlockGemmPipelineDescriptor; }) { + return " → T::block_gemm_pipeline: [✗] (missing or wrong type)\n"; } - msg += " → T::block_gemm member: [✓]\n"; + msg += " → T::block_gemm_pipeline member: [✓]\n"; - if constexpr (requires { T::block_gemm.pipeline_version; }) { - using PipelineType = decltype(T::block_gemm.pipeline_version); + if constexpr (requires { T::block_gemm_pipeline.pipeline_version; }) { + using PipelineType = decltype(T::block_gemm_pipeline.pipeline_version); constexpr bool convertible = std::convertible_to; - msg += " → block_gemm.pipeline_version: " + std::string(CHECK_MARK(convertible)) + + msg += " → block_gemm_pipeline.pipeline_version: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; } else { - msg += " → block_gemm.pipeline_version: [✗] (missing member)\n"; + msg += " → block_gemm_pipeline.pipeline_version: [✗] (missing member)\n"; } - if constexpr (requires { T::block_gemm.scheduler; }) { - using SchedulerType = decltype(T::block_gemm.scheduler); + if constexpr (requires { T::block_gemm_pipeline.scheduler; }) { + using SchedulerType = decltype(T::block_gemm_pipeline.scheduler); constexpr bool convertible = std::convertible_to; - msg += " → block_gemm.scheduler: " + std::string(CHECK_MARK(convertible)) + + msg += " → block_gemm_pipeline.scheduler: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; } else { - msg += " → block_gemm.scheduler: [✗] (missing member)\n"; + msg += " → block_gemm_pipeline.scheduler: [✗] (missing member)\n"; } return msg; @@ -872,14 +872,12 @@ consteval auto detailed_diagnostic_SpecifiesGridwiseWmmaGemm() -> std::string { constexpr bool has_n_per_wmma = requires(GG t) { { t.n_per_wmma } -> std::convertible_to; }; constexpr bool has_m_wmma_per_wave = requires(GG t) { { t.m_wmma_per_wave } -> std::convertible_to; }; constexpr bool has_n_wmma_per_wave = requires(GG t) { { t.n_wmma_per_wave } -> std::convertible_to; }; - constexpr bool has_pipeline = requires(GG t) { { t.pipeline_version } -> std::convertible_to; }; msg += " → gridwise_gemm.k1: " + std::string(CHECK_MARK(has_k1)) + (has_k1 ? "\n" : " (missing or wrong type)\n"); msg += " → gridwise_gemm.m_per_wmma: " + std::string(CHECK_MARK(has_m_per_wmma)) + (has_m_per_wmma ? "\n" : " (missing or wrong type)\n"); msg += " → gridwise_gemm.n_per_wmma: " + std::string(CHECK_MARK(has_n_per_wmma)) + (has_n_per_wmma ? "\n" : " (missing or wrong type)\n"); msg += " → gridwise_gemm.m_wmma_per_wave: " + std::string(CHECK_MARK(has_m_wmma_per_wave)) + (has_m_wmma_per_wave ? "\n" : " (missing or wrong type)\n"); msg += " → gridwise_gemm.n_wmma_per_wave: " + std::string(CHECK_MARK(has_n_wmma_per_wave)) + (has_n_wmma_per_wave ? "\n" : " (missing or wrong type)\n"); - msg += " → gridwise_gemm.pipeline_version: " + std::string(CHECK_MARK(has_pipeline)) + (has_pipeline ? "\n" : " (missing or wrong type)\n"); return msg; } @@ -1194,4 +1192,16 @@ consteval auto detailed_diagnostic_SpecifiesLdsTransfer() -> std::string { return msg; } +template +consteval auto detailed_diagnostic_SpecifiedGridwiseGemmPipeline() -> std::string { + if constexpr (requires { T::pipeline_version; }) { + using PipelineType = decltype(T::pipeline_version); + constexpr bool convertible = std::convertible_to; + return " → T::pipeline_version: " + std::string(CHECK_MARK(convertible)) + + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; + } else { + return " → T::pipeline_version: [✗] (missing member)\n"; + } +} + } // namespace ck_tile::builder::diagnostics diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index 15c383fb7a8..588e5eb6980 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -136,6 +136,7 @@ struct FwdWmmaAlgorithm { CHECK_CONCEPT(T, SpecifiesGemmSpecialization) CHECK_CONCEPT(T, SpecifiesNumPrefetchStages) CHECK_CONCEPT(T, SpecifiesLoopScheduler) + CHECK_CONCEPT(T, SpecifiedGridwiseGemmPipeline) static constexpr bool c1 = c_ConvAlgorithmDescriptor; static constexpr bool c2 = c_SpecifiesThreadBlock; @@ -148,9 +149,10 @@ struct FwdWmmaAlgorithm { static constexpr bool c9 = c_SpecifiesGemmSpecialization; static constexpr bool c10 = c_SpecifiesNumPrefetchStages; static constexpr bool c11 = c_SpecifiesLoopScheduler; + static constexpr bool c12 = c_SpecifiedGridwiseGemmPipeline; static consteval bool is_valid() { - return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && c10 && c11; + return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && c10 && c11 && c12; } static consteval auto message() -> std::string { @@ -166,7 +168,8 @@ struct FwdWmmaAlgorithm { DIAGNOSTIC_LINE(SpecifiesFwdConvSpecialization) + DIAGNOSTIC_LINE(SpecifiesGemmSpecialization) + DIAGNOSTIC_LINE(SpecifiesNumPrefetchStages) + - DIAGNOSTIC_LINE(SpecifiesLoopScheduler); + DIAGNOSTIC_LINE(SpecifiesLoopScheduler) + + DIAGNOSTIC_LINE(SpecifiedGridwiseGemmPipeline); } }; @@ -241,19 +244,19 @@ struct TileAlgorithm { }; template -struct LargeTensorAlgorithm : public FwdXdlAlgorithm +struct LargeTensorAlgorithm : public FwdXdlAlgorithm { - using BaseAlgorithmType = decltype(T::base_algorithm); CHECK_CONCEPT(T, SpecifiesLargeTensorSupport) static constexpr bool c13 = c_SpecifiesLargeTensorSupport; static consteval bool is_valid() { - return FwdXdlAlgorithm::is_valid() && c13; + // Note: Check first if the specialization is set. + return c13 && FwdXdlAlgorithm::is_valid(); } static consteval auto message() -> std::string { - return FwdXdlAlgorithm::message() + + return FwdXdlAlgorithm::message() + DIAGNOSTIC_LINE(SpecifiesLargeTensorSupport); } }; @@ -437,6 +440,53 @@ struct BwdTwoStageXdlAlgorithm { } }; +template +struct BwdWmmaV3Algorithm { + CHECK_CONCEPT(T, ConvAlgorithmDescriptor) + CHECK_CONCEPT(T, SpecifiesThreadBlock) + CHECK_CONCEPT(T, SpecifiesBlockTransfer) + CHECK_CONCEPT(T, SpecifiesLdsTransfer) + CHECK_CONCEPT(T, SpecifiesThreadClusterAccessOrder) + CHECK_CONCEPT(T, SpecifiesSourceAccessOrder) + CHECK_CONCEPT(T, SpecifiesGridwiseWmmaGemm) + CHECK_CONCEPT(T, SpecifiesBwdWeightConvSpecialization) + CHECK_CONCEPT(T, SpecifiesBlockGemm) + CHECK_CONCEPT(T, TransposeTransferWellDefinedIfProvided) + CHECK_CONCEPT(T, SpecifiesGenericInstance) + + static constexpr bool c1 = c_ConvAlgorithmDescriptor; + static constexpr bool c2 = c_SpecifiesThreadBlock; + static constexpr bool c3 = c_SpecifiesBlockTransfer; + static constexpr bool c4 = c_SpecifiesLdsTransfer; + static constexpr bool c5 = c_SpecifiesThreadClusterAccessOrder; + static constexpr bool c6 = c_SpecifiesSourceAccessOrder; + static constexpr bool c7 = c_SpecifiesGridwiseWmmaGemm; + static constexpr bool c8 = c_SpecifiesBwdWeightConvSpecialization; + static constexpr bool c9 = c_SpecifiesBlockGemm; + static constexpr bool c10 = c_TransposeTransferWellDefinedIfProvided; + static constexpr bool c11 = c_SpecifiesGenericInstance; + + static consteval bool is_valid() { + return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && c10 & c11; + } + + static consteval auto message() -> std::string { + return std::string("\n=== Backward WMMA V3 Algorithm Diagnostic (closest match) ===\n" + "Concepts for BwdXdlV3 Algorithm:\n") + + DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + + DIAGNOSTIC_LINE(SpecifiesThreadBlock) + + DIAGNOSTIC_LINE(SpecifiesBlockTransfer) + + DIAGNOSTIC_LINE(SpecifiesLdsTransfer) + + DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) + + DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) + + DIAGNOSTIC_LINE(SpecifiesGridwiseWmmaGemm) + + DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization) + + DIAGNOSTIC_LINE(SpecifiesBlockGemm) + + DIAGNOSTIC_LINE(TransposeTransferWellDefinedIfProvided) + + DIAGNOSTIC_LINE(SpecifiesGenericInstance); + } +}; + template struct BwdDlAlgorithm { CHECK_CONCEPT(T, ConvAlgorithmDescriptor) @@ -556,29 +606,52 @@ consteval void diagnose_fwd_algorithm_signature() constexpr int max_4 = max_3 > large_tensor_matches ? max_3 : large_tensor_matches; constexpr int max_matches = max_4 > tile_matches ? max_4 : tile_matches; - // Generate detailed diagnostic for the closest match - if constexpr(max_matches == xdl_v3_matches) { - using Alg = FwdXdlV3Algorithm; - static_assert(Alg::is_valid(), Alg::message()); - } else if constexpr(max_matches == xdl_matches) { - using Alg = FwdXdlAlgorithm; - static_assert(Alg::is_valid(), Alg::message()); - } else if constexpr(max_matches == wmma_matches) { + // Check whether we have XDL or WMMA algorithm + if constexpr (SpecifiesGridwiseFwdXdlGemm) + { + if constexpr(max_matches == xdl_v3_matches) { + using Alg = FwdXdlV3Algorithm; + static_assert(Alg::is_valid(), Alg::message()); + } else if constexpr(max_matches == xdl_matches) { + using Alg = FwdXdlAlgorithm; + static_assert(Alg::is_valid(), Alg::message()); + } else if constexpr (max_matches == large_tensor_matches) { + using Alg = LargeTensorAlgorithm; + static_assert(Alg::is_valid(), Alg::message()); + } + } + else if constexpr (SpecifiesGridwiseWmmaGemm) + { using Alg = FwdWmmaAlgorithm; static_assert(Alg::is_valid(), Alg::message()); - } else if constexpr(max_matches == dl_matches) { - using Alg = FwdDlAlgorithm; - static_assert(Alg::is_valid(), Alg::message()); - } else if constexpr (max_matches == large_tensor_matches) { - using Alg = LargeTensorAlgorithm; - static_assert(Alg::is_valid(), Alg::message()); - } else if constexpr (max_matches == tile_matches) { - using Alg = TileAlgorithm; - static_assert(Alg::is_valid(), Alg::message()); } - else { - // This should never happen - static_assert(false, "Internal Error: No matching algorithm variant found for diagnostics."); + else + { + // If we cannot match with neither WMMA nor XDL, try all algorithms for diagnostics + // and see whichi is the closest match. + if constexpr(max_matches == xdl_v3_matches) { + using Alg = FwdXdlV3Algorithm; + static_assert(Alg::is_valid(), Alg::message()); + } else if constexpr(max_matches == xdl_matches) { + using Alg = FwdXdlAlgorithm; + static_assert(Alg::is_valid(), Alg::message()); + } else if constexpr(max_matches == wmma_matches) { + using Alg = FwdWmmaAlgorithm; + static_assert(Alg::is_valid(), Alg::message()); + } else if constexpr(max_matches == dl_matches) { + using Alg = FwdDlAlgorithm; + static_assert(Alg::is_valid(), Alg::message()); + } else if constexpr (max_matches == large_tensor_matches) { + using Alg = LargeTensorAlgorithm; + static_assert(Alg::is_valid(), Alg::message()); + } else if constexpr (max_matches == tile_matches) { + using Alg = TileAlgorithm; + static_assert(Alg::is_valid(), Alg::message()); + } + else { + // This should never happen + static_assert(false, "Internal Error: No matching algorithm variant found for diagnostics."); + } } } diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp index 6ac9b7f83c9..e5789ec05f2 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -95,13 +95,6 @@ namespace ck_tile::builder::factory { // // TODO: Make this dispatch logic much more robust and clear for users. -// Reference algorithm (simplest implementation for validation) -template -concept IsReferenceAlgorithm = ConvAlgorithmDescriptor && requires { - { T::specialization } -> std::convertible_to; - requires T::specialization == ConvAlgorithmSpecialization::REFERENCE; -}; - template @@ -110,7 +103,7 @@ constexpr auto make_conv_instance() using AlgoType = std::remove_const_t; // Reference algorithm supports all directions - if constexpr(IsReferenceAlgorithm) + if constexpr(ReferenceAlgorithm::is_valid()) { return typename ReferenceFactory::Instance{}; } @@ -178,6 +171,11 @@ constexpr auto make_conv_instance() { return typename ConvBwdWeightMultiDXdlFactory::Instance{}; } + else if constexpr (BwdWmmaV3Algorithm::is_valid()) + { + static_assert(false, + "Backward weight convolution: WMMA V3 algorithm not yet implemented."); + } else { diagnose_bwd_weight_algorithm_signature(); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp index 456c567aa06..77a930d1ce1 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp @@ -31,24 +31,22 @@ struct ConvFwdLargeTensorFactory using Ops = internal::ElementwiseOps; using AlgorithmType = decltype(ALGORITHM); - static constexpr auto BASE_ALGORITHM = ALGORITHM.base_algorithm; - static constexpr auto FWD_CONV_SPECIALIZATION = - internal::SetFwdConvSpecialization(); - static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization(); + internal::SetFwdConvSpecialization(); + static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization(); static constexpr internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION, .gemm_spec = GEMM_SPECIALIZATION}; - static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler(); - static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = BASE_ALGORITHM.gridwise_gemm; + static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler(); + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; static constexpr auto A_BLOCK_TRANSFER = - internal::SetFwdConvBlockTransfer(); + internal::SetFwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = - internal::SetFwdConvBlockTransfer(); + internal::SetFwdConvBlockTransfer(); static constexpr auto C_BLOCK_TRANSFER = - internal::SetCBlockTransfer(); + internal::SetCBlockTransfer(); // Check limits for the algorithm parameters. static_assert(InputVectorTransferLimits); @@ -78,7 +76,7 @@ struct ConvFwdLargeTensorFactory typename Ops::CDEElementwiseOp, SPECIALIZATION.conv_spec, SPECIALIZATION.gemm_spec, - BASE_ALGORITHM.num_gemm_k_prefetch_stages, + ALGORITHM.num_gemm_k_prefetch_stages, BLOCK.block_size, BLOCK.per_block.m, BLOCK.per_block.n, diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp index f3cf3bcc3cc..dde8e19f434 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp @@ -38,7 +38,7 @@ struct BlockGemmSpec template consteval BlockGemmSpec SetBlockGemm() { - constexpr auto& BG = ALGORITHM.block_gemm; + constexpr auto& BG = ALGORITHM.block_gemm_pipeline; ck::BlockGemmPipelineScheduler scheduler; ck::BlockGemmPipelineVersion version; @@ -83,7 +83,7 @@ consteval ck::LoopScheduler SetLoopScheduler() template consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion() { - constexpr auto pipeline_version = ALGORITHM.gridwise_gemm.pipeline_version; + constexpr auto pipeline_version = ALGORITHM.pipeline_version; using ck_pipeline = ck::PipelineVersion; switch(pipeline_version) { diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp index 22911a1a26c..fc9cf44b7a9 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp @@ -30,10 +30,11 @@ TEST(FwdConvInstances, constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle{} .with_thread_block(ThreadBlock_128_64x64x64) - .with_gemm_config(FwdGemmParams_Wmma_2x1_per_wave) + .with_gemm_config(GemmParams_Wmma_2x1_per_wave) .with_transfer(Transfer_4x32x1) .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) - .with_prefetch_config(1, 0, PipelineScheduler::DEFAULT); + .with_prefetch_config(1, 0, PipelineScheduler::DEFAULT) + .with_gridwise_gemm_pipeline(PipelineVersion::V1); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp index 0d9563e05aa..271788799f3 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp @@ -32,10 +32,10 @@ TEST(FwdConvInstances, constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} - .with_thread_block(FwdThreadBlock_64_64x32x32) + .with_thread_block(ThreadBlock_64_64x32x32) .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) - .with_transfer(FwdTransfer_4x16x1) - .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_transfer(Transfer_4x16x1) + .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp index e7ad9060147..e9e1c5f868c 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp @@ -25,14 +25,13 @@ TEST(FwdConvInstances, .output = {.config = {.layout = GNHWK}}}; constexpr auto FwdConvAlgorithm = - ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{ - .base_algorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} + ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{} .with_thread_block(ThreadBlock_256_256x128x32) .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) .with_transfer(Transfer_4x16x1) .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) - .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT)}; + .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT); using Builder = ConvBuilder; @@ -62,14 +61,13 @@ TEST( .output = {.config = {.layout = GNHWK}}}; constexpr auto FwdConvAlgorithm = - ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{ - .base_algorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} + ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{} .with_thread_block(ThreadBlock_128_128x128x32) .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) .with_transfer(Transfer_4x16x1) .with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0, GemmSpecialization::MNKPadding) - .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT)}; + .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT); using Builder = ConvBuilder; diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index 4140d42ba5d..b562df2608c 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -62,16 +62,15 @@ struct GridwiseWmmaGemm size_t n_per_wmma = 0; size_t m_wmma_per_wave = 0; size_t n_wmma_per_wave = 0; - PipelineVersion pipeline_version; }; static_assert(ckb::GridwiseWmmaGemmDescriptor); -struct BlockGemm +struct BlockGemmPipeline { PipelineVersion pipeline_version; PipelineScheduler scheduler; }; -static_assert(ckb::BlockGemmDescriptor); +static_assert(ckb::BlockGemmPipelineDescriptor); // Describe Aand B block transfer thread cluster lengths. template @@ -252,7 +251,12 @@ struct GemmBatchOptions_ struct BlockGemm_ { - BlockGemm block_gemm; + BlockGemmPipeline block_gemm_pipeline; +}; + +struct GridGemm_ +{ + PipelineVersion pipeline_version; }; struct DlThreadConfig_ @@ -289,13 +293,9 @@ struct MultipleDSpecialization_ static constexpr ConvAlgorithmSpecialization specialization = ConvAlgorithmSpecialization::MULTIPLE_D; }; -// Specialization wrapper for large tensor support -template -struct LargeTensorWrapper +struct LargeTensorSpecialization_ { - BaseAlgorithm base_algorithm; - static constexpr ConvAlgorithmSpecialization specialization = - ConvAlgorithmSpecialization::LARGE_TENSOR; + static constexpr ConvAlgorithmSpecialization specialization = ConvAlgorithmSpecialization::LARGE_TENSOR; }; // Specify thread block dimensions for a GEMM (CK Tile). @@ -465,7 +465,15 @@ struct ConvAlgorithmTemplate : Components... { static_assert(std::is_base_of_v); auto result = *this; - result.block_gemm = bg; + result.block_gemm_pipeline = bg; + return result; + } + + constexpr auto with_gridwise_gemm_pipeline(const PipelineVersion plv) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.pipeline_version = plv; return result; } @@ -552,7 +560,7 @@ using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 = ConvAlgorithmTemplate, ConvSpecializationFwd_, BlockGemm_>; using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle = - ConvAlgorithmTemplate, ConvSpecializationFwd_, Prefetch_>; + ConvAlgorithmTemplate, ConvSpecializationFwd_, GridGemm_, Prefetch_>; using ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK = ConvAlgorithmTemplate>; using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor = - LargeTensorWrapper; + ConvAlgorithmTemplate, ConvSpecializationFwd_, Prefetch_, LargeTensorSpecialization_>; using ConvAlgorithm_Tile_GroupedConvolutionKernel = ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, MultipleDSpecialization_>; +using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3 = + ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, BlockGemm_, TransposeParams_>; } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/utils/ckb_conv_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_test_configs.hpp index 1466155c703..299aa56fac3 100644 --- a/experimental/builder/test/utils/ckb_conv_test_configs.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_configs.hpp @@ -275,12 +275,11 @@ constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_2x1_per_wave{ .ak1 = 8, .bk1 = 8, .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 1}}; -constexpr GridwiseWmmaGemm FwdGemmParams_Wmma_2x1_per_wave{.k1 = 8, - .m_per_wmma = 32, - .n_per_wmma = 32, - .m_wmma_per_wave = 2, - .n_wmma_per_wave = 1, - .pipeline_version = PipelineVersion::V1}; +constexpr GridwiseWmmaGemm GemmParams_Wmma_2x1_per_wave{.k1 = 8, + .m_per_wmma = 32, + .n_per_wmma = 32, + .m_wmma_per_wave = 2, + .n_wmma_per_wave = 1}; constexpr ThreadBlock ThreadBlock_256_256x256x32{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; @@ -309,19 +308,19 @@ constexpr ThreadBlock ThreadBlock_128_128x128x32{.block_size = 128, constexpr ThreadBlock ThreadBlock_128_64x64x64{.block_size = 128, .tile_size = {.m = 64, .n = 64, .k = 64}}; -constexpr BlockGemm BlockGemmDesc_v1_intrawave = {.pipeline_version = PipelineVersion::V1, +constexpr BlockGemmPipeline BlockGemmDesc_v1_intrawave = {.pipeline_version = PipelineVersion::V1, .scheduler = PipelineScheduler::INTRAWAVE}; -constexpr BlockGemm BlockGemmDesc_v2_intrawave = {.pipeline_version = PipelineVersion::V2, +constexpr BlockGemmPipeline BlockGemmDesc_v2_intrawave = {.pipeline_version = PipelineVersion::V2, .scheduler = PipelineScheduler::INTRAWAVE}; -constexpr BlockGemm BlockGemmDesc_v3_intrawave = {.pipeline_version = PipelineVersion::V3, +constexpr BlockGemmPipeline BlockGemmDesc_v3_intrawave = {.pipeline_version = PipelineVersion::V3, .scheduler = PipelineScheduler::INTRAWAVE}; -constexpr BlockGemm BlockGemmDesc_v4_intrawave = {.pipeline_version = PipelineVersion::V4, +constexpr BlockGemmPipeline BlockGemmDesc_v4_intrawave = {.pipeline_version = PipelineVersion::V4, .scheduler = PipelineScheduler::INTRAWAVE}; -constexpr BlockGemm BlockGemmDesc_v5_intrawave = {.pipeline_version = PipelineVersion::V5, +constexpr BlockGemmPipeline BlockGemmDesc_v5_intrawave = {.pipeline_version = PipelineVersion::V5, .scheduler = PipelineScheduler::INTRAWAVE}; } // namespace ck_tile::builder::test_utils diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index e48faefc9f8..ba3173cc9a1 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -113,7 +113,7 @@ inline std::string to_string(GridwiseWmmaGemm t) } template <> -inline std::string to_string(BlockGemm t) +inline std::string to_string(BlockGemmPipeline t) { std::ostringstream oss; oss << to_string(t.scheduler) << "," << to_string(t.pipeline_version); @@ -313,7 +313,7 @@ inline std::string to_string(Prefetch_ t) template <> inline std::string to_string(BlockGemm_ t) { - return to_string(t.block_gemm); + return to_string(t.block_gemm_pipeline); } template <> @@ -388,7 +388,10 @@ template <> inline std::string to_string( ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor t) { - return to_string(t.base_algorithm); + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); + return oss.str(); } template <> From bc3cba873bc523c24f64328dbe5a0d30798cf539 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Fri, 2 Jan 2026 06:08:29 -0500 Subject: [PATCH 40/81] Add factory and tests for DeviceGroupedConvBwdWeight_Wmma_CShuffleV3. --- .../builder/factory/conv_algorithms.hpp | 122 +++++++++++++----- .../factory/conv_bwd_weight_dl_factory.hpp | 3 +- .../conv_bwd_weight_multi_d_xdl_factory.hpp | 4 +- .../conv_bwd_weight_two_stage_xdl_factory.hpp | 4 +- .../conv_bwd_weight_wmma_v3_factory.hpp | 104 +++++++++++++++ .../factory/conv_bwd_weight_xdl_factory.hpp | 4 +- .../conv_bwd_weight_xdl_v3_factory.hpp | 4 +- .../builder/factory/conv_dispatcher.hpp | 4 +- experimental/builder/test/CMakeLists.txt | 1 + ...t_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp | 44 +++++++ .../test/utils/conv_algorithm_type_utils.hpp | 19 +++ 11 files changed, 268 insertions(+), 45 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp create mode 100644 experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index 588e5eb6980..1347b132e8c 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -570,6 +570,12 @@ consteval int count_matches_bwd_two_stage_xdl() { return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + Alg::c10 + Alg::c11 + Alg::c12; } +template +consteval int count_matches_bwd_wmma_v3() { + using Alg = BwdWmmaV3Algorithm; + return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + Alg::c10 + Alg::c11; +} + template consteval int count_matches_bwd_dl() { using Alg = BwdDlAlgorithm; @@ -599,26 +605,26 @@ consteval void diagnose_fwd_algorithm_signature() constexpr int large_tensor_matches = count_matches_large_tensor(); constexpr int tile_matches = count_matches_tile(); - // Find maximum matches across all variants - constexpr int max_1 = xdl_v3_matches > xdl_matches ? xdl_v3_matches : xdl_matches; - constexpr int max_2 = wmma_matches > dl_matches ? wmma_matches : dl_matches; - constexpr int max_3 = max_1 > max_2 ? max_1 : max_2; - constexpr int max_4 = max_3 > large_tensor_matches ? max_3 : large_tensor_matches; - constexpr int max_matches = max_4 > tile_matches ? max_4 : tile_matches; - // Check whether we have XDL or WMMA algorithm if constexpr (SpecifiesGridwiseFwdXdlGemm) { + constexpr int max_1 = xdl_v3_matches > xdl_matches ? xdl_v3_matches : xdl_matches; + constexpr int max_2 = max_1 > dl_matches ? max_1 : dl_matches; + constexpr int max_matches = large_tensor_matches > max_2 ? large_tensor_matches : max_2; + if constexpr(max_matches == xdl_v3_matches) { using Alg = FwdXdlV3Algorithm; static_assert(Alg::is_valid(), Alg::message()); } else if constexpr(max_matches == xdl_matches) { using Alg = FwdXdlAlgorithm; static_assert(Alg::is_valid(), Alg::message()); + } else if constexpr(max_matches == dl_matches) { + using Alg = FwdDlAlgorithm; + static_assert(Alg::is_valid(), Alg::message()); } else if constexpr (max_matches == large_tensor_matches) { using Alg = LargeTensorAlgorithm; static_assert(Alg::is_valid(), Alg::message()); - } + } } else if constexpr (SpecifiesGridwiseWmmaGemm) { @@ -627,6 +633,13 @@ consteval void diagnose_fwd_algorithm_signature() } else { + // Find maximum matches across all variants + constexpr int max_1 = xdl_v3_matches > xdl_matches ? xdl_v3_matches : xdl_matches; + constexpr int max_2 = wmma_matches > dl_matches ? wmma_matches : dl_matches; + constexpr int max_3 = max_1 > max_2 ? max_1 : max_2; + constexpr int max_4 = max_3 > large_tensor_matches ? max_3 : large_tensor_matches; + constexpr int max_matches = max_4 > tile_matches ? max_4 : tile_matches; + // If we cannot match with neither WMMA nor XDL, try all algorithms for diagnostics // and see whichi is the closest match. if constexpr(max_matches == xdl_v3_matches) { @@ -663,35 +676,78 @@ consteval void diagnose_bwd_weight_algorithm_signature() constexpr int two_stage_xdl_matches = count_matches_bwd_two_stage_xdl(); constexpr int dl_matches = count_matches_bwd_dl(); constexpr int multi_d_xdl_matches = count_matches_bwd_multi_d_xdl(); + constexpr int wmma_v3_matches = count_matches_bwd_wmma_v3(); - constexpr int max1 = xdl_v3_matches > xdl_matches ? xdl_v3_matches : xdl_matches; - constexpr int max2 = max1 > two_stage_xdl_matches ? max1 : two_stage_xdl_matches; - constexpr int max3 = max2 > dl_matches ? max2 : dl_matches; - constexpr int max_matches = max3 > multi_d_xdl_matches ? max3 : multi_d_xdl_matches; + // Check whether we have XDL or WMMA algorithm + if constexpr (SpecifiesGridwiseBwdXdlGemm) + { + constexpr int max1 = xdl_v3_matches > xdl_matches ? xdl_v3_matches : xdl_matches; + constexpr int max2 = max1 > two_stage_xdl_matches ? max1 : two_stage_xdl_matches; + constexpr int max3 = max2 > dl_matches ? max2 : dl_matches; + constexpr int max_matches = max3 > multi_d_xdl_matches ? max3 : multi_d_xdl_matches; - if constexpr (max_matches == xdl_matches) { - using Alg = BwdXdlAlgorithm; - static_assert(Alg::is_valid(), Alg::message()); - } - else if constexpr (max_matches == xdl_v3_matches) { - using Alg = BwdXdlV3Algorithm; - static_assert(Alg::is_valid(), Alg::message()); - } - else if constexpr (max_matches == two_stage_xdl_matches) { - using Alg = BwdTwoStageXdlAlgorithm; - static_assert(Alg::is_valid(), Alg::message()); - } - else if constexpr (max_matches == dl_matches) { - using Alg = BwdDlAlgorithm; - static_assert(Alg::is_valid(), Alg::message()); + if constexpr (max_matches == xdl_matches) { + using Alg = BwdXdlAlgorithm; + static_assert(Alg::is_valid(), Alg::message()); + } + else if constexpr (max_matches == xdl_v3_matches) { + using Alg = BwdXdlV3Algorithm; + static_assert(Alg::is_valid(), Alg::message()); + } + else if constexpr (max_matches == two_stage_xdl_matches) { + using Alg = BwdTwoStageXdlAlgorithm; + static_assert(Alg::is_valid(), Alg::message()); + } + else if constexpr (max_matches == dl_matches) { + using Alg = BwdDlAlgorithm; + static_assert(Alg::is_valid(), Alg::message()); + } + else if constexpr (max_matches == multi_d_xdl_matches) { + using Alg = BwdMultiDXdlAlgorithm; + static_assert(Alg::is_valid(), Alg::message()); + } } - else if constexpr (max_matches == multi_d_xdl_matches) { - using Alg = BwdMultiDXdlAlgorithm; - static_assert(Alg::is_valid(), Alg::message()); + else if constexpr (SpecifiesGridwiseWmmaGemm) + { + constexpr int max_matches = wmma_v3_matches; + if constexpr (max_matches == wmma_v3_matches) { + using Alg = BwdWmmaV3Algorithm; + static_assert(Alg::is_valid(), Alg::message()); + } } - else { - // This should never happen - static_assert(false, "Internal Error: No matching algorithm variant found for diagnostics."); + else + { + // If we cannot match with neither WMMA nor XDL, try all algorithms for diagnostics + // and see which is the closest match. + constexpr int max1 = xdl_v3_matches > xdl_matches ? xdl_v3_matches : xdl_matches; + constexpr int max2 = max1 > two_stage_xdl_matches ? max1 : two_stage_xdl_matches; + constexpr int max3 = max2 > dl_matches ? max2 : dl_matches; + constexpr int max_matches = max3 > multi_d_xdl_matches ? max3 : multi_d_xdl_matches; + + if constexpr (max_matches == xdl_matches) { + using Alg = BwdXdlAlgorithm; + static_assert(Alg::is_valid(), Alg::message()); + } + else if constexpr (max_matches == xdl_v3_matches) { + using Alg = BwdXdlV3Algorithm; + static_assert(Alg::is_valid(), Alg::message()); + } + else if constexpr (max_matches == two_stage_xdl_matches) { + using Alg = BwdTwoStageXdlAlgorithm; + static_assert(Alg::is_valid(), Alg::message()); + } + else if constexpr (max_matches == dl_matches) { + using Alg = BwdDlAlgorithm; + static_assert(Alg::is_valid(), Alg::message()); + } + else if constexpr (max_matches == multi_d_xdl_matches) { + using Alg = BwdMultiDXdlAlgorithm; + static_assert(Alg::is_valid(), Alg::message()); + } + else { + // This should never happen + static_assert(false, "Internal Error: No matching algorithm variant found for diagnostics."); + } } } diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_dl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_dl_factory.hpp index 0ddeae2eeda..203280c0630 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_dl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_dl_factory.hpp @@ -10,13 +10,12 @@ #include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" #include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" #include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" -//#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" #include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" namespace ck_tile::builder::factory { // Factory for DeviceGroupedConvBwdWeight_Dl instance -// of a grouped forward convolution kernel. +// of a grouped bwd weight convolution kernel. template diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp index f4787ddc4ab..2c11f3b4366 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp @@ -17,8 +17,8 @@ namespace ck_tile::builder::factory { -// Factory for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle instance -// of a grouped forward convolution kernel. +// Factory for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle instance +// of a grouped bwd weight convolution kernel. template diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp index b9852127e86..fae8ad7b873 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp @@ -17,8 +17,8 @@ namespace ck_tile::builder::factory { -// Factory for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle instance -// of a grouped forward convolution kernel. +// Factory for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle instance +// of a grouped bwd weight convolution kernel. template diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp new file mode 100644 index 00000000000..ab75f5b072c --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp @@ -0,0 +1,104 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_limits.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" + +namespace ck_tile::builder::factory { + +// Factory for DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3 instance +// of a grouped bwd weight convolution kernel. +template + requires ConvDirectionIsBackwardWeight +struct ConvBwdWeightWmmaV3Factory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = internal::ConvTensorLayouts; + using Types = internal::BwdWeightConvTensorDataTypes; + using Ops = internal::ElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto BWD_CONV_SPECIALIZATION = internal::SetBwdWeightConvSpecialization(); + + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto A_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); + static constexpr auto BLOCK_GEMM = internal::SetBlockGemm(); + + // Check limits for the algorithm parameters. + // TODO: Add more limits checks as needed. + static_assert(InputVectorTransferLimits, "Invalid A block transfer config"); + static_assert(InputVectorTransferLimits, "Invalid B block transfer config"); + static_assert(OutputVectorTransferLimits, "Invalid C block transfer config"); + static_assert(AccessOrderLimits3D, "Invalid A thread cluster access order"); + static_assert(AccessOrderLimits3D, "Invalid B thread cluster access order"); + static_assert(AccessOrderLimits3D, "Invalid A source access order"); + static_assert(AccessOrderLimits3D, "Invalid B source access order"); + + // The forward convolution kernel class instance. + using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< + SPATIAL_DIM, + typename Layouts::InLayout, + typename Layouts::WeiLayout, + typename Layouts::OutLayout, + typename Types::InDataType, + typename Types::WeiDataType, + typename Types::OutDataType, + typename Types::AccDataType, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, + BWD_CONV_SPECIALIZATION, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.k1, + GRIDWISE_GEMM.m_per_wmma, + GRIDWISE_GEMM.n_per_wmma, + GRIDWISE_GEMM.m_wmma_per_wave, + GRIDWISE_GEMM.n_wmma_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + A_BLOCK_TRANSFER.lds_padding, + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + B_BLOCK_TRANSFER.lds_padding, + C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + BLOCK_GEMM.scheduler, + BLOCK_GEMM.pipeline_version, + typename Types::InComputeType, + typename Types::WeiComputeType, + ALGORITHM.max_transpose_transfer_src_scalar_per_vector, + ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp index 8790121ed93..8ffbb495ec2 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp @@ -17,8 +17,8 @@ namespace ck_tile::builder::factory { -// Factory for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle instance -// of a grouped forward convolution kernel. +// Factory for DeviceGroupedConvBwdWeight_Xdl_CShuffle instance +// of a grouped bwd weight convolution kernel. template diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp index 14121be940a..20aac55f311 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp @@ -17,8 +17,8 @@ namespace ck_tile::builder::factory { -// Factory for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle instance -// of a grouped forward convolution kernel. +// Factory for DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 instance +// of a grouped bwd weight convolution kernel. template diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp index e5789ec05f2..6aecd779db3 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -73,6 +73,7 @@ #include "ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp" #include "ck_tile/builder/factory/conv_bwd_weight_dl_factory.hpp" #include "ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp" +#include "ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp" namespace ck_tile::builder::factory { @@ -173,8 +174,7 @@ constexpr auto make_conv_instance() } else if constexpr (BwdWmmaV3Algorithm::is_valid()) { - static_assert(false, - "Backward weight convolution: WMMA V3 algorithm not yet implemented."); + return typename ConvBwdWeightWmmaV3Factory::Instance{}; } else { diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 90c6d20c83b..f72a72f5fba 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -170,6 +170,7 @@ add_ck_builder_test(test_ckb_build_bwd_weight_instances conv/ck/test_ckb_conv_bwd_weight_multi_d_xdl_cshuffle.cpp conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp conv/ck/test_ckb_conv_bwd_weight_dl.cpp + conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp ) target_link_libraries(test_ckb_build_bwd_weight_instances PRIVATE utility) diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp new file mode 100644 index 00000000000..268bdee5bec --- /dev/null +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" +#include "ck_tile/host/device_prop.hpp" + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace cku = ck_tile::builder::test_utils; +using enum ck_tile::builder::TensorLayout; + +constexpr auto SIGNATURE = + ckt::ConvSignature{.spatial_dim = 1, + .direction = ckb::ConvDirection::BACKWARD_WEIGHT, + .data_type = ckb::DataType::BF16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = NGCW}}, + .weight = {.config = {.layout = GKXC}}, + .output = {.config = {.layout = NGKW}}}; + +constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3{} + .with_thread_block(cku::ThreadBlock_64_32x32x32) + .with_gemm_config(cku::GemmParams_Wmma_2x1_per_wave) + .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) + .with_bwd_specialization(ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0) + .with_block_gemm(cku::BlockGemmDesc_v2_intrawave) + .with_transpose_params(4,4); + +using Builder = ckb::ConvBuilder; +using Instance = Builder::Instance; + +TEST(BwdWeight_1DBf16_Wmma_CShuffle_V3, Create) +{ + const auto expected_transfer_parameters = to_string(ALGORITHM); + cku::run_test({"DeviceGroupedConvBwdWeight_Wmma_CShuffleV3", + expected_transfer_parameters, + "Filter1x1Stride1Pad0", + "NGCW,GKXC,NGKW", + "PassThrough,PassThrough,PassThrough", + "Intrawave", + "v2"}); +} diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index ba3173cc9a1..bfd16814daf 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -237,6 +237,15 @@ inline std::string to_string(DlEpilogue t) return oss.str(); } +template <> +inline std::string to_string(TransposeParams_ t) +{ + std::ostringstream oss; + oss << t.max_transpose_transfer_src_scalar_per_vector << "," + << t.max_transpose_transfer_dst_scalar_per_vector; + return oss.str(); +} + template <> inline std::string to_string>(DlTransfer<4> t) { @@ -414,6 +423,16 @@ inline std::string to_string +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3 t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); + return oss.str(); +} + template <> inline std::string to_string( ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle t) From 2e43e16e4703c84a2d18ede983e82e1b5b925edc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Fri, 2 Jan 2026 07:06:51 -0500 Subject: [PATCH 41/81] Update Readme. --- experimental/builder/README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/experimental/builder/README.md b/experimental/builder/README.md index 940ee3e503b..7a93c395c07 100644 --- a/experimental/builder/README.md +++ b/experimental/builder/README.md @@ -45,6 +45,11 @@ cmake .. ``` +Note: When compiling e.g. for the `gfx942` architecture, the WMMA builders are not automatically included in the tests +since `gfx9` architectures do not support WMMA. Hence, to compile also the WMMA builders, add e.g. +`gfx1121` to the list of supported architectures or add flag `-D CK_USE_WMMA=ON`. One still needs +a Navi card to execute the Builder tests that use the GPU. + ## Building and Testing The builder test suite is organized into two main categories: From 89934275f48b3540df577e6187849fe935ff4b84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Fri, 2 Jan 2026 07:07:08 -0500 Subject: [PATCH 42/81] Fix WMMA bwd weight tests. --- experimental/builder/test/CMakeLists.txt | 11 +++++++++-- .../ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp | 6 +++--- .../builder/test/utils/ckb_conv_test_configs.hpp | 6 ++++++ 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index f72a72f5fba..3d6e448d84b 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -164,15 +164,22 @@ add_ck_builder_test(test_ckb_build_fwd_instances ) target_link_libraries(test_ckb_build_fwd_instances PRIVATE utility) -add_ck_builder_test(test_ckb_build_bwd_weight_instances +set(BWD_WEIGHT_TESTS conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle.cpp conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp conv/ck/test_ckb_conv_bwd_weight_multi_d_xdl_cshuffle.cpp conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp conv/ck/test_ckb_conv_bwd_weight_dl.cpp - conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp +) + +if (CK_USE_WMMA) + list(APPEND BWD_WEIGHT_TESTS + conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp ) +endif() + +add_ck_builder_test(test_ckb_build_bwd_weight_instances ${BWD_WEIGHT_TESTS}) target_link_libraries(test_ckb_build_bwd_weight_instances PRIVATE utility) add_ck_builder_test(test_ckb_build_bwd_data_instances diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp index 268bdee5bec..4a1a60e852f 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp @@ -22,10 +22,10 @@ constexpr auto SIGNATURE = constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3{} .with_thread_block(cku::ThreadBlock_64_32x32x32) - .with_gemm_config(cku::GemmParams_Wmma_2x1_per_wave) + .with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave) .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) .with_bwd_specialization(ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0) - .with_block_gemm(cku::BlockGemmDesc_v2_intrawave) + .with_block_gemm(cku::BlockGemmDesc_v1_intrawave) .with_transpose_params(4,4); using Builder = ckb::ConvBuilder; @@ -40,5 +40,5 @@ TEST(BwdWeight_1DBf16_Wmma_CShuffle_V3, Create) "NGCW,GKXC,NGKW", "PassThrough,PassThrough,PassThrough", "Intrawave", - "v2"}); + "v1"}); } diff --git a/experimental/builder/test/utils/ckb_conv_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_test_configs.hpp index 299aa56fac3..e1f5a34e20c 100644 --- a/experimental/builder/test/utils/ckb_conv_test_configs.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_configs.hpp @@ -281,6 +281,12 @@ constexpr GridwiseWmmaGemm GemmParams_Wmma_2x1_per_wave{.k1 = 8, .m_wmma_per_wave = 2, .n_wmma_per_wave = 1}; +constexpr GridwiseWmmaGemm GemmParams_Wmma_16x16_2x1_per_wave{.k1 = 8, + .m_per_wmma = 16, + .n_per_wmma = 16, + .m_wmma_per_wave = 2, + .n_wmma_per_wave = 1}; + constexpr ThreadBlock ThreadBlock_256_256x256x32{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; From aa10d659b707fa636e2b3a1b80b787e48e6823f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Fri, 2 Jan 2026 07:51:40 -0500 Subject: [PATCH 43/81] Added factory and tests for DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3. --- .../builder/factory/conv_algorithms.hpp | 95 +++++++++++++--- ...v_bwd_weight_two_stage_wmma_v3_factory.hpp | 105 ++++++++++++++++++ .../builder/factory/conv_dispatcher.hpp | 5 + experimental/builder/test/CMakeLists.txt | 1 + ..._bwd_weight_two_stage_wmma_cshuffle_v3.cpp | 47 ++++++++ ...t_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp | 4 +- .../test/impl/conv_algorithm_types.hpp | 4 + .../test/utils/conv_algorithm_type_utils.hpp | 10 ++ 8 files changed, 256 insertions(+), 15 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp create mode 100644 experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_wmma_cshuffle_v3.cpp diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index 1347b132e8c..d970de795f2 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -74,7 +74,7 @@ struct FwdXdlV3Algorithm { }; template -struct FwdXdlAlgorithm { +struct FwdXdlAlgorithmBase { CHECK_CONCEPT(T, ConvAlgorithmDescriptor) CHECK_CONCEPT(T, SpecifiesThreadBlock) CHECK_CONCEPT(T, SpecifiesBlockTransfer) @@ -106,8 +106,7 @@ struct FwdXdlAlgorithm { } static consteval auto message() -> std::string { - return std::string("\n=== Forward XDL Algorithm Diagnostic (closest match) ===\n" - "Concepts for FwdXdl Algorithm:\n") + + return DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + DIAGNOSTIC_LINE(SpecifiesThreadBlock) + DIAGNOSTIC_LINE(SpecifiesBlockTransfer) + @@ -123,6 +122,24 @@ struct FwdXdlAlgorithm { } }; +template +struct FwdXdlAlgorithm : public FwdXdlAlgorithmBase{ + CHECK_CONCEPT(T, SpecifiesGenericInstance) + + static constexpr bool c13 = c_SpecifiesGenericInstance; + + static consteval bool is_valid() { + return c13 && FwdXdlAlgorithmBase::is_valid(); + } + + static consteval auto message() -> std::string { + return std::string("\n=== Forward XDL Algorithm Diagnostic (closest match) ===\n" + "Concepts for FwdXdl Algorithm:\n") + + FwdXdlAlgorithmBase::message() + + DIAGNOSTIC_LINE(SpecifiesGenericInstance); + } +}; + template struct FwdWmmaAlgorithm { CHECK_CONCEPT(T, ConvAlgorithmDescriptor) @@ -244,7 +261,7 @@ struct TileAlgorithm { }; template -struct LargeTensorAlgorithm : public FwdXdlAlgorithm +struct LargeTensorAlgorithm : public FwdXdlAlgorithmBase { CHECK_CONCEPT(T, SpecifiesLargeTensorSupport) @@ -256,7 +273,9 @@ struct LargeTensorAlgorithm : public FwdXdlAlgorithm } static consteval auto message() -> std::string { - return FwdXdlAlgorithm::message() + + return std::string("\n=== Forward XDL Large Tensor Algorithm Diagnostic (closest match) ===\n" + "Concepts for FwdLargeTensorXdl Algorithm:\n") + + FwdXdlAlgorithmBase::message() + DIAGNOSTIC_LINE(SpecifiesLargeTensorSupport); } }; @@ -441,7 +460,7 @@ struct BwdTwoStageXdlAlgorithm { }; template -struct BwdWmmaV3Algorithm { +struct BwdWmmaV3AlgorithmBase { CHECK_CONCEPT(T, ConvAlgorithmDescriptor) CHECK_CONCEPT(T, SpecifiesThreadBlock) CHECK_CONCEPT(T, SpecifiesBlockTransfer) @@ -452,7 +471,6 @@ struct BwdWmmaV3Algorithm { CHECK_CONCEPT(T, SpecifiesBwdWeightConvSpecialization) CHECK_CONCEPT(T, SpecifiesBlockGemm) CHECK_CONCEPT(T, TransposeTransferWellDefinedIfProvided) - CHECK_CONCEPT(T, SpecifiesGenericInstance) static constexpr bool c1 = c_ConvAlgorithmDescriptor; static constexpr bool c2 = c_SpecifiesThreadBlock; @@ -464,15 +482,13 @@ struct BwdWmmaV3Algorithm { static constexpr bool c8 = c_SpecifiesBwdWeightConvSpecialization; static constexpr bool c9 = c_SpecifiesBlockGemm; static constexpr bool c10 = c_TransposeTransferWellDefinedIfProvided; - static constexpr bool c11 = c_SpecifiesGenericInstance; static consteval bool is_valid() { - return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && c10 & c11; + return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && c10; } static consteval auto message() -> std::string { - return std::string("\n=== Backward WMMA V3 Algorithm Diagnostic (closest match) ===\n" - "Concepts for BwdXdlV3 Algorithm:\n") + + return DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + DIAGNOSTIC_LINE(SpecifiesThreadBlock) + DIAGNOSTIC_LINE(SpecifiesBlockTransfer) + @@ -482,11 +498,51 @@ struct BwdWmmaV3Algorithm { DIAGNOSTIC_LINE(SpecifiesGridwiseWmmaGemm) + DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization) + DIAGNOSTIC_LINE(SpecifiesBlockGemm) + - DIAGNOSTIC_LINE(TransposeTransferWellDefinedIfProvided) + + DIAGNOSTIC_LINE(TransposeTransferWellDefinedIfProvided); + } +}; + +template +struct BwdWmmaV3Algorithm : public BwdWmmaV3AlgorithmBase +{ + CHECK_CONCEPT(T, SpecifiesGenericInstance) + + static constexpr bool c11 = c_SpecifiesGenericInstance; + + static consteval bool is_valid() { + return c11 && BwdWmmaV3AlgorithmBase::is_valid(); + } + + static consteval auto message() -> std::string { + return std::string("\n=== Backward WMMA V3 Algorithm Diagnostic (closest match) ===\n" + "Concepts for BwdWmmaV3 Algorithm:\n") + + BwdWmmaV3AlgorithmBase::message() + DIAGNOSTIC_LINE(SpecifiesGenericInstance); } }; +template +struct BwdTwoStageWmmaV3Algorithm : public BwdWmmaV3AlgorithmBase +{ + CHECK_CONCEPT(T, SpecifiesTwoStageSupport) + CHECK_CONCEPT(T, SpecifiesGemmBatchOptions) + + static constexpr bool c11 = c_SpecifiesTwoStageSupport; + static constexpr bool c12 = c_SpecifiesGemmBatchOptions; + + static consteval bool is_valid() { + return c11 && c12 && BwdWmmaV3AlgorithmBase::is_valid(); + } + + static consteval auto message() -> std::string { + return std::string("\n=== Backward Two Stage WMMA V3 Algorithm Diagnostic (closest match) ===\n" + "Concepts for BwdTwoStageWmmaV3 Algorithm:\n") + + BwdWmmaV3AlgorithmBase::message() + + DIAGNOSTIC_LINE(SpecifiesGemmBatchOptions) + + DIAGNOSTIC_LINE(SpecifiesTwoStageSupport); + } +}; + template struct BwdDlAlgorithm { CHECK_CONCEPT(T, ConvAlgorithmDescriptor) @@ -531,7 +587,7 @@ consteval int count_matches_fwd_xdl_v3() { template consteval int count_matches_fwd_xdl() { using Alg = FwdXdlAlgorithm; - return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + Alg::c10 + Alg::c11 + Alg::c12; + return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + Alg::c10 + Alg::c11 + Alg::c12 + Alg::c13; } template @@ -576,6 +632,12 @@ consteval int count_matches_bwd_wmma_v3() { return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + Alg::c10 + Alg::c11; } +template +consteval int count_matches_bwd_two_stage_wmma_v3() { + using Alg = BwdTwoStageWmmaV3Algorithm; + return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + Alg::c10 + Alg::c11 + Alg::c12; +} + template consteval int count_matches_bwd_dl() { using Alg = BwdDlAlgorithm; @@ -677,6 +739,7 @@ consteval void diagnose_bwd_weight_algorithm_signature() constexpr int dl_matches = count_matches_bwd_dl(); constexpr int multi_d_xdl_matches = count_matches_bwd_multi_d_xdl(); constexpr int wmma_v3_matches = count_matches_bwd_wmma_v3(); + constexpr int two_stage_wmma_v3_matches = count_matches_bwd_two_stage_wmma_v3(); // Check whether we have XDL or WMMA algorithm if constexpr (SpecifiesGridwiseBwdXdlGemm) @@ -709,11 +772,15 @@ consteval void diagnose_bwd_weight_algorithm_signature() } else if constexpr (SpecifiesGridwiseWmmaGemm) { - constexpr int max_matches = wmma_v3_matches; + constexpr int max_matches = wmma_v3_matches > two_stage_wmma_v3_matches ? wmma_v3_matches : two_stage_wmma_v3_matches; if constexpr (max_matches == wmma_v3_matches) { using Alg = BwdWmmaV3Algorithm; static_assert(Alg::is_valid(), Alg::message()); } + else if constexpr (max_matches == two_stage_wmma_v3_matches) { + using Alg = BwdTwoStageWmmaV3Algorithm; + static_assert(Alg::is_valid(), Alg::message()); + } } else { diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp new file mode 100644 index 00000000000..635b8222937 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp @@ -0,0 +1,105 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_limits.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" + +namespace ck_tile::builder::factory { + +// Factory for DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffle_V3 instance +// of a grouped bwd weight convolution kernel. +template + requires ConvDirectionIsBackwardWeight +struct ConvBwdWeightTwoStageWmmaV3Factory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = internal::ConvTensorLayouts; + using Types = internal::BwdWeightConvTensorDataTypes; + using Ops = internal::ElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto BWD_CONV_SPECIALIZATION = internal::SetBwdWeightConvSpecialization(); + + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto A_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); + static constexpr auto BLOCK_GEMM = internal::SetBlockGemm(); + + // Check limits for the algorithm parameters. + // TODO: Add more limits checks as needed. + static_assert(InputVectorTransferLimits, "Invalid A block transfer config"); + static_assert(InputVectorTransferLimits, "Invalid B block transfer config"); + static_assert(OutputVectorTransferLimits, "Invalid C block transfer config"); + static_assert(AccessOrderLimits3D, "Invalid A thread cluster access order"); + static_assert(AccessOrderLimits3D, "Invalid B thread cluster access order"); + static_assert(AccessOrderLimits3D, "Invalid A source access order"); + static_assert(AccessOrderLimits3D, "Invalid B source access order"); + + // The forward convolution kernel class instance. + using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< + SPATIAL_DIM, + typename Layouts::InLayout, + typename Layouts::WeiLayout, + typename Layouts::OutLayout, + typename Types::InDataType, + typename Types::WeiDataType, + typename Types::OutDataType, + typename Types::AccDataType, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, + BWD_CONV_SPECIALIZATION, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.k1, + GRIDWISE_GEMM.m_per_wmma, + GRIDWISE_GEMM.n_per_wmma, + GRIDWISE_GEMM.m_wmma_per_wave, + GRIDWISE_GEMM.n_wmma_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + A_BLOCK_TRANSFER.lds_padding, + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + B_BLOCK_TRANSFER.lds_padding, + C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + BLOCK_GEMM.scheduler, + BLOCK_GEMM.pipeline_version, + ALGORITHM.num_conv_groups_to_merge, + typename Types::InComputeType, + typename Types::WeiComputeType, + ALGORITHM.max_transpose_transfer_src_scalar_per_vector, + ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp index 6aecd779db3..557f35f861d 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -74,6 +74,7 @@ #include "ck_tile/builder/factory/conv_bwd_weight_dl_factory.hpp" #include "ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp" #include "ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp" +#include "ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp" namespace ck_tile::builder::factory { @@ -176,6 +177,10 @@ constexpr auto make_conv_instance() { return typename ConvBwdWeightWmmaV3Factory::Instance{}; } + else if constexpr (BwdTwoStageWmmaV3Algorithm::is_valid()) + { + return typename ConvBwdWeightTwoStageWmmaV3Factory::Instance{}; + } else { diagnose_bwd_weight_algorithm_signature(); diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 3d6e448d84b..aedcbb81105 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -176,6 +176,7 @@ set(BWD_WEIGHT_TESTS if (CK_USE_WMMA) list(APPEND BWD_WEIGHT_TESTS conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp + conv/ck/test_ckb_conv_bwd_weight_two_stage_wmma_cshuffle_v3.cpp ) endif() diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_wmma_cshuffle_v3.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_wmma_cshuffle_v3.cpp new file mode 100644 index 00000000000..e581da49510 --- /dev/null +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_wmma_cshuffle_v3.cpp @@ -0,0 +1,47 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" +#include "ck_tile/host/device_prop.hpp" + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace cku = ck_tile::builder::test_utils; +using enum ck_tile::builder::TensorLayout; + +constexpr auto SIGNATURE = + ckt::ConvSignature{.spatial_dim = 2, + .direction = ckb::ConvDirection::BACKWARD_WEIGHT, + .data_type = ckb::DataType::FP16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = NGCHW}}, + .weight = {.config = {.layout = GKYXC}}, + .output = {.config = {.layout = NGKHW}}}; + +constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3{} + .with_thread_block(cku::ThreadBlock_64_32x32x32) + .with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave) + .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) + .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT) + .with_block_gemm(cku::BlockGemmDesc_v1_intrawave) + .with_num_conv_groups_to_merge(2) + .with_transpose_params(2,2); + +using Builder = ckb::ConvBuilder; +using Instance = Builder::Instance; + +TEST(BwdWeight_2DFp16_TwoStage_Wmma_CShuffle_V3, Create) +{ + const auto expected_transfer_parameters = to_string(ALGORITHM); + std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl; + cku::run_test({"DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3", + expected_transfer_parameters, + "Default", + "NGCHW,GKYXC,NGKHW", + "PassThrough,PassThrough,PassThrough", + "Intrawave", + "v1", + "fp16,fp16,2,2>"}); // Check compute types and transpose params. +} diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp index 4a1a60e852f..f58ec11129e 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp @@ -34,11 +34,13 @@ using Instance = Builder::Instance; TEST(BwdWeight_1DBf16_Wmma_CShuffle_V3, Create) { const auto expected_transfer_parameters = to_string(ALGORITHM); + std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl; cku::run_test({"DeviceGroupedConvBwdWeight_Wmma_CShuffleV3", expected_transfer_parameters, "Filter1x1Stride1Pad0", "NGCW,GKXC,NGKW", "PassThrough,PassThrough,PassThrough", "Intrawave", - "v1"}); + "v1", + "bf16,bf16,4,4>"}); } diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index b562df2608c..404e14fced2 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -604,4 +604,8 @@ using ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle = using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3 = ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, BlockGemm_, TransposeParams_>; + +using ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3 = + ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, BlockGemm_, TransposeParams_, GemmBatchOptions_, TwoStageSpecialization_>; + } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index bfd16814daf..161d6be268d 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -433,6 +433,16 @@ inline std::string to_string +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3 t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); + return oss.str(); +} + template <> inline std::string to_string( ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle t) From 1759db72501b27baeae686c4d4f2fb81a5920425 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Fri, 2 Jan 2026 08:52:38 -0500 Subject: [PATCH 44/81] Factory and tests for DeviceGroupedConvBwdWeight_Wmma_CShuffle. --- .../builder/conv_signature_concepts.hpp | 9 ++ .../builder/factory/conv_algorithms.hpp | 104 +++++++++++++----- .../factory/conv_bwd_weight_wmma_factory.hpp | 104 ++++++++++++++++++ .../builder/factory/conv_dispatcher.hpp | 5 + experimental/builder/test/CMakeLists.txt | 1 + ...test_ckb_conv_bwd_weight_wmma_cshuffle.cpp | 44 ++++++++ .../test/impl/conv_algorithm_types.hpp | 10 +- .../test/utils/conv_algorithm_type_utils.hpp | 12 +- 8 files changed, 256 insertions(+), 33 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp create mode 100644 experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp diff --git a/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp index 39e081ec8d5..f3f64b2f5bb 100644 --- a/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp @@ -221,4 +221,13 @@ concept ValidConvWeightLayoutForSpatialDim = (SpatialDim == 1 && ConvWeightLayout1D) || (SpatialDim == 2 && ConvWeightLayout2D) || (SpatialDim == 3 && ConvWeightLayout3D); +// Constraint for 3D conv signature. +template +concept Is3D = requires { + requires Sig.spatial_dim == 3; + requires ConvInputLayout3D; + requires ConvOutputLayout3D; + requires ConvWeightLayout3D; +}; + } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index d970de795f2..cf31ee64c03 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -366,7 +366,7 @@ struct BwdMultiDXdlAlgorithm { }; template -struct BwdXdlV3Algorithm { +struct BwdXdlV3AlgorithmBase { CHECK_CONCEPT(T, ConvAlgorithmDescriptor) CHECK_CONCEPT(T, SpecifiesThreadBlock) CHECK_CONCEPT(T, SpecifiesBlockTransfer) @@ -376,7 +376,6 @@ struct BwdXdlV3Algorithm { CHECK_CONCEPT(T, SpecifiesGridwiseBwdXdlGemm) CHECK_CONCEPT(T, SpecifiesBwdWeightConvSpecialization) CHECK_CONCEPT(T, SpecifiesBlockGemm) - CHECK_CONCEPT(T, SpecifiesGenericInstance) static constexpr bool c1 = c_ConvAlgorithmDescriptor; static constexpr bool c2 = c_SpecifiesThreadBlock; @@ -387,15 +386,13 @@ struct BwdXdlV3Algorithm { static constexpr bool c7 = c_SpecifiesGridwiseBwdXdlGemm; static constexpr bool c8 = c_SpecifiesBwdWeightConvSpecialization; static constexpr bool c9 = c_SpecifiesBlockGemm; - static constexpr bool c10 = c_SpecifiesGenericInstance; static consteval bool is_valid() { - return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && c10; + return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9; } static consteval auto message() -> std::string { - return std::string("\n=== Backward XDL V3 Algorithm Diagnostic (closest match) ===\n" - "Concepts for BwdXdlV3 Algorithm:\n") + + return DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + DIAGNOSTIC_LINE(SpecifiesThreadBlock) + DIAGNOSTIC_LINE(SpecifiesBlockTransfer) + @@ -404,25 +401,66 @@ struct BwdXdlV3Algorithm { DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) + DIAGNOSTIC_LINE(SpecifiesGridwiseBwdXdlGemm) + DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization) + - DIAGNOSTIC_LINE(SpecifiesBlockGemm) + + DIAGNOSTIC_LINE(SpecifiesBlockGemm); + } +}; + +template +struct BwdXdlV3Algorithm : public BwdXdlV3AlgorithmBase{ + CHECK_CONCEPT(T, SpecifiesGenericInstance) + + static constexpr bool c10 = c_SpecifiesGenericInstance; + + static consteval bool is_valid() { + return c10 && BwdXdlV3AlgorithmBase::is_valid(); + } + + static consteval auto message() -> std::string { + return std::string("\n=== Backward XDL V3 Algorithm Diagnostic (closest match) ===\n" + "Concepts for BwdXdlV3 Algorithm:\n") + + BwdXdlV3AlgorithmBase::message() + DIAGNOSTIC_LINE(SpecifiesGenericInstance); } }; template -struct BwdTwoStageXdlAlgorithm { +struct BwdTwoStageXdlAlgorithm : public BwdXdlV3AlgorithmBase{ + CHECK_CONCEPT(T, SpecifiesTransposeTransfer) + CHECK_CONCEPT(T, SpecifiesGemmBatchOptions) + CHECK_CONCEPT(T, SpecifiesTwoStageSupport) + + static constexpr bool c10 = c_SpecifiesTransposeTransfer; + static constexpr bool c11 = c_SpecifiesGemmBatchOptions; + static constexpr bool c12 = c_SpecifiesTwoStageSupport; + + static consteval bool is_valid() { + return c10 && c11 && c12 && BwdXdlV3AlgorithmBase::is_valid(); + } + + static consteval auto message() -> std::string { + return std::string("\n=== Backward two stage XDL Algorithm Diagnostic (closest match) ===\n" + "Concepts for BwdXdlV3 Algorithm:\n") + + BwdXdlV3AlgorithmBase::message() + + DIAGNOSTIC_LINE(SpecifiesTransposeTransfer) + + DIAGNOSTIC_LINE(SpecifiesGemmBatchOptions) + + DIAGNOSTIC_LINE(SpecifiesTwoStageSupport); + } +}; + +template +struct BwdWmmaAlgorithm { CHECK_CONCEPT(T, ConvAlgorithmDescriptor) CHECK_CONCEPT(T, SpecifiesThreadBlock) CHECK_CONCEPT(T, SpecifiesBlockTransfer) CHECK_CONCEPT(T, SpecifiesLdsTransfer) CHECK_CONCEPT(T, SpecifiesThreadClusterAccessOrder) CHECK_CONCEPT(T, SpecifiesSourceAccessOrder) - CHECK_CONCEPT(T, SpecifiesGridwiseBwdXdlGemm) + CHECK_CONCEPT(T, SpecifiesGridwiseWmmaGemm) CHECK_CONCEPT(T, SpecifiesBwdWeightConvSpecialization) - CHECK_CONCEPT(T, SpecifiesBlockGemm) - CHECK_CONCEPT(T, SpecifiesTransposeTransfer) - CHECK_CONCEPT(T, SpecifiesGemmBatchOptions) - CHECK_CONCEPT(T, SpecifiesTwoStageSupport) + CHECK_CONCEPT(T, SpecifiesNumPrefetchStages) + CHECK_CONCEPT(T, SpecifiesLoopScheduler) + CHECK_CONCEPT(T, SpecifiedGridwiseGemmPipeline) + CHECK_CONCEPT(T, SpecifiesGenericInstance) static constexpr bool c1 = c_ConvAlgorithmDescriptor; static constexpr bool c2 = c_SpecifiesThreadBlock; @@ -430,32 +468,32 @@ struct BwdTwoStageXdlAlgorithm { static constexpr bool c4 = c_SpecifiesLdsTransfer; static constexpr bool c5 = c_SpecifiesThreadClusterAccessOrder; static constexpr bool c6 = c_SpecifiesSourceAccessOrder; - static constexpr bool c7 = c_SpecifiesGridwiseBwdXdlGemm; + static constexpr bool c7 = c_SpecifiesGridwiseWmmaGemm; static constexpr bool c8 = c_SpecifiesBwdWeightConvSpecialization; - static constexpr bool c9 = c_SpecifiesBlockGemm; - static constexpr bool c10 = c_SpecifiesTransposeTransfer; - static constexpr bool c11 = c_SpecifiesGemmBatchOptions; - static constexpr bool c12 = c_SpecifiesTwoStageSupport; + static constexpr bool c9 = c_SpecifiesNumPrefetchStages; + static constexpr bool c10 = c_SpecifiesLoopScheduler; + static constexpr bool c11 = c_SpecifiedGridwiseGemmPipeline; + static constexpr bool c12 = c_SpecifiesGenericInstance; static consteval bool is_valid() { - return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && & c10 && c11 && c12; + return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && c10 && c11 && c12; } static consteval auto message() -> std::string { - return std::string("\n=== Backward two stage XDL Algorithm Diagnostic (closest match) ===\n" - "Concepts for BwdXdlV3 Algorithm:\n") + + return std::string("\n=== Backward WMMA Algorithm Diagnostic (closest match) ===\n" + "Concepts for BwdWmma Algorithm:\n") + DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + DIAGNOSTIC_LINE(SpecifiesThreadBlock) + DIAGNOSTIC_LINE(SpecifiesBlockTransfer) + DIAGNOSTIC_LINE(SpecifiesLdsTransfer) + DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) + DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) + - DIAGNOSTIC_LINE(SpecifiesGridwiseBwdXdlGemm) + + DIAGNOSTIC_LINE(SpecifiesGridwiseWmmaGemm) + DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization) + - DIAGNOSTIC_LINE(SpecifiesBlockGemm) + - DIAGNOSTIC_LINE(SpecifiesTransposeTransfer) + - DIAGNOSTIC_LINE(SpecifiesGemmBatchOptions) + - DIAGNOSTIC_LINE(SpecifiesTwoStageSupport); + DIAGNOSTIC_LINE(SpecifiesNumPrefetchStages) + + DIAGNOSTIC_LINE(SpecifiesLoopScheduler) + + DIAGNOSTIC_LINE(SpecifiedGridwiseGemmPipeline) + + DIAGNOSTIC_LINE(SpecifiesGenericInstance); } }; @@ -626,6 +664,12 @@ consteval int count_matches_bwd_two_stage_xdl() { return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + Alg::c10 + Alg::c11 + Alg::c12; } +template +consteval int count_matches_bwd_wmma() { + using Alg = BwdWmmaAlgorithm; + return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + Alg::c10 + Alg::c11 + Alg::c12; +} + template consteval int count_matches_bwd_wmma_v3() { using Alg = BwdWmmaV3Algorithm; @@ -740,6 +784,7 @@ consteval void diagnose_bwd_weight_algorithm_signature() constexpr int multi_d_xdl_matches = count_matches_bwd_multi_d_xdl(); constexpr int wmma_v3_matches = count_matches_bwd_wmma_v3(); constexpr int two_stage_wmma_v3_matches = count_matches_bwd_two_stage_wmma_v3(); + constexpr int wmma_matches = count_matches_bwd_wmma(); // Check whether we have XDL or WMMA algorithm if constexpr (SpecifiesGridwiseBwdXdlGemm) @@ -772,7 +817,8 @@ consteval void diagnose_bwd_weight_algorithm_signature() } else if constexpr (SpecifiesGridwiseWmmaGemm) { - constexpr int max_matches = wmma_v3_matches > two_stage_wmma_v3_matches ? wmma_v3_matches : two_stage_wmma_v3_matches; + constexpr int max_1 = wmma_v3_matches > two_stage_wmma_v3_matches ? wmma_v3_matches : two_stage_wmma_v3_matches; + constexpr int max_matches = max_1 > wmma_matches ? max_1 : wmma_matches; if constexpr (max_matches == wmma_v3_matches) { using Alg = BwdWmmaV3Algorithm; static_assert(Alg::is_valid(), Alg::message()); @@ -781,6 +827,10 @@ consteval void diagnose_bwd_weight_algorithm_signature() using Alg = BwdTwoStageWmmaV3Algorithm; static_assert(Alg::is_valid(), Alg::message()); } + else if constexpr (max_matches == wmma_matches) { + using Alg = BwdWmmaAlgorithm; + static_assert(Alg::is_valid(), Alg::message()); + } } else { diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp new file mode 100644 index 00000000000..b8c3bc8f9b2 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp @@ -0,0 +1,104 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_limits.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" + +namespace ck_tile::builder::factory { + +// Factory for DeviceGroupedConvBwdWeight_Wmma_CShuffle instance +// of a grouped bwd weight convolution kernel. +template + requires ConvDirectionIsBackwardWeight && Is3D +struct ConvBwdWeightWmmaFactory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = internal::ConvTensorLayouts; + using Types = internal::BwdWeightConvTensorDataTypes; + using Ops = internal::ElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto BWD_CONV_SPECIALIZATION = internal::SetBwdWeightConvSpecialization(); + + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto GRIDWISE_GEMM_PIPELINE_VERSION = + internal::SetGridwiseGemmPipelineVersion(); + static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler(); + + static constexpr auto A_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); + + // Check limits for the algorithm parameters. + // TODO: Add more limits checks as needed. + static_assert(InputVectorTransferLimits, "Invalid A block transfer config"); + static_assert(InputVectorTransferLimits, "Invalid B block transfer config"); + static_assert(OutputVectorTransferLimits, "Invalid C block transfer config"); + static_assert(AccessOrderLimits4D, "Invalid A thread cluster access order"); + static_assert(AccessOrderLimits4D, "Invalid B thread cluster access order"); + static_assert(AccessOrderLimits4D, "Invalid A source access order"); + static_assert(AccessOrderLimits4D, "Invalid B source access order"); + + // The forward convolution kernel class instance. + using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle< + SPATIAL_DIM, + typename Layouts::InLayout, + typename Layouts::WeiLayout, + typename Layouts::OutLayout, + typename Types::InDataType, + typename Types::WeiDataType, + typename Types::OutDataType, + typename Types::AccDataType, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, + BWD_CONV_SPECIALIZATION, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.k1, + GRIDWISE_GEMM.m_per_wmma, + GRIDWISE_GEMM.n_per_wmma, + GRIDWISE_GEMM.m_wmma_per_wave, + GRIDWISE_GEMM.n_wmma_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + A_BLOCK_TRANSFER.lds_padding, + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + B_BLOCK_TRANSFER.lds_padding, + C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + ALGORITHM.num_gemm_k_prefetch_stages, + LOOP_SCHEDULER, + GRIDWISE_GEMM_PIPELINE_VERSION>; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp index 557f35f861d..4e1a4766fd8 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -75,6 +75,7 @@ #include "ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp" #include "ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp" #include "ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp" +#include "ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp" namespace ck_tile::builder::factory { @@ -181,6 +182,10 @@ constexpr auto make_conv_instance() { return typename ConvBwdWeightTwoStageWmmaV3Factory::Instance{}; } + else if constexpr (BwdWmmaAlgorithm::is_valid()) + { + return typename ConvBwdWeightWmmaFactory::Instance{}; + } else { diagnose_bwd_weight_algorithm_signature(); diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index aedcbb81105..179b73b784f 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -176,6 +176,7 @@ set(BWD_WEIGHT_TESTS if (CK_USE_WMMA) list(APPEND BWD_WEIGHT_TESTS conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp + conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp conv/ck/test_ckb_conv_bwd_weight_two_stage_wmma_cshuffle_v3.cpp ) endif() diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp new file mode 100644 index 00000000000..47e07d07220 --- /dev/null +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" +#include "ck_tile/host/device_prop.hpp" + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace cku = ck_tile::builder::test_utils; +using enum ck_tile::builder::TensorLayout; + +constexpr auto SIGNATURE = + ckt::ConvSignature{.spatial_dim = 3, + .direction = ckb::ConvDirection::BACKWARD_WEIGHT, + .data_type = ckb::DataType::BF16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = NGCDHW}}, + .weight = {.config = {.layout = GKZYXC}}, + .output = {.config = {.layout = NGKDHW}}}; + +constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle{} + .with_thread_block(cku::ThreadBlock_64_32x32x32) + .with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave) + .with_transfer(cku::BwdTransfer_4x64x1) + .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT) + .with_prefetch_config(1, ckb::PipelineScheduler::DEFAULT) + .with_gridwise_gemm_pipeline(ckb::PipelineVersion::V1); + +using Builder = ckb::ConvBuilder; +using Instance = Builder::Instance; + +TEST(BwdWeight_3DBf16_Wmma_CShuffle, Create) +{ + const auto expected_transfer_parameters = to_string(ALGORITHM); + std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl; + cku::run_test({"DeviceGroupedConvBwdWeight_Wmma_CShuffle", + expected_transfer_parameters, + "Default", + "NGCDHW,GKZYXC,NGKDHW", + "PassThrough,PassThrough,PassThrough", + "v1"}); +} diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index 404e14fced2..f53bdaa6ab3 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -234,7 +234,6 @@ struct ConvSpecializationBwdWeight_ struct Prefetch_ { size_t num_gemm_k_prefetch_stages; - size_t num_groups_to_merge; PipelineScheduler loop_scheduler; }; @@ -430,14 +429,11 @@ struct ConvAlgorithmTemplate : Components... return result; } - constexpr auto with_prefetch_config(size_t k_prefetch_stages, - size_t groups_to_merge, - PipelineScheduler scheduler) const + constexpr auto with_prefetch_config(size_t k_prefetch_stages, PipelineScheduler scheduler) const { static_assert(std::is_base_of_v); auto result = *this; result.num_gemm_k_prefetch_stages = k_prefetch_stages; - result.num_groups_to_merge = groups_to_merge; result.loop_scheduler = scheduler; return result; } @@ -608,4 +604,8 @@ using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3 = using ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3 = ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, BlockGemm_, TransposeParams_, GemmBatchOptions_, TwoStageSpecialization_>; +using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle = + ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, GridGemm_, Prefetch_>; + + } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index 161d6be268d..56c9e1755f1 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -314,7 +314,7 @@ template <> inline std::string to_string(Prefetch_ t) { std::ostringstream oss; - oss << t.num_gemm_k_prefetch_stages << "," << t.num_groups_to_merge << "," + oss << t.num_gemm_k_prefetch_stages << "," << to_string(t.loop_scheduler); return oss.str(); } @@ -423,6 +423,16 @@ inline std::string to_string +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); + return oss.str(); +} + template <> inline std::string to_string( ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3 t) From c3a9044bad8b6028fa2b016644631fea57da2219 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Fri, 2 Jan 2026 09:25:38 -0500 Subject: [PATCH 45/81] Dispatching for DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle. --- .../builder/factory/conv_algorithms.hpp | 145 +++++++++++------- .../builder/factory/conv_dispatcher.hpp | 5 + experimental/builder/test/CMakeLists.txt | 1 + ..._conv_bwd_weight_multi_d_wmma_cshuffle.cpp | 42 +++++ .../test/impl/conv_algorithm_types.hpp | 7 +- .../test/utils/conv_algorithm_type_utils.hpp | 10 ++ 6 files changed, 156 insertions(+), 54 deletions(-) create mode 100644 experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle.cpp diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index cf31ee64c03..9b120bae9a3 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -281,7 +281,7 @@ struct LargeTensorAlgorithm : public FwdXdlAlgorithmBase }; template -struct BwdXdlAlgorithm { +struct BwdXdlAlgorithmBase { CHECK_CONCEPT(T, ConvAlgorithmDescriptor) CHECK_CONCEPT(T, SpecifiesThreadBlock) CHECK_CONCEPT(T, SpecifiesBlockTransfer4D) @@ -290,8 +290,6 @@ struct BwdXdlAlgorithm { CHECK_CONCEPT(T, SpecifiesSourceAccessOrder) CHECK_CONCEPT(T, SpecifiesGridwiseBwdXdlGemm) CHECK_CONCEPT(T, SpecifiesBwdWeightConvSpecialization) - CHECK_CONCEPT(T, TransposeTransferWellDefinedIfProvided) - CHECK_CONCEPT(T, SpecifiesGenericInstance) static constexpr bool c1 = c_ConvAlgorithmDescriptor; static constexpr bool c2 = c_SpecifiesThreadBlock; @@ -301,16 +299,13 @@ struct BwdXdlAlgorithm { static constexpr bool c6 = c_SpecifiesSourceAccessOrder; static constexpr bool c7 = c_SpecifiesGridwiseBwdXdlGemm; static constexpr bool c8 = c_SpecifiesBwdWeightConvSpecialization; - static constexpr bool c9 = c_TransposeTransferWellDefinedIfProvided; - static constexpr bool c10 = c_SpecifiesGenericInstance; static consteval bool is_valid() { - return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && c10; + return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8; } static consteval auto message() -> std::string { - return std::string("\n=== Backward XDL Algorithm Diagnostic (closest match) ===\n" - "Concepts for BwdXdl Algorithm:\n") + + return DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + DIAGNOSTIC_LINE(SpecifiesThreadBlock) + DIAGNOSTIC_LINE(SpecifiesBlockTransfer4D) + @@ -318,49 +313,45 @@ struct BwdXdlAlgorithm { DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) + DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) + DIAGNOSTIC_LINE(SpecifiesGridwiseBwdXdlGemm) + - DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization) + - DIAGNOSTIC_LINE(TransposeTransferWellDefinedIfProvided) + + DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization); + } +}; + +template +struct BwdXdlAlgorithm : public BwdXdlAlgorithmBase{ + CHECK_CONCEPT(T, SpecifiesTransposeTransfer) + CHECK_CONCEPT(T, SpecifiesGenericInstance) + + static constexpr bool c9 = c_SpecifiesTransposeTransfer; + static constexpr bool c10 = c_SpecifiesGenericInstance; + + static consteval bool is_valid() { + return c9 && c10 && BwdXdlAlgorithmBase::is_valid(); + } + + static consteval auto message() -> std::string { + return std::string("\n=== Backward XDL Algorithm Diagnostic (closest match) ===\n" + "Concepts for BwdXdl Algorithm:\n") + + BwdXdlAlgorithmBase::message() + + DIAGNOSTIC_LINE(SpecifiesTransposeTransfer) + DIAGNOSTIC_LINE(SpecifiesGenericInstance); } }; template -struct BwdMultiDXdlAlgorithm { - CHECK_CONCEPT(T, ConvAlgorithmDescriptor) - CHECK_CONCEPT(T, SpecifiesThreadBlock) - CHECK_CONCEPT(T, SpecifiesBlockTransfer4D) - CHECK_CONCEPT(T, SpecifiesLdsTransfer) - CHECK_CONCEPT(T, SpecifiesThreadClusterAccessOrder) - CHECK_CONCEPT(T, SpecifiesSourceAccessOrder) - CHECK_CONCEPT(T, SpecifiesGridwiseBwdXdlGemm) - CHECK_CONCEPT(T, SpecifiesBwdWeightConvSpecialization) +struct BwdMultiDXdlAlgorithm : public BwdXdlAlgorithmBase{ CHECK_CONCEPT(T, SpecifiesMultipleDSupport) - static constexpr bool c1 = c_ConvAlgorithmDescriptor; - static constexpr bool c2 = c_SpecifiesThreadBlock; - static constexpr bool c3 = c_SpecifiesBlockTransfer4D; - static constexpr bool c4 = c_SpecifiesLdsTransfer; - static constexpr bool c5 = c_SpecifiesThreadClusterAccessOrder; - static constexpr bool c6 = c_SpecifiesSourceAccessOrder; - static constexpr bool c7 = c_SpecifiesGridwiseBwdXdlGemm; - static constexpr bool c8 = c_SpecifiesBwdWeightConvSpecialization; static constexpr bool c9 = c_SpecifiesMultipleDSupport; static consteval bool is_valid() { - return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9; + return c9 && BwdXdlAlgorithmBase::is_valid(); } static consteval auto message() -> std::string { return std::string("\n=== Backward XDL Algorithm Diagnostic (closest match) ===\n" "Concepts for BwdXdl Algorithm:\n") + - DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + - DIAGNOSTIC_LINE(SpecifiesThreadBlock) + - DIAGNOSTIC_LINE(SpecifiesBlockTransfer4D) + - DIAGNOSTIC_LINE(SpecifiesLdsTransfer) + - DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) + - DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) + - DIAGNOSTIC_LINE(SpecifiesGridwiseBwdXdlGemm) + - DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization) + + BwdXdlAlgorithmBase::message() + DIAGNOSTIC_LINE(SpecifiesMultipleDSupport); } }; @@ -448,7 +439,7 @@ struct BwdTwoStageXdlAlgorithm : public BwdXdlV3AlgorithmBase{ }; template -struct BwdWmmaAlgorithm { +struct BwdWmmaAlgorithmBase { CHECK_CONCEPT(T, ConvAlgorithmDescriptor) CHECK_CONCEPT(T, SpecifiesThreadBlock) CHECK_CONCEPT(T, SpecifiesBlockTransfer) @@ -457,10 +448,6 @@ struct BwdWmmaAlgorithm { CHECK_CONCEPT(T, SpecifiesSourceAccessOrder) CHECK_CONCEPT(T, SpecifiesGridwiseWmmaGemm) CHECK_CONCEPT(T, SpecifiesBwdWeightConvSpecialization) - CHECK_CONCEPT(T, SpecifiesNumPrefetchStages) - CHECK_CONCEPT(T, SpecifiesLoopScheduler) - CHECK_CONCEPT(T, SpecifiedGridwiseGemmPipeline) - CHECK_CONCEPT(T, SpecifiesGenericInstance) static constexpr bool c1 = c_ConvAlgorithmDescriptor; static constexpr bool c2 = c_SpecifiesThreadBlock; @@ -470,18 +457,13 @@ struct BwdWmmaAlgorithm { static constexpr bool c6 = c_SpecifiesSourceAccessOrder; static constexpr bool c7 = c_SpecifiesGridwiseWmmaGemm; static constexpr bool c8 = c_SpecifiesBwdWeightConvSpecialization; - static constexpr bool c9 = c_SpecifiesNumPrefetchStages; - static constexpr bool c10 = c_SpecifiesLoopScheduler; - static constexpr bool c11 = c_SpecifiedGridwiseGemmPipeline; - static constexpr bool c12 = c_SpecifiesGenericInstance; static consteval bool is_valid() { - return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && c10 && c11 && c12; + return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8; } static consteval auto message() -> std::string { - return std::string("\n=== Backward WMMA Algorithm Diagnostic (closest match) ===\n" - "Concepts for BwdWmma Algorithm:\n") + + return DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + DIAGNOSTIC_LINE(SpecifiesThreadBlock) + DIAGNOSTIC_LINE(SpecifiesBlockTransfer) + @@ -489,7 +471,30 @@ struct BwdWmmaAlgorithm { DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) + DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) + DIAGNOSTIC_LINE(SpecifiesGridwiseWmmaGemm) + - DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization) + + DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization); + } +}; + +template +struct BwdWmmaAlgorithm : public BwdWmmaAlgorithmBase { + CHECK_CONCEPT(T, SpecifiesNumPrefetchStages) + CHECK_CONCEPT(T, SpecifiesLoopScheduler) + CHECK_CONCEPT(T, SpecifiedGridwiseGemmPipeline) + CHECK_CONCEPT(T, SpecifiesGenericInstance) + + static constexpr bool c9 = c_SpecifiesNumPrefetchStages; + static constexpr bool c10 = c_SpecifiesLoopScheduler; + static constexpr bool c11 = c_SpecifiedGridwiseGemmPipeline; + static constexpr bool c12 = c_SpecifiesGenericInstance; + + static consteval bool is_valid() { + return c9 && c10 && c11 && c12 && BwdWmmaAlgorithmBase::is_valid(); + } + + static consteval auto message() -> std::string { + return std::string("\n=== Backward WMMA Algorithm Diagnostic (closest match) ===\n" + "Concepts for BwdWmma Algorithm:\n") + + BwdWmmaAlgorithmBase::message() + DIAGNOSTIC_LINE(SpecifiesNumPrefetchStages) + DIAGNOSTIC_LINE(SpecifiesLoopScheduler) + DIAGNOSTIC_LINE(SpecifiedGridwiseGemmPipeline) + @@ -497,6 +502,27 @@ struct BwdWmmaAlgorithm { } }; +template +struct BwdMultiDWmmaAlgorithm : public BwdWmmaAlgorithmBase { + CHECK_CONCEPT(T, SpecifiesBlockGemm) + CHECK_CONCEPT(T, SpecifiesMultipleDSupport) + + static constexpr bool c9 = c_SpecifiesBlockGemm; + static constexpr bool c10 = c_SpecifiesMultipleDSupport; + + static consteval bool is_valid() { + return c9 && c10 && BwdWmmaAlgorithmBase::is_valid(); + } + + static consteval auto message() -> std::string { + return std::string("\n=== Backward WMMA Algorithm Diagnostic (closest match) ===\n" + "Concepts for BwdMultiDWmma Algorithm:\n") + + BwdWmmaAlgorithmBase::message() + + DIAGNOSTIC_LINE(SpecifiesBlockGemm) + + DIAGNOSTIC_LINE(SpecifiesMultipleDSupport); + } +}; + template struct BwdWmmaV3AlgorithmBase { CHECK_CONCEPT(T, ConvAlgorithmDescriptor) @@ -508,7 +534,7 @@ struct BwdWmmaV3AlgorithmBase { CHECK_CONCEPT(T, SpecifiesGridwiseWmmaGemm) CHECK_CONCEPT(T, SpecifiesBwdWeightConvSpecialization) CHECK_CONCEPT(T, SpecifiesBlockGemm) - CHECK_CONCEPT(T, TransposeTransferWellDefinedIfProvided) + CHECK_CONCEPT(T, SpecifiesTransposeTransfer) static constexpr bool c1 = c_ConvAlgorithmDescriptor; static constexpr bool c2 = c_SpecifiesThreadBlock; @@ -519,7 +545,7 @@ struct BwdWmmaV3AlgorithmBase { static constexpr bool c7 = c_SpecifiesGridwiseWmmaGemm; static constexpr bool c8 = c_SpecifiesBwdWeightConvSpecialization; static constexpr bool c9 = c_SpecifiesBlockGemm; - static constexpr bool c10 = c_TransposeTransferWellDefinedIfProvided; + static constexpr bool c10 = c_SpecifiesTransposeTransfer; static consteval bool is_valid() { return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && c10; @@ -536,7 +562,7 @@ struct BwdWmmaV3AlgorithmBase { DIAGNOSTIC_LINE(SpecifiesGridwiseWmmaGemm) + DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization) + DIAGNOSTIC_LINE(SpecifiesBlockGemm) + - DIAGNOSTIC_LINE(TransposeTransferWellDefinedIfProvided); + DIAGNOSTIC_LINE(SpecifiesTransposeTransfer); } }; @@ -670,6 +696,12 @@ consteval int count_matches_bwd_wmma() { return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + Alg::c10 + Alg::c11 + Alg::c12; } +template +consteval int count_matches_bwd_multi_d_wmma() { + using Alg = BwdMultiDWmmaAlgorithm; + return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + Alg::c10 + Alg::c11 + Alg::c12; +} + template consteval int count_matches_bwd_wmma_v3() { using Alg = BwdWmmaV3Algorithm; @@ -785,6 +817,7 @@ consteval void diagnose_bwd_weight_algorithm_signature() constexpr int wmma_v3_matches = count_matches_bwd_wmma_v3(); constexpr int two_stage_wmma_v3_matches = count_matches_bwd_two_stage_wmma_v3(); constexpr int wmma_matches = count_matches_bwd_wmma(); + constexpr int multi_d_wmma_matches = count_matches_bwd_multi_d_wmma(); // Check whether we have XDL or WMMA algorithm if constexpr (SpecifiesGridwiseBwdXdlGemm) @@ -818,7 +851,9 @@ consteval void diagnose_bwd_weight_algorithm_signature() else if constexpr (SpecifiesGridwiseWmmaGemm) { constexpr int max_1 = wmma_v3_matches > two_stage_wmma_v3_matches ? wmma_v3_matches : two_stage_wmma_v3_matches; - constexpr int max_matches = max_1 > wmma_matches ? max_1 : wmma_matches; + constexpr int max_2 = max_1 > wmma_matches ? max_1 : wmma_matches; + constexpr int max_matches = multi_d_wmma_matches > max_2 ? multi_d_wmma_matches : max_2; + if constexpr (max_matches == wmma_v3_matches) { using Alg = BwdWmmaV3Algorithm; static_assert(Alg::is_valid(), Alg::message()); @@ -831,6 +866,10 @@ consteval void diagnose_bwd_weight_algorithm_signature() using Alg = BwdWmmaAlgorithm; static_assert(Alg::is_valid(), Alg::message()); } + else if constexpr (max_matches == multi_d_wmma_matches) { + using Alg = BwdMultiDWmmaAlgorithm; + static_assert(Alg::is_valid(), Alg::message()); + } } else { diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp index 4e1a4766fd8..533e7752720 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -186,6 +186,11 @@ constexpr auto make_conv_instance() { return typename ConvBwdWeightWmmaFactory::Instance{}; } + else if constexpr (BwdMultiDWmmaAlgorithm::is_valid()) + { + static_assert(false, + "Backward weight convolution with multi-D WMMA algorithm is not yet supported."); + } else { diagnose_bwd_weight_algorithm_signature(); diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 179b73b784f..b2f5970d2e7 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -178,6 +178,7 @@ if (CK_USE_WMMA) conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp conv/ck/test_ckb_conv_bwd_weight_two_stage_wmma_cshuffle_v3.cpp + conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle.cpp ) endif() diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle.cpp new file mode 100644 index 00000000000..e050ffad4e7 --- /dev/null +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle.cpp @@ -0,0 +1,42 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" +#include "ck_tile/host/device_prop.hpp" + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace cku = ck_tile::builder::test_utils; + +constexpr auto SIGNATURE = + ckt::ConvSignature{.spatial_dim = 2, + .direction = ckb::ConvDirection::BACKWARD_WEIGHT, + .data_type = ckb::DataType::FP16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = ckb::TensorLayout::GNHWC}}, + .weight = {.config = {.layout = ckb::TensorLayout::GKYXC}}, + .output = {.config = {.layout = ckb::TensorLayout::GNHWK}}}; + +constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle{} + .with_thread_block(cku::ThreadBlock_256_128x128x8) + .with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave) + .with_transfer(cku::BwdTransfer_4x64x1) + .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT) + .with_block_gemm(cku::BlockGemmDesc_v1_intrawave); + +using Builder = ckb::ConvBuilder; +using Instance = Builder::Instance; + +TEST(BwdWeight_2DFp16_MultiD_Wmma_Shuffle_GNHWC, Create) +{ + const auto expected_transfer_parameters = to_string(ALGORITHM); + std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl; + cku::run_test({"DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle", + expected_transfer_parameters, + "Default", + "GNHWC,GKYXC,GNHWK", + "PassThrough,PassThrough,PassThrough", + "fp16,fp16>"}); // check compute types +} diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index f53bdaa6ab3..b126b5af8d8 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -547,7 +547,7 @@ struct ConvAlgorithmTemplate : Components... } }; -// Algorithm types +// Fwd algorithm types using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = ConvAlgorithmTemplate, ConvSpecializationFwd_, Prefetch_>; @@ -568,6 +568,7 @@ using ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK = using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor = ConvAlgorithmTemplate, ConvSpecializationFwd_, Prefetch_, LargeTensorSpecialization_>; +// CK Tile algorithm using ConvAlgorithm_Tile_GroupedConvolutionKernel = ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, TransposeParams_>; @@ -607,5 +609,8 @@ using ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3 = using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle = ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, GridGemm_, Prefetch_>; +using ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle = + ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, BlockGemm_, MultipleDSpecialization_>; + } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index 56c9e1755f1..ad887b2491c 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -443,6 +443,16 @@ inline std::string to_string +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); + return oss.str(); +} + template <> inline std::string to_string( ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3 t) From 4eea42c5b7173fb6ddc2574b2aa06b8ccbf9652b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Fri, 2 Jan 2026 09:52:52 -0500 Subject: [PATCH 46/81] Add factory for DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 --- .../builder/factory/conv_algorithms.hpp | 60 +++++----- ...onv_bwd_weight_multi_d_wmma_v3_factory.hpp | 106 ++++++++++++++++++ .../builder/factory/conv_dispatcher.hpp | 6 +- ..._conv_bwd_weight_multi_d_wmma_cshuffle.cpp | 4 +- .../test/impl/conv_algorithm_types.hpp | 2 +- .../test/utils/conv_algorithm_type_utils.hpp | 2 +- ...bwd_weight_multiple_d_wmma_cshuffle_v3.hpp | 10 +- 7 files changed, 148 insertions(+), 42 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index 9b120bae9a3..800403b16da 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -502,27 +502,6 @@ struct BwdWmmaAlgorithm : public BwdWmmaAlgorithmBase { } }; -template -struct BwdMultiDWmmaAlgorithm : public BwdWmmaAlgorithmBase { - CHECK_CONCEPT(T, SpecifiesBlockGemm) - CHECK_CONCEPT(T, SpecifiesMultipleDSupport) - - static constexpr bool c9 = c_SpecifiesBlockGemm; - static constexpr bool c10 = c_SpecifiesMultipleDSupport; - - static consteval bool is_valid() { - return c9 && c10 && BwdWmmaAlgorithmBase::is_valid(); - } - - static consteval auto message() -> std::string { - return std::string("\n=== Backward WMMA Algorithm Diagnostic (closest match) ===\n" - "Concepts for BwdMultiDWmma Algorithm:\n") + - BwdWmmaAlgorithmBase::message() + - DIAGNOSTIC_LINE(SpecifiesBlockGemm) + - DIAGNOSTIC_LINE(SpecifiesMultipleDSupport); - } -}; - template struct BwdWmmaV3AlgorithmBase { CHECK_CONCEPT(T, ConvAlgorithmDescriptor) @@ -534,7 +513,6 @@ struct BwdWmmaV3AlgorithmBase { CHECK_CONCEPT(T, SpecifiesGridwiseWmmaGemm) CHECK_CONCEPT(T, SpecifiesBwdWeightConvSpecialization) CHECK_CONCEPT(T, SpecifiesBlockGemm) - CHECK_CONCEPT(T, SpecifiesTransposeTransfer) static constexpr bool c1 = c_ConvAlgorithmDescriptor; static constexpr bool c2 = c_SpecifiesThreadBlock; @@ -545,10 +523,9 @@ struct BwdWmmaV3AlgorithmBase { static constexpr bool c7 = c_SpecifiesGridwiseWmmaGemm; static constexpr bool c8 = c_SpecifiesBwdWeightConvSpecialization; static constexpr bool c9 = c_SpecifiesBlockGemm; - static constexpr bool c10 = c_SpecifiesTransposeTransfer; static consteval bool is_valid() { - return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && c10; + return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9; } static consteval auto message() -> std::string { @@ -561,26 +538,46 @@ struct BwdWmmaV3AlgorithmBase { DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) + DIAGNOSTIC_LINE(SpecifiesGridwiseWmmaGemm) + DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization) + - DIAGNOSTIC_LINE(SpecifiesBlockGemm) + - DIAGNOSTIC_LINE(SpecifiesTransposeTransfer); + DIAGNOSTIC_LINE(SpecifiesBlockGemm); + } +}; + +template +struct BwdMultiDWmmaV3Algorithm : public BwdWmmaV3AlgorithmBase { + CHECK_CONCEPT(T, SpecifiesMultipleDSupport) + + static constexpr bool c10 = c_SpecifiesMultipleDSupport; + + static consteval bool is_valid() { + return c10 && BwdWmmaAlgorithmBase::is_valid(); + } + + static consteval auto message() -> std::string { + return std::string("\n=== Backward WMMA Algorithm Diagnostic (closest match) ===\n" + "Concepts for BwdMultiDWmma Algorithm:\n") + + BwdWmmaAlgorithmBase::message() + + DIAGNOSTIC_LINE(SpecifiesMultipleDSupport); } }; template struct BwdWmmaV3Algorithm : public BwdWmmaV3AlgorithmBase { + CHECK_CONCEPT(T, SpecifiesTransposeTransfer) CHECK_CONCEPT(T, SpecifiesGenericInstance) + static constexpr bool c10 = c_SpecifiesTransposeTransfer; static constexpr bool c11 = c_SpecifiesGenericInstance; static consteval bool is_valid() { - return c11 && BwdWmmaV3AlgorithmBase::is_valid(); + return c10 && c11 && BwdWmmaV3AlgorithmBase::is_valid(); } static consteval auto message() -> std::string { return std::string("\n=== Backward WMMA V3 Algorithm Diagnostic (closest match) ===\n" "Concepts for BwdWmmaV3 Algorithm:\n") + BwdWmmaV3AlgorithmBase::message() + + DIAGNOSTIC_LINE(SpecifiesTransposeTransfer) + DIAGNOSTIC_LINE(SpecifiesGenericInstance); } }; @@ -588,20 +585,23 @@ struct BwdWmmaV3Algorithm : public BwdWmmaV3AlgorithmBase template struct BwdTwoStageWmmaV3Algorithm : public BwdWmmaV3AlgorithmBase { + CHECK_CONCEPT(T, SpecifiesTransposeTransfer) CHECK_CONCEPT(T, SpecifiesTwoStageSupport) CHECK_CONCEPT(T, SpecifiesGemmBatchOptions) + static constexpr bool c10 = c_SpecifiesTransposeTransfer; static constexpr bool c11 = c_SpecifiesTwoStageSupport; static constexpr bool c12 = c_SpecifiesGemmBatchOptions; static consteval bool is_valid() { - return c11 && c12 && BwdWmmaV3AlgorithmBase::is_valid(); + return c10 && c11 && c12 && BwdWmmaV3AlgorithmBase::is_valid(); } static consteval auto message() -> std::string { return std::string("\n=== Backward Two Stage WMMA V3 Algorithm Diagnostic (closest match) ===\n" "Concepts for BwdTwoStageWmmaV3 Algorithm:\n") + BwdWmmaV3AlgorithmBase::message() + + DIAGNOSTIC_LINE(SpecifiesTransposeTransfer) + DIAGNOSTIC_LINE(SpecifiesGemmBatchOptions) + DIAGNOSTIC_LINE(SpecifiesTwoStageSupport); } @@ -698,7 +698,7 @@ consteval int count_matches_bwd_wmma() { template consteval int count_matches_bwd_multi_d_wmma() { - using Alg = BwdMultiDWmmaAlgorithm; + using Alg = BwdMultiDWmmaV3Algorithm; return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + Alg::c10 + Alg::c11 + Alg::c12; } @@ -867,7 +867,7 @@ consteval void diagnose_bwd_weight_algorithm_signature() static_assert(Alg::is_valid(), Alg::message()); } else if constexpr (max_matches == multi_d_wmma_matches) { - using Alg = BwdMultiDWmmaAlgorithm; + using Alg = BwdMultiDWmmaV3Algorithm; static_assert(Alg::is_valid(), Alg::message()); } } diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp new file mode 100644 index 00000000000..2ba13005370 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp @@ -0,0 +1,106 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_limits.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" + +namespace ck_tile::builder::factory { + +// Factory for DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 instance +// of a grouped bwd weight convolution kernel. +template + requires ConvDirectionIsBackwardWeight && Is3D +struct ConvBwdWeightMultiDWmmaV3Factory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = internal::ConvTensorLayouts; + using Types = internal::BwdWeightConvTensorDataTypes; + using Ops = internal::ElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto BWD_CONV_SPECIALIZATION = internal::SetBwdWeightConvSpecialization(); + + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto GRIDWISE_GEMM_PIPELINE_VERSION = + internal::SetGridwiseGemmPipelineVersion(); + static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler(); + + static constexpr auto A_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); + static constexpr auto BLOCK_GEMM = internal::SetBlockGemm(); + + // Check limits for the algorithm parameters. + // TODO: Add more limits checks as needed. + static_assert(InputVectorTransferLimits, "Invalid A block transfer config"); + static_assert(InputVectorTransferLimits, "Invalid B block transfer config"); + static_assert(OutputVectorTransferLimits, "Invalid C block transfer config"); + static_assert(AccessOrderLimits4D, "Invalid A thread cluster access order"); + static_assert(AccessOrderLimits4D, "Invalid B thread cluster access order"); + static_assert(AccessOrderLimits4D, "Invalid A source access order"); + static_assert(AccessOrderLimits4D, "Invalid B source access order"); + + // The forward convolution kernel class instance. + using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< + SPATIAL_DIM, + typename Layouts::InLayout, + typename Layouts::WeiLayout, + typename Layouts::OutLayout, + typename Layouts::DsLayout, + typename Types::InDataType, + typename Types::WeiDataType, + typename Types::OutDataType, + typename Types::AccDataType, + typename Types::DsDataType, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, + BWD_CONV_SPECIALIZATION, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.k1, + GRIDWISE_GEMM.m_per_wmma, + GRIDWISE_GEMM.n_per_wmma, + GRIDWISE_GEMM.m_wmma_per_wave, + GRIDWISE_GEMM.n_wmma_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + A_BLOCK_TRANSFER.lds_padding, + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + B_BLOCK_TRANSFER.lds_padding, + C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + BLOCK_GEMM.scheduler, + BLOCK_GEMM.pipeline_version>; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp index 533e7752720..9bca0177350 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -76,6 +76,7 @@ #include "ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp" #include "ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp" #include "ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp" +#include "ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp" namespace ck_tile::builder::factory { @@ -186,10 +187,9 @@ constexpr auto make_conv_instance() { return typename ConvBwdWeightWmmaFactory::Instance{}; } - else if constexpr (BwdMultiDWmmaAlgorithm::is_valid()) + else if constexpr (BwdMultiDWmmaV3Algorithm::is_valid()) { - static_assert(false, - "Backward weight convolution with multi-D WMMA algorithm is not yet supported."); + return typename ConvBwdWeightMultiDWmmaV3Factory::Instance{}; } else { diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle.cpp index e050ffad4e7..e2bcd4a9269 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle.cpp @@ -20,9 +20,9 @@ constexpr auto SIGNATURE = .output = {.config = {.layout = ckb::TensorLayout::GNHWK}}}; constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle{} - .with_thread_block(cku::ThreadBlock_256_128x128x8) + .with_thread_block(cku::ThreadBlock_64_32x32x32) .with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave) - .with_transfer(cku::BwdTransfer_4x64x1) + .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT) .with_block_gemm(cku::BlockGemmDesc_v1_intrawave); diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index b126b5af8d8..eb03ecfab2a 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -610,7 +610,7 @@ using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle = ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, GridGemm_, Prefetch_>; using ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle = - ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, BlockGemm_, MultipleDSpecialization_>; + ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, BlockGemm_, MultipleDSpecialization_>; } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index ad887b2491c..ee4bd1f5978 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -449,7 +449,7 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); + << "," << to_string(static_cast>(t)); return oss.str(); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp index a3b8be8bf88..dbce8e8ccf6 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp @@ -50,7 +50,7 @@ __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3( + kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d( typename GridwiseGemm::Argument karg, const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, @@ -861,7 +861,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 { if(gemm_arg.KBatch > 1) { - const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< + const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< GridwiseGemm, remove_reference_t, remove_reference_t, @@ -875,7 +875,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 } else { - const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< + const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< GridwiseGemm, remove_reference_t, remove_reference_t, @@ -900,7 +900,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 { if(gemm_arg.KBatch > 1) { - const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< + const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< GridwiseGemm, remove_reference_t, remove_reference_t, @@ -914,7 +914,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 } else { - const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< + const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< GridwiseGemm, remove_reference_t, remove_reference_t, From 1dcea1825f162561365699abedce22b77b4bf7b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Mon, 5 Jan 2026 03:09:22 -0500 Subject: [PATCH 47/81] Fix DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 factory and compute types for input and output tensor in bwd weigth convs. --- .../ck_tile/builder/conv_algorithm_concepts.hpp | 2 +- .../builder/conv_algorithm_diagnostics.hpp | 2 +- .../ck_tile/builder/factory/conv_algorithms.hpp | 12 ++++++------ .../conv_bwd_weight_multi_d_wmma_v3_factory.hpp | 16 +++++++--------- .../conv_bwd_weight_multi_d_xdl_factory.hpp | 4 ++-- ...conv_bwd_weight_two_stage_wmma_v3_factory.hpp | 2 +- .../conv_bwd_weight_two_stage_xdl_factory.hpp | 2 +- .../factory/conv_bwd_weight_wmma_v3_factory.hpp | 2 +- .../factory/conv_bwd_weight_xdl_factory.hpp | 2 +- .../factory/conv_bwd_weight_xdl_v3_factory.hpp | 4 ++-- experimental/builder/test/CMakeLists.txt | 2 +- ...conv_bwd_weight_multi_d_wmma_cshuffle_v3.cpp} | 16 ++++++++-------- .../builder/test/impl/conv_algorithm_types.hpp | 2 +- .../test/utils/conv_algorithm_type_utils.hpp | 4 ++-- 14 files changed, 35 insertions(+), 37 deletions(-) rename experimental/builder/test/conv/ck/{test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle.cpp => test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle_v3.cpp} (85%) diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index 326350eb274..511c0c2b2db 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -255,7 +255,7 @@ concept SpecifiesBlockGemm = requires { }; template -concept SpecifiedGridwiseGemmPipeline = requires +concept SpecifiesGridwiseGemmPipeline = requires { { T::pipeline_version } -> std::convertible_to; }; diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp index df81b20f8ee..047eb7e4bdc 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp @@ -1193,7 +1193,7 @@ consteval auto detailed_diagnostic_SpecifiesLdsTransfer() -> std::string { } template -consteval auto detailed_diagnostic_SpecifiedGridwiseGemmPipeline() -> std::string { +consteval auto detailed_diagnostic_SpecifiesGridwiseGemmPipeline() -> std::string { if constexpr (requires { T::pipeline_version; }) { using PipelineType = decltype(T::pipeline_version); constexpr bool convertible = std::convertible_to; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index 800403b16da..891ea8f730c 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -153,7 +153,7 @@ struct FwdWmmaAlgorithm { CHECK_CONCEPT(T, SpecifiesGemmSpecialization) CHECK_CONCEPT(T, SpecifiesNumPrefetchStages) CHECK_CONCEPT(T, SpecifiesLoopScheduler) - CHECK_CONCEPT(T, SpecifiedGridwiseGemmPipeline) + CHECK_CONCEPT(T, SpecifiesGridwiseGemmPipeline) static constexpr bool c1 = c_ConvAlgorithmDescriptor; static constexpr bool c2 = c_SpecifiesThreadBlock; @@ -166,7 +166,7 @@ struct FwdWmmaAlgorithm { static constexpr bool c9 = c_SpecifiesGemmSpecialization; static constexpr bool c10 = c_SpecifiesNumPrefetchStages; static constexpr bool c11 = c_SpecifiesLoopScheduler; - static constexpr bool c12 = c_SpecifiedGridwiseGemmPipeline; + static constexpr bool c12 = c_SpecifiesGridwiseGemmPipeline; static consteval bool is_valid() { return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && c10 && c11 && c12; @@ -186,7 +186,7 @@ struct FwdWmmaAlgorithm { DIAGNOSTIC_LINE(SpecifiesGemmSpecialization) + DIAGNOSTIC_LINE(SpecifiesNumPrefetchStages) + DIAGNOSTIC_LINE(SpecifiesLoopScheduler) + - DIAGNOSTIC_LINE(SpecifiedGridwiseGemmPipeline); + DIAGNOSTIC_LINE(SpecifiesGridwiseGemmPipeline); } }; @@ -479,12 +479,12 @@ template struct BwdWmmaAlgorithm : public BwdWmmaAlgorithmBase { CHECK_CONCEPT(T, SpecifiesNumPrefetchStages) CHECK_CONCEPT(T, SpecifiesLoopScheduler) - CHECK_CONCEPT(T, SpecifiedGridwiseGemmPipeline) + CHECK_CONCEPT(T, SpecifiesGridwiseGemmPipeline) CHECK_CONCEPT(T, SpecifiesGenericInstance) static constexpr bool c9 = c_SpecifiesNumPrefetchStages; static constexpr bool c10 = c_SpecifiesLoopScheduler; - static constexpr bool c11 = c_SpecifiedGridwiseGemmPipeline; + static constexpr bool c11 = c_SpecifiesGridwiseGemmPipeline; static constexpr bool c12 = c_SpecifiesGenericInstance; static consteval bool is_valid() { @@ -497,7 +497,7 @@ struct BwdWmmaAlgorithm : public BwdWmmaAlgorithmBase { BwdWmmaAlgorithmBase::message() + DIAGNOSTIC_LINE(SpecifiesNumPrefetchStages) + DIAGNOSTIC_LINE(SpecifiesLoopScheduler) + - DIAGNOSTIC_LINE(SpecifiedGridwiseGemmPipeline) + + DIAGNOSTIC_LINE(SpecifiesGridwiseGemmPipeline) + DIAGNOSTIC_LINE(SpecifiesGenericInstance); } }; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp index 2ba13005370..9d94404a88d 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp @@ -35,10 +35,6 @@ struct ConvBwdWeightMultiDWmmaV3Factory static constexpr auto BLOCK = internal::SetThreadBlockInfo(); static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; - static constexpr auto GRIDWISE_GEMM_PIPELINE_VERSION = - internal::SetGridwiseGemmPipelineVersion(); - static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler(); - static constexpr auto A_BLOCK_TRANSFER = internal::SetBwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -51,10 +47,10 @@ struct ConvBwdWeightMultiDWmmaV3Factory static_assert(InputVectorTransferLimits, "Invalid A block transfer config"); static_assert(InputVectorTransferLimits, "Invalid B block transfer config"); static_assert(OutputVectorTransferLimits, "Invalid C block transfer config"); - static_assert(AccessOrderLimits4D, "Invalid A thread cluster access order"); - static_assert(AccessOrderLimits4D, "Invalid B thread cluster access order"); - static_assert(AccessOrderLimits4D, "Invalid A source access order"); - static_assert(AccessOrderLimits4D, "Invalid B source access order"); + static_assert(AccessOrderLimits3D, "Invalid A thread cluster access order"); + static_assert(AccessOrderLimits3D, "Invalid B thread cluster access order"); + static_assert(AccessOrderLimits3D, "Invalid A source access order"); + static_assert(AccessOrderLimits3D, "Invalid B source access order"); // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< @@ -100,7 +96,9 @@ struct ConvBwdWeightMultiDWmmaV3Factory to_sequence_v, C_BLOCK_TRANSFER.scalar_per_vector, BLOCK_GEMM.scheduler, - BLOCK_GEMM.pipeline_version>; + BLOCK_GEMM.pipeline_version, + typename Types::OutComputeType, + typename Types::InComputeType>; }; } // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp index 2c11f3b4366..dae333d99e1 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp @@ -95,8 +95,8 @@ struct ConvBwdWeightMultiDXdlFactory C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, to_sequence_v, C_BLOCK_TRANSFER.scalar_per_vector, - typename Types::InComputeType, - typename Types::WeiComputeType>; + typename Types::OutComputeType, + typename Types::InComputeType>; }; } // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp index 635b8222937..8288874f0e3 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp @@ -96,8 +96,8 @@ struct ConvBwdWeightTwoStageWmmaV3Factory BLOCK_GEMM.scheduler, BLOCK_GEMM.pipeline_version, ALGORITHM.num_conv_groups_to_merge, + typename Types::OutComputeType, typename Types::InComputeType, - typename Types::WeiComputeType, ALGORITHM.max_transpose_transfer_src_scalar_per_vector, ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>; }; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp index fae8ad7b873..df3e4c01b2e 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp @@ -97,8 +97,8 @@ struct ConvBwdWeightTwoStageXdlFactory BLOCK_GEMM.scheduler, BLOCK_GEMM.pipeline_version, ALGORITHM.num_conv_groups_to_merge, + typename Types::OutComputeType, typename Types::InComputeType, - typename Types::WeiComputeType, ALGORITHM.max_transpose_transfer_src_scalar_per_vector, ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>; }; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp index ab75f5b072c..d1cfbe9e8dc 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp @@ -95,8 +95,8 @@ struct ConvBwdWeightWmmaV3Factory C_BLOCK_TRANSFER.scalar_per_vector, BLOCK_GEMM.scheduler, BLOCK_GEMM.pipeline_version, + typename Types::OutComputeType, typename Types::InComputeType, - typename Types::WeiComputeType, ALGORITHM.max_transpose_transfer_src_scalar_per_vector, ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>; }; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp index 8ffbb495ec2..e89d227f821 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp @@ -93,8 +93,8 @@ struct ConvBwdWeightXdlFactory C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, to_sequence_v, C_BLOCK_TRANSFER.scalar_per_vector, + typename Types::OutComputeType, typename Types::InComputeType, - typename Types::WeiComputeType, ALGORITHM.max_transpose_transfer_src_scalar_per_vector, ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>; }; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp index 20aac55f311..e93b1456cd7 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp @@ -96,8 +96,8 @@ struct ConvBwdWeightXdlV3Factory C_BLOCK_TRANSFER.scalar_per_vector, BLOCK_GEMM.scheduler, BLOCK_GEMM.pipeline_version, - typename Types::InComputeType, - typename Types::WeiComputeType>; + typename Types::OutComputeType, + typename Types::InComputeType>; }; } // namespace ck_tile::builder::factory diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index b2f5970d2e7..667490151fc 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -178,7 +178,7 @@ if (CK_USE_WMMA) conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp conv/ck/test_ckb_conv_bwd_weight_two_stage_wmma_cshuffle_v3.cpp - conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle.cpp + conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle_v3.cpp ) endif() diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle_v3.cpp similarity index 85% rename from experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle_v3.cpp index e2bcd4a9269..404d1dbacdb 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle_v3.cpp @@ -11,15 +11,15 @@ namespace ckt = ck_tile::builder::test; namespace cku = ck_tile::builder::test_utils; constexpr auto SIGNATURE = - ckt::ConvSignature{.spatial_dim = 2, + ckt::ConvSignature{.spatial_dim = 3, .direction = ckb::ConvDirection::BACKWARD_WEIGHT, .data_type = ckb::DataType::FP16, .accumulation_data_type = ckb::DataType::FP32, - .input = {.config = {.layout = ckb::TensorLayout::GNHWC}}, - .weight = {.config = {.layout = ckb::TensorLayout::GKYXC}}, - .output = {.config = {.layout = ckb::TensorLayout::GNHWK}}}; + .input = {.config = {.layout = ckb::TensorLayout::GNDHWC}}, + .weight = {.config = {.layout = ckb::TensorLayout::GKZYXC}}, + .output = {.config = {.layout = ckb::TensorLayout::GNDHWK}}}; -constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle{} +constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_V3{} .with_thread_block(cku::ThreadBlock_64_32x32x32) .with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave) .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) @@ -29,14 +29,14 @@ constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeightMultiple using Builder = ckb::ConvBuilder; using Instance = Builder::Instance; -TEST(BwdWeight_2DFp16_MultiD_Wmma_Shuffle_GNHWC, Create) +TEST(BwdWeight_3DFp16_MultiD_Wmma_ShuffleV3_GNHWC, Create) { const auto expected_transfer_parameters = to_string(ALGORITHM); std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl; - cku::run_test({"DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle", + cku::run_test({"DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3", expected_transfer_parameters, "Default", - "GNHWC,GKYXC,GNHWK", + "GNDHWC,GKZYXC,GNDHWK", "PassThrough,PassThrough,PassThrough", "fp16,fp16>"}); // check compute types } diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index eb03ecfab2a..4cce9475cb6 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -609,7 +609,7 @@ using ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3 = using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle = ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, GridGemm_, Prefetch_>; -using ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle = +using ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_V3 = ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, BlockGemm_, MultipleDSpecialization_>; diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index ee4bd1f5978..6ad61797800 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -444,8 +444,8 @@ inline std::string to_string -inline std::string to_string( - ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle t) +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_V3 t) { std::ostringstream oss; oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) From 829eabed3a56c387087311b06fe0077cb567928d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Mon, 5 Jan 2026 04:42:31 -0500 Subject: [PATCH 48/81] Fix fwd factories after refactoring. --- .../ck_tile/builder/conv_algorithm_concepts.hpp | 2 +- .../ck_tile/builder/conv_algorithm_diagnostics.hpp | 10 +++++----- .../ck_tile/builder/factory/conv_algorithms.hpp | 2 +- .../ck_tile/builder/factory/conv_fwd_xdl_factory.hpp | 2 +- .../builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp | 3 ++- .../builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp | 3 ++- .../ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp | 3 ++- .../builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp | 3 ++- .../conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp | 6 ++++-- .../builder/test/impl/conv_algorithm_types.hpp | 6 +++--- 10 files changed, 23 insertions(+), 17 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index 511c0c2b2db..5a5dfd17b8a 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -310,7 +310,7 @@ concept SpecifiesNumPrefetchStages = requires { template concept SpecifiesNumGroupsToMerge = requires { - { T::num_groups_to_merge } -> SizeType; + { T::num_conv_groups_to_merge } -> SizeType; }; template diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp index 047eb7e4bdc..46e205d68bf 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp @@ -670,13 +670,13 @@ consteval auto detailed_diagnostic_SpecifiesNumPrefetchStages() -> std::string { template consteval auto detailed_diagnostic_SpecifiesNumGroupsToMerge() -> std::string { - if constexpr (requires { T::num_groups_to_merge; }) { - using NumGroupsType = decltype(T::num_groups_to_merge); + if constexpr (requires { T::num_conv_groups_to_merge; }) { + using NumGroupsType = decltype(T::num_conv_groups_to_merge); constexpr bool convertible = std::convertible_to; - return " → T::num_groups_to_merge: " + std::string(CHECK_MARK(convertible)) + + return " → T::num_conv_groups_to_merge: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; } else { - return " → T::num_groups_to_merge: [✗] (missing member)\n"; + return " → T::num_conv_groups_to_merge: [✗] (missing member)\n"; } } @@ -1193,7 +1193,7 @@ consteval auto detailed_diagnostic_SpecifiesLdsTransfer() -> std::string { } template -consteval auto detailed_diagnostic_SpecifiesGridwiseGemmPipeline() -> std::string { +consteval auto detailed_diagnostic_SpecifieGridwiseGemmPipeline() -> std::string { if constexpr (requires { T::pipeline_version; }) { using PipelineType = decltype(T::pipeline_version); constexpr bool convertible = std::convertible_to; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index 891ea8f730c..eb02aae584d 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -269,7 +269,7 @@ struct LargeTensorAlgorithm : public FwdXdlAlgorithmBase static consteval bool is_valid() { // Note: Check first if the specialization is set. - return c13 && FwdXdlAlgorithm::is_valid(); + return c13 && FwdXdlAlgorithmBase::is_valid(); } static consteval auto message() -> std::string { diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp index cebf5a0c3a1..5a0084d6da3 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp @@ -106,7 +106,7 @@ struct ConvFwdXdlFactory typename Types::AComputeType, typename Types::BComputeType, LOOP_SCHEDULER, - ALGORITHM.num_groups_to_merge>; + ALGORITHM.num_conv_groups_to_merge>; }; } // namespace ck_tile::builder::factory diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp index e543ce6fa0e..d3ace110c4b 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp @@ -31,7 +31,8 @@ TEST(FwdConvInstances, .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) .with_transfer(Transfer_4x16x1) .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) - .with_prefetch_config(1, 2, PipelineScheduler::DEFAULT); + .with_prefetch_config(1, PipelineScheduler::DEFAULT) + .with_num_conv_groups_to_merge(2); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp index fc9cf44b7a9..a8df3c0d98d 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp @@ -33,7 +33,8 @@ TEST(FwdConvInstances, .with_gemm_config(GemmParams_Wmma_2x1_per_wave) .with_transfer(Transfer_4x32x1) .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) - .with_prefetch_config(1, 0, PipelineScheduler::DEFAULT) + .with_prefetch_config(1, PipelineScheduler::DEFAULT) + .with_num_conv_groups_to_merge(2) .with_gridwise_gemm_pipeline(PipelineVersion::V1); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp index 271788799f3..23edef54369 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp @@ -36,7 +36,8 @@ TEST(FwdConvInstances, .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) .with_transfer(Transfer_4x16x1) .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) - .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT); + .with_prefetch_config(1, PipelineScheduler::DEFAULT) + .with_num_conv_groups_to_merge(1); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp index c95d48d7125..b117e693fe3 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp @@ -31,7 +31,8 @@ TEST(FwdConvInstances, .with_gemm_config(FwdGemmParams_Xdl_4x2_per_wave) .with_transfer(Transfer_4x64x1_fp8) .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) - .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT); + .with_prefetch_config(1, PipelineScheduler::DEFAULT) + .with_num_conv_groups_to_merge(1); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp index e9e1c5f868c..303875b3489 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp @@ -31,7 +31,8 @@ TEST(FwdConvInstances, .with_transfer(Transfer_4x16x1) .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) - .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT); + .with_prefetch_config(1, PipelineScheduler::DEFAULT) + .with_num_conv_groups_to_merge(1); using Builder = ConvBuilder; @@ -67,7 +68,8 @@ TEST( .with_transfer(Transfer_4x16x1) .with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0, GemmSpecialization::MNKPadding) - .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT); + .with_prefetch_config(1, PipelineScheduler::DEFAULT) + .with_num_conv_groups_to_merge(1); using Builder = ConvBuilder; diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index 4cce9475cb6..57ef018ae32 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -550,13 +550,13 @@ struct ConvAlgorithmTemplate : Components... // Fwd algorithm types using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = - ConvAlgorithmTemplate, ConvSpecializationFwd_, Prefetch_>; + ConvAlgorithmTemplate, ConvSpecializationFwd_, Prefetch_, GemmBatchOptions_>; using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 = ConvAlgorithmTemplate, ConvSpecializationFwd_, BlockGemm_>; using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle = - ConvAlgorithmTemplate, ConvSpecializationFwd_, GridGemm_, Prefetch_>; + ConvAlgorithmTemplate, ConvSpecializationFwd_, GridGemm_, Prefetch_, GemmBatchOptions_>; using ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK = ConvAlgorithmTemplate>; using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor = - ConvAlgorithmTemplate, ConvSpecializationFwd_, Prefetch_, LargeTensorSpecialization_>; + ConvAlgorithmTemplate, ConvSpecializationFwd_, Prefetch_, GemmBatchOptions_, LargeTensorSpecialization_>; // CK Tile algorithm using ConvAlgorithm_Tile_GroupedConvolutionKernel = ConvAlgorithmTemplate Date: Mon, 5 Jan 2026 04:45:51 -0500 Subject: [PATCH 49/81] clang-format --- .../builder/conv_algorithm_concepts.hpp | 17 +- .../builder/conv_algorithm_diagnostics.hpp | 1790 +++++++++++------ .../builder/factory/conv_algorithms.hpp | 648 +++--- .../factory/conv_bwd_weight_dl_factory.hpp | 11 +- ...onv_bwd_weight_multi_d_wmma_v3_factory.hpp | 120 +- .../conv_bwd_weight_multi_d_xdl_factory.hpp | 17 +- ...v_bwd_weight_two_stage_wmma_v3_factory.hpp | 122 +- .../conv_bwd_weight_two_stage_xdl_factory.hpp | 29 +- .../factory/conv_bwd_weight_wmma_factory.hpp | 27 +- .../conv_bwd_weight_wmma_v3_factory.hpp | 27 +- .../factory/conv_bwd_weight_xdl_factory.hpp | 17 +- .../conv_bwd_weight_xdl_v3_factory.hpp | 29 +- .../builder/factory/conv_dispatcher.hpp | 30 +- .../builder/factory/conv_fwd_dl_factory.hpp | 8 +- .../factory/conv_fwd_large_tensor_factory.hpp | 16 +- .../builder/factory/conv_fwd_v3_factory.hpp | 8 +- .../builder/factory/conv_fwd_wmma_factory.hpp | 8 +- .../builder/factory/conv_fwd_xdl_factory.hpp | 8 +- .../helpers/ck/conv_block_transfer.hpp | 45 +- .../helpers/ck/conv_elementwise_op.hpp | 6 +- .../factory/helpers/ck/conv_tensor_layout.hpp | 14 +- .../factory/helpers/ck/conv_tensor_type.hpp | 4 +- .../factory/helpers/ck/conv_tuning_params.hpp | 8 +- .../conv/ck/test_ckb_conv_bwd_weight_dl.cpp | 10 +- ..._bwd_weight_two_stage_wmma_cshuffle_v3.cpp | 17 +- ...conv_bwd_weight_two_stage_xdl_cshuffle.cpp | 2 +- ...test_ckb_conv_bwd_weight_wmma_cshuffle.cpp | 15 +- ...t_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp | 30 +- ...st_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp | 28 +- .../conv/ck/test_ckb_conv_fwd_1d_bf16.cpp | 2 +- .../test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp | 2 +- .../conv/ck/test_ckb_conv_fwd_2d_bf16.cpp | 3 +- .../conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp | 2 +- .../conv/ck/test_ckb_conv_fwd_2d_fp16.cpp | 2 +- .../conv/ck/test_ckb_conv_fwd_2d_fp32.cpp | 2 +- ...test_ckb_conv_fwd_2d_large_tensor_fp16.cpp | 27 +- .../conv/ck/test_ckb_conv_fwd_3d_fp16.cpp | 2 +- .../conv/ck/test_ckb_conv_fwd_3d_fp32.cpp | 2 +- .../test_ckb_conv_bwd_data_2d_fp16_v3.cpp | 15 +- .../test_ckb_conv_bwd_weight_2d_fp16_v3.cpp | 15 +- .../test/impl/conv_algorithm_types.hpp | 143 +- .../test/test_concept_diagnostics_sync.cpp | 87 +- .../builder/test/test_conv_description.cpp | 16 +- .../test/utils/ckb_conv_test_configs.hpp | 114 +- .../test/utils/conv_algorithm_type_utils.hpp | 15 +- ...bwd_weight_multiple_d_wmma_cshuffle_v3.hpp | 84 +- ...onv_bwd_weight_multiple_d_xdl_cshuffle.hpp | 27 +- 47 files changed, 2201 insertions(+), 1470 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index 5a5dfd17b8a..2b0f63296be 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -16,7 +16,7 @@ namespace ck_tile::builder { /********************************************************************/ // Common concept for size-related fields -template +template concept SizeType = std::unsigned_integral>; // Concept for thread block dimensions for a GEMM problem. @@ -170,7 +170,7 @@ concept SpecifiesTileThreadBlock = requires { // Concept to check if a struct specifies gridwise XDL GEMM info. template -concept GridwiseFwdXdlGemmDescriptor = requires (T t){ +concept GridwiseFwdXdlGemmDescriptor = requires(T t) { { t.ak1 } -> SizeType; { t.bk1 } -> SizeType; { t.xdl_params } -> GridwiseXdlGemmDescriptor; @@ -178,26 +178,26 @@ concept GridwiseFwdXdlGemmDescriptor = requires (T t){ // Concept to check if a struct specifies gridwise XDL GEMM info. template -concept GridwiseBwdXdlGemmDescriptor = requires (T t){ +concept GridwiseBwdXdlGemmDescriptor = requires(T t) { { t.k1 } -> SizeType; { t.xdl_params } -> GridwiseXdlGemmDescriptor; }; // Concept to check if a struct specifies gridwise XDL GEMM info. template -concept SpecifiesGridwiseFwdXdlGemm = requires (T t) { +concept SpecifiesGridwiseFwdXdlGemm = requires(T t) { { t.gridwise_gemm } -> GridwiseFwdXdlGemmDescriptor; }; // Concept to check if a struct specifies gridwise XDL GEMM info. template -concept SpecifiesGridwiseBwdXdlGemm = requires (T t) { +concept SpecifiesGridwiseBwdXdlGemm = requires(T t) { { t.gridwise_gemm } -> GridwiseBwdXdlGemmDescriptor; }; // Concept to check if a struct specifies gridwise WMMA GEMM info. template -concept SpecifiesGridwiseWmmaGemm = requires (T t){ +concept SpecifiesGridwiseWmmaGemm = requires(T t) { { t.gridwise_gemm } -> GridwiseWmmaGemmDescriptor; }; @@ -209,7 +209,7 @@ concept SpecifiesBlockTransfer = requires(T t) { { T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor; }; -// Concept to check if a struct specifies convolution input and output block transfer info +// Concept to check if a struct specifies convolution input and output block transfer info // for 4D thread slices. template concept SpecifiesBlockTransfer4D = requires(T t) { @@ -255,8 +255,7 @@ concept SpecifiesBlockGemm = requires { }; template -concept SpecifiesGridwiseGemmPipeline = requires -{ +concept SpecifiesGridwiseGemmPipeline = requires { { T::pipeline_version } -> std::convertible_to; }; diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp index 46e205d68bf..6db35a9ba0f 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp @@ -10,373 +10,520 @@ namespace ck_tile::builder::diagnostics { #define CHECK_MARK(cond) (cond ? "[✓]" : "[✗]") // Macro to check a concept and generate both the boolean and the string representation -#define CHECK_CONCEPT(Type, Concept) \ - static constexpr bool c_##Concept = Concept; \ +#define CHECK_CONCEPT(Type, Concept) \ + static constexpr bool c_##Concept = Concept; \ static constexpr const char* s_##Concept = #Concept; // Helper to create diagnostic message line -#define DIAGNOSTIC_LINE(Concept) \ +#define DIAGNOSTIC_LINE(Concept) \ " " + std::string(s_##Concept) + ": " + std::string(CHECK_MARK(c_##Concept)) + "\n" + \ - (c_##Concept ? std::string("") : detailed_diagnostic_##Concept()) + (c_##Concept ? std::string("") : detailed_diagnostic_##Concept()) namespace detail { // Helper to get type information -template -consteval auto get_type_info() -> const char* { +template +consteval auto get_type_info() -> const char* +{ // Returns a descriptive string about the type - if constexpr (std::is_same_v) { + if constexpr(std::is_same_v) + { return " (type: size_t)"; - } else if constexpr (std::is_same_v) { + } + else if constexpr(std::is_same_v) + { return " (type: int)"; - } else if constexpr (std::is_same_v) { + } + else if constexpr(std::is_same_v) + { return " (type: bool)"; - } else if constexpr (std::is_same_v) { + } + else if constexpr(std::is_same_v) + { return " (type: PipelineVersion)"; - } else if constexpr (std::is_same_v) { + } + else if constexpr(std::is_same_v) + { return " (type: PipelineScheduler)"; - } else if constexpr (std::is_same_v) { + } + else if constexpr(std::is_same_v) + { return " (type: ConvSpecialization)"; - } else if constexpr (std::is_same_v) { + } + else if constexpr(std::is_same_v) + { return " (type: GemmSpecialization)"; - } else if constexpr (std::is_same_v) { + } + else if constexpr(std::is_same_v) + { return " (type: TileConvSpecialization)"; - } else if constexpr (std::is_same_v) { + } + else if constexpr(std::is_same_v) + { return " (type: ConvAlgorithmSpecialization)"; - } else if constexpr (std::is_same_v>) { + } + else if constexpr(std::is_same_v>) + { return " (type: std::array)"; - } else if constexpr (std::is_same_v>) { + } + else if constexpr(std::is_same_v>) + { return " (type: std::array)"; - } else if constexpr (std::is_same_v>) { + } + else if constexpr(std::is_same_v>) + { return " (type: std::array)"; - } else if constexpr (std::is_same_v>) { + } + else if constexpr(std::is_same_v>) + { return " (type: std::array)"; - } else { + } + else + { return " (type: found but unknown)"; } } // ThreadBlockDescriptor diagnostics template -consteval auto diagnose_thread_block_descriptor() -> std::string { - if constexpr (!requires { T::thread_block; }) { +consteval auto diagnose_thread_block_descriptor() -> std::string +{ + if constexpr(!requires { T::thread_block; }) + { return " → T::thread_block member: [✗] (missing member)\n"; - } else { + } + else + { using TB = decltype(T::thread_block); std::string msg; - - if constexpr (requires(TB t) { t.block_size; }) { - using BlockSizeType = decltype(std::declval().block_size); + + if constexpr(requires(TB t) { t.block_size; }) + { + using BlockSizeType = decltype(std::declval().block_size); constexpr bool convertible = SizeType; - msg += " → thread_block.block_size: " + std::string(CHECK_MARK(convertible)) + + msg += " → thread_block.block_size: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(get_type_info())) + "\n"; - } else { + } + else + { msg += " → thread_block.block_size: [✗] (missing member)\n"; } - - if constexpr (requires(TB t) { t.tile_size.m; }) { - using TileMType = decltype(std::declval().tile_size.m); + + if constexpr(requires(TB t) { t.tile_size.m; }) + { + using TileMType = decltype(std::declval().tile_size.m); constexpr bool convertible = SizeType; - msg += " → thread_block.tile_size.m: " + std::string(CHECK_MARK(convertible)) + + msg += " → thread_block.tile_size.m: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(get_type_info())) + "\n"; - } else { + } + else + { msg += " → thread_block.tile_size.m: [✗] (missing member)\n"; } - - if constexpr (requires(TB t) { t.tile_size.n; }) { - using TileNType = decltype(std::declval().tile_size.n); + + if constexpr(requires(TB t) { t.tile_size.n; }) + { + using TileNType = decltype(std::declval().tile_size.n); constexpr bool convertible = SizeType; - msg += " → thread_block.tile_size.n: " + std::string(CHECK_MARK(convertible)) + + msg += " → thread_block.tile_size.n: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(get_type_info())) + "\n"; - } else { + } + else + { msg += " → thread_block.tile_size.n: [✗] (missing member)\n"; } - - if constexpr (requires(TB t) { t.tile_size.k; }) { - using TileKType = decltype(std::declval().tile_size.k); + + if constexpr(requires(TB t) { t.tile_size.k; }) + { + using TileKType = decltype(std::declval().tile_size.k); constexpr bool convertible = SizeType; - msg += " → thread_block.tile_size.k: " + std::string(CHECK_MARK(convertible)) + + msg += " → thread_block.tile_size.k: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(get_type_info())) + "\n"; - } else { + } + else + { msg += " → thread_block.tile_size.k: [✗] (missing member)\n"; } - + return msg; } } // GridwiseXdlGemmDescriptor diagnostics template -consteval auto diagnose_xdl_params() -> std::string { +consteval auto diagnose_xdl_params() -> std::string +{ std::string msg; - - if constexpr (requires(XdlParams t) { t.m_per_xdl; }) { - using MPerXdlType = decltype(std::declval().m_per_xdl); + + if constexpr(requires(XdlParams t) { t.m_per_xdl; }) + { + using MPerXdlType = decltype(std::declval().m_per_xdl); constexpr bool convertible = SizeType; - msg += " → xdl_params.m_per_xdl: " + std::string(CHECK_MARK(convertible)) + + msg += " → xdl_params.m_per_xdl: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(get_type_info())) + "\n"; - } else { + } + else + { msg += " → xdl_params.m_per_xdl: [✗] (missing member)\n"; } - - if constexpr (requires(XdlParams t) { t.n_per_xdl; }) { - using NPerXdlType = decltype(std::declval().n_per_xdl); + + if constexpr(requires(XdlParams t) { t.n_per_xdl; }) + { + using NPerXdlType = decltype(std::declval().n_per_xdl); constexpr bool convertible = SizeType; - msg += " → xdl_params.n_per_xdl: " + std::string(CHECK_MARK(convertible)) + + msg += " → xdl_params.n_per_xdl: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(get_type_info())) + "\n"; - } else { + } + else + { msg += " → xdl_params.n_per_xdl: [✗] (missing member)\n"; } - - if constexpr (requires(XdlParams t) { t.m_xdl_per_wave; }) { - using MXdlPerWaveType = decltype(std::declval().m_xdl_per_wave); + + if constexpr(requires(XdlParams t) { t.m_xdl_per_wave; }) + { + using MXdlPerWaveType = decltype(std::declval().m_xdl_per_wave); constexpr bool convertible = SizeType; - msg += " → xdl_params.m_xdl_per_wave: " + std::string(CHECK_MARK(convertible)) + + msg += " → xdl_params.m_xdl_per_wave: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(get_type_info())) + "\n"; - } else { + } + else + { msg += " → xdl_params.m_xdl_per_wave: [✗] (missing member)\n"; } - - if constexpr (requires(XdlParams t) { t.n_xdl_per_wave; }) { - using NXdlPerWaveType = decltype(std::declval().n_xdl_per_wave); + + if constexpr(requires(XdlParams t) { t.n_xdl_per_wave; }) + { + using NXdlPerWaveType = decltype(std::declval().n_xdl_per_wave); constexpr bool convertible = SizeType; - msg += " → xdl_params.n_xdl_per_wave: " + std::string(CHECK_MARK(convertible)) + + msg += " → xdl_params.n_xdl_per_wave: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(get_type_info())) + "\n"; - } else { + } + else + { msg += " → xdl_params.n_xdl_per_wave: [✗] (missing member)\n"; } - + return msg; } // BlockTransferDescriptor diagnostics template -consteval auto diagnose_block_transfer(const char* prefix) -> std::string { +consteval auto diagnose_block_transfer(const char* prefix) -> std::string +{ std::string msg; - - if constexpr (requires(BT t) { t.k0; }) { - using K0Type = decltype(std::declval().k0); + + if constexpr(requires(BT t) { t.k0; }) + { + using K0Type = decltype(std::declval().k0); constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + ".k0: " + std::string(CHECK_MARK(convertible)) + + msg += std::string(" → ") + prefix + + ".k0: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(get_type_info())) + "\n"; - } else { + } + else + { msg += std::string(" → ") + prefix + ".k0: [✗] (missing member)\n"; } - - if constexpr (requires(BT t) { t.m_n; }) { - using MNType = decltype(std::declval().m_n); + + if constexpr(requires(BT t) { t.m_n; }) + { + using MNType = decltype(std::declval().m_n); constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + ".m_n: " + std::string(CHECK_MARK(convertible)) + + msg += std::string(" → ") + prefix + + ".m_n: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(get_type_info())) + "\n"; - } else { + } + else + { msg += std::string(" → ") + prefix + ".m_n: [✗] (missing member)\n"; } - - if constexpr (requires(BT t) { t.k1; }) { - using K1Type = decltype(std::declval().k1); + + if constexpr(requires(BT t) { t.k1; }) + { + using K1Type = decltype(std::declval().k1); constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + ".k1: " + std::string(CHECK_MARK(convertible)) + + msg += std::string(" → ") + prefix + + ".k1: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(get_type_info())) + "\n"; - } else { + } + else + { msg += std::string(" → ") + prefix + ".k1: [✗] (missing member)\n"; } - + return msg; } // BlockTransferDescriptor4D diagnostics (requires k_batch_size) template -consteval auto diagnose_block_transfer_4d(const char* prefix) -> std::string { +consteval auto diagnose_block_transfer_4d(const char* prefix) -> std::string +{ std::string msg; - - if constexpr (requires(BT t) { t.k0; }) { - using K0Type = decltype(std::declval().k0); + + if constexpr(requires(BT t) { t.k0; }) + { + using K0Type = decltype(std::declval().k0); constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + ".k0: " + std::string(CHECK_MARK(convertible)) + + msg += std::string(" → ") + prefix + + ".k0: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(get_type_info())) + "\n"; - } else { + } + else + { msg += std::string(" → ") + prefix + ".k0: [✗] (missing member)\n"; } - - if constexpr (requires(BT t) { t.m_n; }) { - using MNType = decltype(std::declval().m_n); + + if constexpr(requires(BT t) { t.m_n; }) + { + using MNType = decltype(std::declval().m_n); constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + ".m_n: " + std::string(CHECK_MARK(convertible)) + + msg += std::string(" → ") + prefix + + ".m_n: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(get_type_info())) + "\n"; - } else { + } + else + { msg += std::string(" → ") + prefix + ".m_n: [✗] (missing member)\n"; } - - if constexpr (requires(BT t) { t.k1; }) { - using K1Type = decltype(std::declval().k1); + + if constexpr(requires(BT t) { t.k1; }) + { + using K1Type = decltype(std::declval().k1); constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + ".k1: " + std::string(CHECK_MARK(convertible)) + + msg += std::string(" → ") + prefix + + ".k1: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(get_type_info())) + "\n"; - } else { + } + else + { msg += std::string(" → ") + prefix + ".k1: [✗] (missing member)\n"; } - + // k_batch_size is required for Bwd descriptor - if constexpr (requires(BT t) { t.k_batch_size; }) { - using KBatchType = decltype(std::declval().k_batch_size); + if constexpr(requires(BT t) { t.k_batch_size; }) + { + using KBatchType = decltype(std::declval().k_batch_size); constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + ".k_batch_size: " + std::string(CHECK_MARK(convertible)) + + msg += std::string(" → ") + prefix + + ".k_batch_size: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(get_type_info())) + "\n"; - } else { + } + else + { msg += std::string(" → ") + prefix + ".k_batch_size: [✗] (missing member)\n"; } - + return msg; } // LdsTransferDescriptor diagnostics template -consteval auto diagnose_lds_transfer(const char* prefix) -> std::string { +consteval auto diagnose_lds_transfer(const char* prefix) -> std::string +{ std::string msg; - - if constexpr (requires(LT t) { t.src_vector_dim; }) { - using SrcVectorDimType = decltype(std::declval().src_vector_dim); + + if constexpr(requires(LT t) { t.src_vector_dim; }) + { + using SrcVectorDimType = decltype(std::declval().src_vector_dim); constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + ".src_vector_dim: " + std::string(CHECK_MARK(convertible)) + + msg += std::string(" → ") + prefix + + ".src_vector_dim: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(get_type_info())) + "\n"; - } else { + } + else + { msg += std::string(" → ") + prefix + ".src_vector_dim: [✗] (missing member)\n"; } - - if constexpr (requires(LT t) { t.src_scalar_per_vector; }) { - using SrcScalarType = decltype(std::declval().src_scalar_per_vector); + + if constexpr(requires(LT t) { t.src_scalar_per_vector; }) + { + using SrcScalarType = decltype(std::declval().src_scalar_per_vector); constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + ".src_scalar_per_vector: " + std::string(CHECK_MARK(convertible)) + + msg += std::string(" → ") + prefix + + ".src_scalar_per_vector: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(get_type_info())) + "\n"; - } else { - msg += std::string(" → ") + prefix + ".src_scalar_per_vector: [✗] (missing member)\n"; } - - if constexpr (requires(LT t) { t.lds_dst_scalar_per_vector; }) { - using LdsDstScalarType = decltype(std::declval().lds_dst_scalar_per_vector); + else + { + msg += + std::string(" → ") + prefix + ".src_scalar_per_vector: [✗] (missing member)\n"; + } + + if constexpr(requires(LT t) { t.lds_dst_scalar_per_vector; }) + { + using LdsDstScalarType = decltype(std::declval().lds_dst_scalar_per_vector); constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + ".lds_dst_scalar_per_vector: " + std::string(CHECK_MARK(convertible)) + + msg += std::string(" → ") + prefix + + ".lds_dst_scalar_per_vector: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(get_type_info())) + "\n"; - } else { - msg += std::string(" → ") + prefix + ".lds_dst_scalar_per_vector: [✗] (missing member)\n"; } - - if constexpr (requires(LT t) { t.is_direct_load; }) { - using IsDirectLoadType = decltype(std::declval().is_direct_load); + else + { + msg += std::string(" → ") + prefix + + ".lds_dst_scalar_per_vector: [✗] (missing member)\n"; + } + + if constexpr(requires(LT t) { t.is_direct_load; }) + { + using IsDirectLoadType = decltype(std::declval().is_direct_load); constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + ".is_direct_load: " + std::string(CHECK_MARK(convertible)) + + msg += std::string(" → ") + prefix + + ".is_direct_load: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(get_type_info())) + "\n"; - } else { + } + else + { msg += std::string(" → ") + prefix + ".is_direct_load: [✗] (missing member)\n"; } - - if constexpr (requires(LT t) { t.lds_padding; }) { - using LdsPaddingType = decltype(std::declval().lds_padding); + + if constexpr(requires(LT t) { t.lds_padding; }) + { + using LdsPaddingType = decltype(std::declval().lds_padding); constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + ".lds_padding: " + std::string(CHECK_MARK(convertible)) + + msg += std::string(" → ") + prefix + + ".lds_padding: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(get_type_info())) + "\n"; - } else { + } + else + { msg += std::string(" → ") + prefix + ".lds_padding: [✗] (missing member)\n"; } - + return msg; } // ThreadClusterDescriptor diagnostics template -consteval auto diagnose_thread_cluster(const char* prefix) -> std::string { +consteval auto diagnose_thread_cluster(const char* prefix) -> std::string +{ std::string msg; - - if constexpr (requires(TC t) { t.m_block; }) { - using MBlockType = decltype(std::declval().m_block); + + if constexpr(requires(TC t) { t.m_block; }) + { + using MBlockType = decltype(std::declval().m_block); constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + ".m_block: " + std::string(CHECK_MARK(convertible)) + + msg += std::string(" → ") + prefix + + ".m_block: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(get_type_info())) + "\n"; - } else { + } + else + { msg += std::string(" → ") + prefix + ".m_block: [✗] (missing member)\n"; } - - if constexpr (requires(TC t) { t.m_wave_per_xdl; }) { - using MWaveType = decltype(std::declval().m_wave_per_xdl); + + if constexpr(requires(TC t) { t.m_wave_per_xdl; }) + { + using MWaveType = decltype(std::declval().m_wave_per_xdl); constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + ".m_wave_per_xdl: " + std::string(CHECK_MARK(convertible)) + + msg += std::string(" → ") + prefix + + ".m_wave_per_xdl: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(get_type_info())) + "\n"; - } else { + } + else + { msg += std::string(" → ") + prefix + ".m_wave_per_xdl: [✗] (missing member)\n"; } - - if constexpr (requires(TC t) { t.n_block; }) { - using NBlockType = decltype(std::declval().n_block); + + if constexpr(requires(TC t) { t.n_block; }) + { + using NBlockType = decltype(std::declval().n_block); constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + ".n_block: " + std::string(CHECK_MARK(convertible)) + + msg += std::string(" → ") + prefix + + ".n_block: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(get_type_info())) + "\n"; - } else { + } + else + { msg += std::string(" → ") + prefix + ".n_block: [✗] (missing member)\n"; } - - if constexpr (requires(TC t) { t.n_wave_per_xdl; }) { - using NWaveType = decltype(std::declval().n_wave_per_xdl); + + if constexpr(requires(TC t) { t.n_wave_per_xdl; }) + { + using NWaveType = decltype(std::declval().n_wave_per_xdl); constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + ".n_wave_per_xdl: " + std::string(CHECK_MARK(convertible)) + + msg += std::string(" → ") + prefix + + ".n_wave_per_xdl: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(get_type_info())) + "\n"; - } else { + } + else + { msg += std::string(" → ") + prefix + ".n_wave_per_xdl: [✗] (missing member)\n"; } - + return msg; } // AccessOrderDescriptor diagnostics template -consteval auto diagnose_access_order(const char* prefix) -> std::string { +consteval auto diagnose_access_order(const char* prefix) -> std::string +{ std::string msg; - - if constexpr (requires(AO t) { t.order; }) { - using OrderType = decltype(std::declval().order); + + if constexpr(requires(AO t) { t.order; }) + { + using OrderType = decltype(std::declval().order); constexpr bool convertible_3 = std::convertible_to>; constexpr bool convertible_4 = std::convertible_to>; - constexpr bool convertible = convertible_3 || convertible_4; - msg += std::string(" → ") + prefix + ".order: " + std::string(CHECK_MARK(convertible)) + + constexpr bool convertible = convertible_3 || convertible_4; + msg += std::string(" → ") + prefix + + ".order: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(get_type_info())) + "\n"; - } else { + } + else + { msg += std::string(" → ") + prefix + ".order: [✗] (missing member)\n"; } - + return msg; } // EpilogueDescriptor diagnostics template -consteval auto diagnose_epilogue(const char* prefix) -> std::string { +consteval auto diagnose_epilogue(const char* prefix) -> std::string +{ std::string msg; - - if constexpr (requires(E t) { t.m_xdl_per_wave_per_shuffle; }) { - using MXdlType = decltype(std::declval().m_xdl_per_wave_per_shuffle); + + if constexpr(requires(E t) { t.m_xdl_per_wave_per_shuffle; }) + { + using MXdlType = decltype(std::declval().m_xdl_per_wave_per_shuffle); constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + ".m_xdl_per_wave_per_shuffle: " + std::string(CHECK_MARK(convertible)) + + msg += std::string(" → ") + prefix + + ".m_xdl_per_wave_per_shuffle: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(get_type_info())) + "\n"; - } else { - msg += std::string(" → ") + prefix + ".m_xdl_per_wave_per_shuffle: [✗] (missing member)\n"; } - - if constexpr (requires(E t) { t.n_per_wave_per_shuffle; }) { - using NPerWaveType = decltype(std::declval().n_per_wave_per_shuffle); + else + { + msg += std::string(" → ") + prefix + + ".m_xdl_per_wave_per_shuffle: [✗] (missing member)\n"; + } + + if constexpr(requires(E t) { t.n_per_wave_per_shuffle; }) + { + using NPerWaveType = decltype(std::declval().n_per_wave_per_shuffle); constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + ".n_per_wave_per_shuffle: " + std::string(CHECK_MARK(convertible)) + + msg += std::string(" → ") + prefix + + ".n_per_wave_per_shuffle: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(get_type_info())) + "\n"; - } else { - msg += std::string(" → ") + prefix + ".n_per_wave_per_shuffle: [✗] (missing member)\n"; } - - if constexpr (requires(E t) { t.scalar_per_vector; }) { - using ScalarType = decltype(std::declval().scalar_per_vector); + else + { + msg += std::string(" → ") + prefix + + ".n_per_wave_per_shuffle: [✗] (missing member)\n"; + } + + if constexpr(requires(E t) { t.scalar_per_vector; }) + { + using ScalarType = decltype(std::declval().scalar_per_vector); constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + ".scalar_per_vector: " + std::string(CHECK_MARK(convertible)) + + msg += std::string(" → ") + prefix + + ".scalar_per_vector: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(get_type_info())) + "\n"; - } else { + } + else + { msg += std::string(" → ") + prefix + ".scalar_per_vector: [✗] (missing member)\n"; } - + return msg; } @@ -384,398 +531,544 @@ consteval auto diagnose_epilogue(const char* prefix) -> std::string { // Detailed diagnostic functions for high-level concepts template -consteval auto detailed_diagnostic_ConvAlgorithmDescriptor() -> std::string { +consteval auto detailed_diagnostic_ConvAlgorithmDescriptor() -> std::string +{ return ""; // Base concept, no sub-requirements to check } template -consteval auto detailed_diagnostic_SpecifiesThreadBlock() -> std::string { - if constexpr (!requires { { T::thread_block } -> ThreadBlockDescriptor; }) { +consteval auto detailed_diagnostic_SpecifiesThreadBlock() -> std::string +{ + if constexpr(!requires { + { T::thread_block } -> ThreadBlockDescriptor; + }) + { return " → T::thread_block member: [✗] (missing or wrong type)\n"; - } else { - return " → T::thread_block member: [✓]\n" + + } + else + { + return " → T::thread_block member: [✓]\n" + detail::diagnose_thread_block_descriptor(); } } template -consteval auto detailed_diagnostic_SpecifiesGridwiseFwdXdlGemm() -> std::string { +consteval auto detailed_diagnostic_SpecifiesGridwiseFwdXdlGemm() -> std::string +{ std::string msg; - - if constexpr (!requires(T t) { { t.gridwise_gemm } -> GridwiseFwdXdlGemmDescriptor; }) { + + if constexpr(!requires(T t) { + { t.gridwise_gemm } -> GridwiseFwdXdlGemmDescriptor; + }) + { return " → T::gridwise_gemm member: [✗] (missing or wrong type)\n"; } - + msg += " → T::gridwise_gemm member: [✓]\n"; using GG = decltype(T::gridwise_gemm); - - if constexpr (requires(GG t) { t.ak1; }) { - using AK1Type = decltype(std::declval().ak1); + + if constexpr(requires(GG t) { t.ak1; }) + { + using AK1Type = decltype(std::declval().ak1); constexpr bool convertible = std::convertible_to; - msg += " → gridwise_gemm.ak1: " + std::string(CHECK_MARK(convertible)) + + msg += " → gridwise_gemm.ak1: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; - } else { + } + else + { msg += " → gridwise_gemm.ak1: [✗] (missing member)\n"; } - - if constexpr (requires(GG t) { t.bk1; }) { - using BK1Type = decltype(std::declval().bk1); + + if constexpr(requires(GG t) { t.bk1; }) + { + using BK1Type = decltype(std::declval().bk1); constexpr bool convertible = std::convertible_to; - msg += " → gridwise_gemm.bk1: " + std::string(CHECK_MARK(convertible)) + + msg += " → gridwise_gemm.bk1: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; - } else { + } + else + { msg += " → gridwise_gemm.bk1: [✗] (missing member)\n"; } - - if constexpr (requires(GG t) { t.xdl_params; }) { + + if constexpr(requires(GG t) { t.xdl_params; }) + { msg += " → gridwise_gemm.xdl_params member: [✓]\n"; msg += detail::diagnose_xdl_params().xdl_params)>(); - } else { + } + else + { msg += " → gridwise_gemm.xdl_params: [✗] (missing member)\n"; } - + return msg; } template -consteval auto detailed_diagnostic_SpecifiesGridwiseBwdXdlGemm() -> std::string { +consteval auto detailed_diagnostic_SpecifiesGridwiseBwdXdlGemm() -> std::string +{ std::string msg; - - if constexpr (!requires(T t) { { t.gridwise_gemm } -> GridwiseBwdXdlGemmDescriptor; }) { + + if constexpr(!requires(T t) { + { t.gridwise_gemm } -> GridwiseBwdXdlGemmDescriptor; + }) + { return " → T::gridwise_gemm member: [✗] (missing or wrong type)\n"; } - + msg += " → T::gridwise_gemm member: [✓]\n"; using GG = decltype(T::gridwise_gemm); - - if constexpr (requires(GG t) { t.k1; }) { - using K1Type = decltype(std::declval().k1); + + if constexpr(requires(GG t) { t.k1; }) + { + using K1Type = decltype(std::declval().k1); constexpr bool convertible = std::convertible_to; - msg += " → gridwise_gemm.k1: " + std::string(CHECK_MARK(convertible)) + + msg += " → gridwise_gemm.k1: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; - } else { + } + else + { msg += " → gridwise_gemm.k1: [✗] (missing member)\n"; } - - if constexpr (requires(GG t) { t.xdl_params; }) { + + if constexpr(requires(GG t) { t.xdl_params; }) + { msg += " → gridwise_gemm.xdl_params member: [✓]\n"; msg += detail::diagnose_xdl_params().xdl_params)>(); - } else { + } + else + { msg += " → gridwise_gemm.xdl_params: [✗] (missing member)\n"; } - + return msg; } template -consteval auto detailed_diagnostic_SpecifiesBlockTransfer() -> std::string { +consteval auto detailed_diagnostic_SpecifiesBlockTransfer() -> std::string +{ std::string msg; - + constexpr bool has_transfer = requires { T::transfer; }; msg += " → T::transfer member: " + std::string(CHECK_MARK(has_transfer)) + "\n"; - - if constexpr (!has_transfer) { + + if constexpr(!has_transfer) + { return msg; } - - constexpr bool has_a = requires { { T::transfer.a.block_transfer } -> BlockTransferDescriptor; }; + + constexpr bool has_a = requires { + { T::transfer.a.block_transfer } -> BlockTransferDescriptor; + }; msg += " → T::transfer.a: " + std::string(CHECK_MARK(has_a)) + "\n"; - if constexpr (!has_a) { + if constexpr(!has_a) + { msg += " → T::transfer.a.block_transfer: [✗] (missing or wrong type)\n"; } - - constexpr bool has_b = requires { { T::transfer.b.block_transfer } -> BlockTransferDescriptor; }; + + constexpr bool has_b = requires { + { T::transfer.b.block_transfer } -> BlockTransferDescriptor; + }; msg += " → T::transfer.b: " + std::string(CHECK_MARK(has_b)) + "\n"; - if constexpr (!has_b) { + if constexpr(!has_b) + { msg += " → T::transfer.b.block_transfer: [✗] (missing or wrong type)\n"; } - - constexpr bool has_c = requires { { T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor; }; + + constexpr bool has_c = requires { + { T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor; + }; msg += " → T::transfer.c: " + std::string(CHECK_MARK(has_c)) + "\n"; - if constexpr (!has_c) { + if constexpr(!has_c) + { msg += " → T::transfer.c.thread_cluster_dims: [✗] (missing or wrong type)\n"; } - + return msg; } template -consteval auto detailed_diagnostic_SpecifiesBlockTransfer4D() -> std::string { +consteval auto detailed_diagnostic_SpecifiesBlockTransfer4D() -> std::string +{ std::string msg; - + constexpr bool has_transfer = requires { T::transfer; }; msg += " → T::transfer member: " + std::string(CHECK_MARK(has_transfer)) + "\n"; - - if constexpr (!has_transfer) { + + if constexpr(!has_transfer) + { return msg; } - - constexpr bool has_a = requires { { T::transfer.a.block_transfer } -> BlockTransferDescriptor4D; }; + + constexpr bool has_a = requires { + { T::transfer.a.block_transfer } -> BlockTransferDescriptor4D; + }; msg += " → T::transfer.a: " + std::string(CHECK_MARK(has_a)) + "\n"; - if constexpr (!has_a) { + if constexpr(!has_a) + { msg += " → T::transfer.a.block_transfer: [✗] (missing or wrong type)\n"; - } else { - msg += detail::diagnose_block_transfer_4d("transfer.a.block_transfer"); } - - constexpr bool has_b = requires { { T::transfer.b.block_transfer } -> BlockTransferDescriptor4D; }; + else + { + msg += detail::diagnose_block_transfer_4d( + "transfer.a.block_transfer"); + } + + constexpr bool has_b = requires { + { T::transfer.b.block_transfer } -> BlockTransferDescriptor4D; + }; msg += " → T::transfer.b: " + std::string(CHECK_MARK(has_b)) + "\n"; - if constexpr (!has_b) { + if constexpr(!has_b) + { msg += " → T::transfer.b.block_transfer: [✗] (missing or wrong type)\n"; - } else { - msg += detail::diagnose_block_transfer_4d("transfer.b.block_transfer"); } - - constexpr bool has_c = requires { { T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor; }; + else + { + msg += detail::diagnose_block_transfer_4d( + "transfer.b.block_transfer"); + } + + constexpr bool has_c = requires { + { T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor; + }; msg += " → T::transfer.c: " + std::string(CHECK_MARK(has_c)) + "\n"; - if constexpr (!has_c) { + if constexpr(!has_c) + { msg += " → T::transfer.c.thread_cluster_dims: [✗] (missing or wrong type)\n"; } - + return msg; } template -consteval auto detailed_diagnostic_SpecifiesThreadClusterAccessOrder() -> std::string { +consteval auto detailed_diagnostic_SpecifiesThreadClusterAccessOrder() -> std::string +{ std::string msg; - + constexpr bool has_transfer = requires { T::transfer; }; - if constexpr (!has_transfer) { + if constexpr(!has_transfer) + { return " → T::transfer member: [✗] (missing member)\n"; } - + constexpr bool has_a = requires { T::transfer.a; }; constexpr bool has_b = requires { T::transfer.b; }; - - if constexpr (has_a && requires { T::transfer.a.block_transfer_access_order; }) { - msg += detail::diagnose_access_order("transfer.a.block_transfer_access_order"); - } else if constexpr (has_a) { + + if constexpr(has_a && requires { T::transfer.a.block_transfer_access_order; }) + { + msg += + detail::diagnose_access_order( + "transfer.a.block_transfer_access_order"); + } + else if constexpr(has_a) + { msg += " → T::transfer.a.block_transfer_access_order: [✗] (missing member)\n"; } - - if constexpr (has_b && requires { T::transfer.b.block_transfer_access_order; }) { - msg += detail::diagnose_access_order("transfer.b.block_transfer_access_order"); - } else if constexpr (has_b) { + + if constexpr(has_b && requires { T::transfer.b.block_transfer_access_order; }) + { + msg += + detail::diagnose_access_order( + "transfer.b.block_transfer_access_order"); + } + else if constexpr(has_b) + { msg += " → T::transfer.b.block_transfer_access_order: [✗] (missing member)\n"; } - + return msg; } template -consteval auto detailed_diagnostic_SpecifiesSourceAccessOrder() -> std::string { +consteval auto detailed_diagnostic_SpecifiesSourceAccessOrder() -> std::string +{ std::string msg; - + constexpr bool has_transfer = requires { T::transfer; }; - if constexpr (!has_transfer) { + if constexpr(!has_transfer) + { return " → T::transfer member: [✗] (missing member)\n"; } - + constexpr bool has_a = requires { T::transfer.a; }; constexpr bool has_b = requires { T::transfer.b; }; - - if constexpr (has_a && requires { T::transfer.a.src_access_order; }) { - msg += detail::diagnose_access_order("transfer.a.src_access_order"); - } else if constexpr (has_a) { + + if constexpr(has_a && requires { T::transfer.a.src_access_order; }) + { + msg += detail::diagnose_access_order( + "transfer.a.src_access_order"); + } + else if constexpr(has_a) + { msg += " → T::transfer.a.src_access_order: [✗] (missing member)\n"; } - - if constexpr (has_b && requires { T::transfer.b.src_access_order; }) { - msg += detail::diagnose_access_order("transfer.b.src_access_order"); - } else if constexpr (has_b) { + + if constexpr(has_b && requires { T::transfer.b.src_access_order; }) + { + msg += detail::diagnose_access_order( + "transfer.b.src_access_order"); + } + else if constexpr(has_b) + { msg += " → T::transfer.b.src_access_order: [✗] (missing member)\n"; } - + return msg; } template -consteval auto detailed_diagnostic_SpecifiesBlockGemm() -> std::string { +consteval auto detailed_diagnostic_SpecifiesBlockGemm() -> std::string +{ std::string msg; - - if constexpr (!requires { { T::block_gemm_pipeline } -> BlockGemmPipelineDescriptor; }) { + + if constexpr(!requires { + { T::block_gemm_pipeline } -> BlockGemmPipelineDescriptor; + }) + { return " → T::block_gemm_pipeline: [✗] (missing or wrong type)\n"; } - + msg += " → T::block_gemm_pipeline member: [✓]\n"; - - if constexpr (requires { T::block_gemm_pipeline.pipeline_version; }) { - using PipelineType = decltype(T::block_gemm_pipeline.pipeline_version); + + if constexpr(requires { T::block_gemm_pipeline.pipeline_version; }) + { + using PipelineType = decltype(T::block_gemm_pipeline.pipeline_version); constexpr bool convertible = std::convertible_to; - msg += " → block_gemm_pipeline.pipeline_version: " + std::string(CHECK_MARK(convertible)) + + msg += " → block_gemm_pipeline.pipeline_version: " + + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; - } else { + } + else + { msg += " → block_gemm_pipeline.pipeline_version: [✗] (missing member)\n"; } - - if constexpr (requires { T::block_gemm_pipeline.scheduler; }) { - using SchedulerType = decltype(T::block_gemm_pipeline.scheduler); + + if constexpr(requires { T::block_gemm_pipeline.scheduler; }) + { + using SchedulerType = decltype(T::block_gemm_pipeline.scheduler); constexpr bool convertible = std::convertible_to; - msg += " → block_gemm_pipeline.scheduler: " + std::string(CHECK_MARK(convertible)) + + msg += " → block_gemm_pipeline.scheduler: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; - } else { + } + else + { msg += " → block_gemm_pipeline.scheduler: [✗] (missing member)\n"; } - + return msg; } template -consteval auto detailed_diagnostic_SpecifiesFwdConvSpecialization() -> std::string { - if constexpr (requires { T::fwd_specialization; }) { - using FwdSpecType = decltype(T::fwd_specialization); +consteval auto detailed_diagnostic_SpecifiesFwdConvSpecialization() -> std::string +{ + if constexpr(requires { T::fwd_specialization; }) + { + using FwdSpecType = decltype(T::fwd_specialization); constexpr bool convertible = std::convertible_to; - return " → T::fwd_specialization: " + std::string(CHECK_MARK(convertible)) + + return " → T::fwd_specialization: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; - } else { + } + else + { return " → T::fwd_specialization: [✗] (missing member)\n"; } } template -consteval auto detailed_diagnostic_SpecifiesBwdWeightConvSpecialization() -> std::string { - if constexpr (requires { T::bwd_weight_specialization; }) { - using BwdSpecType = decltype(T::bwd_weight_specialization); +consteval auto detailed_diagnostic_SpecifiesBwdWeightConvSpecialization() -> std::string +{ + if constexpr(requires { T::bwd_weight_specialization; }) + { + using BwdSpecType = decltype(T::bwd_weight_specialization); constexpr bool convertible = std::convertible_to; - return " → T::bwd_weight_specialization: " + std::string(CHECK_MARK(convertible)) + + return " → T::bwd_weight_specialization: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; - } else { + } + else + { return " → T::bwd_weight_specialization: [✗] (missing member)\n"; } } template -consteval auto detailed_diagnostic_SpecifiesGemmSpecialization() -> std::string { - if constexpr (requires { T::gemm_specialization; }) { - using GemmSpecType = decltype(T::gemm_specialization); +consteval auto detailed_diagnostic_SpecifiesGemmSpecialization() -> std::string +{ + if constexpr(requires { T::gemm_specialization; }) + { + using GemmSpecType = decltype(T::gemm_specialization); constexpr bool convertible = std::convertible_to; - return " → T::gemm_specialization: " + std::string(CHECK_MARK(convertible)) + + return " → T::gemm_specialization: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; - } else { + } + else + { return " → T::gemm_specialization: [✗] (missing member)\n"; } } template -consteval auto detailed_diagnostic_SpecifiesNumPrefetchStages() -> std::string { - if constexpr (requires { T::num_gemm_k_prefetch_stages; }) { - using NumPrefetchType = decltype(T::num_gemm_k_prefetch_stages); +consteval auto detailed_diagnostic_SpecifiesNumPrefetchStages() -> std::string +{ + if constexpr(requires { T::num_gemm_k_prefetch_stages; }) + { + using NumPrefetchType = decltype(T::num_gemm_k_prefetch_stages); constexpr bool convertible = std::convertible_to; - return " → T::num_gemm_k_prefetch_stages: " + std::string(CHECK_MARK(convertible)) + + return " → T::num_gemm_k_prefetch_stages: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; - } else { + } + else + { return " → T::num_gemm_k_prefetch_stages: [✗] (missing member)\n"; } } template -consteval auto detailed_diagnostic_SpecifiesNumGroupsToMerge() -> std::string { - if constexpr (requires { T::num_conv_groups_to_merge; }) { - using NumGroupsType = decltype(T::num_conv_groups_to_merge); +consteval auto detailed_diagnostic_SpecifiesNumGroupsToMerge() -> std::string +{ + if constexpr(requires { T::num_conv_groups_to_merge; }) + { + using NumGroupsType = decltype(T::num_conv_groups_to_merge); constexpr bool convertible = std::convertible_to; - return " → T::num_conv_groups_to_merge: " + std::string(CHECK_MARK(convertible)) + + return " → T::num_conv_groups_to_merge: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; - } else { + } + else + { return " → T::num_conv_groups_to_merge: [✗] (missing member)\n"; } } template -consteval auto detailed_diagnostic_SpecifiesLoopScheduler() -> std::string { - if constexpr (requires { T::loop_scheduler; }) { - using LoopSchedulerType = decltype(T::loop_scheduler); +consteval auto detailed_diagnostic_SpecifiesLoopScheduler() -> std::string +{ + if constexpr(requires { T::loop_scheduler; }) + { + using LoopSchedulerType = decltype(T::loop_scheduler); constexpr bool convertible = std::convertible_to; - return " → T::loop_scheduler: " + std::string(CHECK_MARK(convertible)) + + return " → T::loop_scheduler: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; - } else { + } + else + { return " → T::loop_scheduler: [✗] (missing member)\n"; } } template -consteval auto detailed_diagnostic_SpecifiesLargeTensorSupport() -> std::string { +consteval auto detailed_diagnostic_SpecifiesLargeTensorSupport() -> std::string +{ std::string msg; - if constexpr (requires { T::specialization; }) { - using SpecType = decltype(T::specialization); + if constexpr(requires { T::specialization; }) + { + using SpecType = decltype(T::specialization); constexpr bool convertible = std::convertible_to; - msg += " → T::specialization: " + std::string(CHECK_MARK(convertible)) + + msg += " → T::specialization: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; - - if constexpr (convertible) { - constexpr bool is_large_tensor = (T::specialization == ConvAlgorithmSpecialization::LARGE_TENSOR); - msg += " → specialization == LARGE_TENSOR: " + std::string(CHECK_MARK(is_large_tensor)) + "\n"; + + if constexpr(convertible) + { + constexpr bool is_large_tensor = + (T::specialization == ConvAlgorithmSpecialization::LARGE_TENSOR); + msg += " → specialization == LARGE_TENSOR: " + + std::string(CHECK_MARK(is_large_tensor)) + "\n"; } - } else { + } + else + { msg += " → T::specialization: [✗] (missing member)\n"; } - + return msg; } template -consteval auto detailed_diagnostic_SpecifiesReferenceAlgorithm() -> std::string { +consteval auto detailed_diagnostic_SpecifiesReferenceAlgorithm() -> std::string +{ std::string msg; - if constexpr (requires { T::specialization; }) { - using SpecType = decltype(T::specialization); + if constexpr(requires { T::specialization; }) + { + using SpecType = decltype(T::specialization); constexpr bool convertible = std::convertible_to; - msg += " → T::specialization: " + std::string(CHECK_MARK(convertible)) + + msg += " → T::specialization: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; - - if constexpr (convertible) { - constexpr bool is_reference = (T::specialization == ConvAlgorithmSpecialization::REFERENCE); - msg += " → specialization == REFERENCE: " + std::string(CHECK_MARK(is_reference)) + "\n"; + + if constexpr(convertible) + { + constexpr bool is_reference = + (T::specialization == ConvAlgorithmSpecialization::REFERENCE); + msg += " → specialization == REFERENCE: " + std::string(CHECK_MARK(is_reference)) + + "\n"; } - } else { + } + else + { msg += " → T::specialization: [✗] (missing member)\n"; } - + return msg; } template -consteval auto detailed_diagnostic_SpecifiesTwoStageSupport() -> std::string { +consteval auto detailed_diagnostic_SpecifiesTwoStageSupport() -> std::string +{ std::string msg; - if constexpr (requires { T::specialization; }) { - using SpecType = decltype(T::specialization); + if constexpr(requires { T::specialization; }) + { + using SpecType = decltype(T::specialization); constexpr bool convertible = std::convertible_to; - msg += " → T::specialization: " + std::string(CHECK_MARK(convertible)) + + msg += " → T::specialization: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; - - if constexpr (convertible) { - constexpr bool is_two_stage = (T::specialization == ConvAlgorithmSpecialization::TWO_STAGE); - msg += " → specialization == TWO_STAGE: " + std::string(CHECK_MARK(is_two_stage)) + "\n"; + + if constexpr(convertible) + { + constexpr bool is_two_stage = + (T::specialization == ConvAlgorithmSpecialization::TWO_STAGE); + msg += " → specialization == TWO_STAGE: " + std::string(CHECK_MARK(is_two_stage)) + + "\n"; } - } else { + } + else + { msg += " → T::specialization: [✗] (missing member)\n"; } - + return msg; } template -consteval auto detailed_diagnostic_SpecifiesMultipleDSupport() -> std::string { +consteval auto detailed_diagnostic_SpecifiesMultipleDSupport() -> std::string +{ std::string msg; - if constexpr (requires { T::specialization; }) { - using SpecType = decltype(T::specialization); + if constexpr(requires { T::specialization; }) + { + using SpecType = decltype(T::specialization); constexpr bool convertible = std::convertible_to; - msg += " → T::specialization: " + std::string(CHECK_MARK(convertible)) + + msg += " → T::specialization: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; - - if constexpr (convertible) { - constexpr bool is_multiple_d = (T::specialization == ConvAlgorithmSpecialization::MULTIPLE_D); - msg += " → specialization == MULTIPLE_D: " + std::string(CHECK_MARK(is_multiple_d)) + "\n"; + + if constexpr(convertible) + { + constexpr bool is_multiple_d = + (T::specialization == ConvAlgorithmSpecialization::MULTIPLE_D); + msg += + " → specialization == MULTIPLE_D: " + std::string(CHECK_MARK(is_multiple_d)) + + "\n"; } - } else { + } + else + { msg += " → T::specialization: [✗] (missing member)\n"; } - + return msg; } template -consteval auto detailed_diagnostic_SpecifiesGenericInstance() -> std::string { +consteval auto detailed_diagnostic_SpecifiesGenericInstance() -> std::string +{ std::string msg; - if constexpr (requires { T::specialization; }) { + if constexpr(requires { T::specialization; }) + { msg += " → T::specialization: [✗] (member should NOT exist for generic instance)\n"; msg += " → This concept requires the absence of the specialization member\n"; } @@ -783,423 +1076,692 @@ consteval auto detailed_diagnostic_SpecifiesGenericInstance() -> std::string { } template -consteval auto detailed_diagnostic_SpecifiesTransposeTransfer() -> std::string { +consteval auto detailed_diagnostic_SpecifiesTransposeTransfer() -> std::string +{ std::string msg; - - if constexpr (requires { T::max_transpose_transfer_src_scalar_per_vector; }) { - using SrcType = decltype(T::max_transpose_transfer_src_scalar_per_vector); + + if constexpr(requires { T::max_transpose_transfer_src_scalar_per_vector; }) + { + using SrcType = decltype(T::max_transpose_transfer_src_scalar_per_vector); constexpr bool convertible = std::convertible_to; - msg += " → T::max_transpose_transfer_src_scalar_per_vector: " + std::string(CHECK_MARK(convertible)) + + msg += " → T::max_transpose_transfer_src_scalar_per_vector: " + + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; - } else { + } + else + { msg += " → T::max_transpose_transfer_src_scalar_per_vector: [✗] (missing member)\n"; } - - if constexpr (requires { T::max_transpose_transfer_dst_scalar_per_vector; }) { - using DstType = decltype(T::max_transpose_transfer_dst_scalar_per_vector); + + if constexpr(requires { T::max_transpose_transfer_dst_scalar_per_vector; }) + { + using DstType = decltype(T::max_transpose_transfer_dst_scalar_per_vector); constexpr bool convertible = std::convertible_to; - msg += " → T::max_transpose_transfer_dst_scalar_per_vector: " + std::string(CHECK_MARK(convertible)) + + msg += " → T::max_transpose_transfer_dst_scalar_per_vector: " + + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; - } else { + } + else + { msg += " → T::max_transpose_transfer_dst_scalar_per_vector: [✗] (missing member)\n"; } - + return msg; } template -consteval auto detailed_diagnostic_TransposeTransferWellDefinedIfProvided() -> std::string { +consteval auto detailed_diagnostic_TransposeTransferWellDefinedIfProvided() -> std::string +{ std::string msg; - + constexpr bool has_src = requires { T::max_transpose_transfer_src_scalar_per_vector; }; constexpr bool has_dst = requires { T::max_transpose_transfer_dst_scalar_per_vector; }; constexpr bool has_transpose_transfer = has_src || has_dst; - - if constexpr (!has_transpose_transfer) { + + if constexpr(!has_transpose_transfer) + { msg += " → Transpose transfer fields not provided: [✓] (optional, not required)\n"; - } else { + } + else + { msg += " → Transpose transfer fields provided, checking if well-defined:\n"; - - if constexpr (has_src) { + + if constexpr(has_src) + { using SrcType = decltype(T::max_transpose_transfer_src_scalar_per_vector); constexpr bool src_convertible = std::convertible_to; - msg += " → T::max_transpose_transfer_src_scalar_per_vector: " + - std::string(CHECK_MARK(src_convertible)) + + msg += " → T::max_transpose_transfer_src_scalar_per_vector: " + + std::string(CHECK_MARK(src_convertible)) + (src_convertible ? "" : std::string(detail::get_type_info())) + "\n"; - } else { - msg += " → T::max_transpose_transfer_src_scalar_per_vector: [✗] (missing, but dst is provided)\n"; } - - if constexpr (has_dst) { + else + { + msg += " → T::max_transpose_transfer_src_scalar_per_vector: [✗] (missing, but " + "dst is provided)\n"; + } + + if constexpr(has_dst) + { using DstType = decltype(T::max_transpose_transfer_dst_scalar_per_vector); constexpr bool dst_convertible = std::convertible_to; - msg += " → T::max_transpose_transfer_dst_scalar_per_vector: " + - std::string(CHECK_MARK(dst_convertible)) + + msg += " → T::max_transpose_transfer_dst_scalar_per_vector: " + + std::string(CHECK_MARK(dst_convertible)) + (dst_convertible ? "" : std::string(detail::get_type_info())) + "\n"; - } else { - msg += " → T::max_transpose_transfer_dst_scalar_per_vector: [✗] (missing, but src is provided)\n"; + } + else + { + msg += " → T::max_transpose_transfer_dst_scalar_per_vector: [✗] (missing, but " + "src is provided)\n"; } } - + return msg; } template -consteval auto detailed_diagnostic_SpecifiesGemmBatchOptions() -> std::string { - if constexpr (requires { T::num_conv_groups_to_merge; }) { - using NumGroupsType = decltype(T::num_conv_groups_to_merge); +consteval auto detailed_diagnostic_SpecifiesGemmBatchOptions() -> std::string +{ + if constexpr(requires { T::num_conv_groups_to_merge; }) + { + using NumGroupsType = decltype(T::num_conv_groups_to_merge); constexpr bool convertible = std::convertible_to; - return " → T::num_conv_groups_to_merge: " + std::string(CHECK_MARK(convertible)) + + return " → T::num_conv_groups_to_merge: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; - } else { + } + else + { return " → T::num_conv_groups_to_merge: [✗] (missing member)\n"; } } template -consteval auto detailed_diagnostic_SpecifiesGridwiseWmmaGemm() -> std::string { +consteval auto detailed_diagnostic_SpecifiesGridwiseWmmaGemm() -> std::string +{ std::string msg; - constexpr bool has_gridwise_gemm = requires(T t) { { t.gridwise_gemm } -> GridwiseWmmaGemmDescriptor; }; + constexpr bool has_gridwise_gemm = requires(T t) { + { t.gridwise_gemm } -> GridwiseWmmaGemmDescriptor; + }; msg += " → T::gridwise_gemm member: " + std::string(CHECK_MARK(has_gridwise_gemm)) + "\n"; - - if constexpr (!has_gridwise_gemm) { + + if constexpr(!has_gridwise_gemm) + { return msg; } - - using GG = decltype(T::gridwise_gemm); - constexpr bool has_k1 = requires(GG t) { { t.k1 } -> std::convertible_to; }; - constexpr bool has_m_per_wmma = requires(GG t) { { t.m_per_wmma } -> std::convertible_to; }; - constexpr bool has_n_per_wmma = requires(GG t) { { t.n_per_wmma } -> std::convertible_to; }; - constexpr bool has_m_wmma_per_wave = requires(GG t) { { t.m_wmma_per_wave } -> std::convertible_to; }; - constexpr bool has_n_wmma_per_wave = requires(GG t) { { t.n_wmma_per_wave } -> std::convertible_to; }; - - msg += " → gridwise_gemm.k1: " + std::string(CHECK_MARK(has_k1)) + (has_k1 ? "\n" : " (missing or wrong type)\n"); - msg += " → gridwise_gemm.m_per_wmma: " + std::string(CHECK_MARK(has_m_per_wmma)) + (has_m_per_wmma ? "\n" : " (missing or wrong type)\n"); - msg += " → gridwise_gemm.n_per_wmma: " + std::string(CHECK_MARK(has_n_per_wmma)) + (has_n_per_wmma ? "\n" : " (missing or wrong type)\n"); - msg += " → gridwise_gemm.m_wmma_per_wave: " + std::string(CHECK_MARK(has_m_wmma_per_wave)) + (has_m_wmma_per_wave ? "\n" : " (missing or wrong type)\n"); - msg += " → gridwise_gemm.n_wmma_per_wave: " + std::string(CHECK_MARK(has_n_wmma_per_wave)) + (has_n_wmma_per_wave ? "\n" : " (missing or wrong type)\n"); - + + using GG = decltype(T::gridwise_gemm); + constexpr bool has_k1 = requires(GG t) { + { t.k1 } -> std::convertible_to; + }; + constexpr bool has_m_per_wmma = requires(GG t) { + { t.m_per_wmma } -> std::convertible_to; + }; + constexpr bool has_n_per_wmma = requires(GG t) { + { t.n_per_wmma } -> std::convertible_to; + }; + constexpr bool has_m_wmma_per_wave = requires(GG t) { + { t.m_wmma_per_wave } -> std::convertible_to; + }; + constexpr bool has_n_wmma_per_wave = requires(GG t) { + { t.n_wmma_per_wave } -> std::convertible_to; + }; + + msg += " → gridwise_gemm.k1: " + std::string(CHECK_MARK(has_k1)) + + (has_k1 ? "\n" : " (missing or wrong type)\n"); + msg += " → gridwise_gemm.m_per_wmma: " + std::string(CHECK_MARK(has_m_per_wmma)) + + (has_m_per_wmma ? "\n" : " (missing or wrong type)\n"); + msg += " → gridwise_gemm.n_per_wmma: " + std::string(CHECK_MARK(has_n_per_wmma)) + + (has_n_per_wmma ? "\n" : " (missing or wrong type)\n"); + msg += + " → gridwise_gemm.m_wmma_per_wave: " + std::string(CHECK_MARK(has_m_wmma_per_wave)) + + (has_m_wmma_per_wave ? "\n" : " (missing or wrong type)\n"); + msg += + " → gridwise_gemm.n_wmma_per_wave: " + std::string(CHECK_MARK(has_n_wmma_per_wave)) + + (has_n_wmma_per_wave ? "\n" : " (missing or wrong type)\n"); + return msg; } // Tile-specific diagnostics template -consteval auto detailed_diagnostic_SpecifiesTileThreadBlock() -> std::string { - if constexpr (!requires { { T::thread_block } -> TileThreadBlockDescriptor; }) { +consteval auto detailed_diagnostic_SpecifiesTileThreadBlock() -> std::string +{ + if constexpr(!requires { + { T::thread_block } -> TileThreadBlockDescriptor; + }) + { return " → T::thread_block member: [✗] (missing or wrong type)\n"; - } else { - using TB = decltype(T::thread_block); + } + else + { + using TB = decltype(T::thread_block); std::string msg = " → T::thread_block member: [✓]\n"; - - constexpr bool has_tile_m = requires(TB t) { { t.tile_size.m } -> std::convertible_to; }; - constexpr bool has_tile_n = requires(TB t) { { t.tile_size.n } -> std::convertible_to; }; - constexpr bool has_tile_k = requires(TB t) { { t.tile_size.k } -> std::convertible_to; }; - - msg += " → thread_block.tile_size.m: " + std::string(CHECK_MARK(has_tile_m)) + (has_tile_m ? "\n" : " (missing or wrong type)\n"); - msg += " → thread_block.tile_size.n: " + std::string(CHECK_MARK(has_tile_n)) + (has_tile_n ? "\n" : " (missing or wrong type)\n"); - msg += " → thread_block.tile_size.k: " + std::string(CHECK_MARK(has_tile_k)) + (has_tile_k ? "\n" : " (missing or wrong type)\n"); - + + constexpr bool has_tile_m = requires(TB t) { + { t.tile_size.m } -> std::convertible_to; + }; + constexpr bool has_tile_n = requires(TB t) { + { t.tile_size.n } -> std::convertible_to; + }; + constexpr bool has_tile_k = requires(TB t) { + { t.tile_size.k } -> std::convertible_to; + }; + + msg += " → thread_block.tile_size.m: " + std::string(CHECK_MARK(has_tile_m)) + + (has_tile_m ? "\n" : " (missing or wrong type)\n"); + msg += " → thread_block.tile_size.n: " + std::string(CHECK_MARK(has_tile_n)) + + (has_tile_n ? "\n" : " (missing or wrong type)\n"); + msg += " → thread_block.tile_size.k: " + std::string(CHECK_MARK(has_tile_k)) + + (has_tile_k ? "\n" : " (missing or wrong type)\n"); + return msg; } } template -consteval auto detailed_diagnostic_SpecifiesTileTransfer() -> std::string { +consteval auto detailed_diagnostic_SpecifiesTileTransfer() -> std::string +{ std::string msg; constexpr bool has_transfer = requires { T::transfer; }; msg += " → T::transfer member: " + std::string(CHECK_MARK(has_transfer)) + "\n"; - - if constexpr (!has_transfer) { + + if constexpr(!has_transfer) + { return msg; } - - constexpr bool has_a_scalar = requires { { T::transfer.a_scalar_per_vector } -> std::convertible_to; }; - constexpr bool has_b_scalar = requires { { T::transfer.b_scalar_per_vector } -> std::convertible_to; }; - constexpr bool has_c_scalar = requires { { T::transfer.c_scalar_per_vector } -> std::convertible_to; }; - - msg += " → transfer.a_scalar_per_vector: " + std::string(CHECK_MARK(has_a_scalar)) + (has_a_scalar ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.b_scalar_per_vector: " + std::string(CHECK_MARK(has_b_scalar)) + (has_b_scalar ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.c_scalar_per_vector: " + std::string(CHECK_MARK(has_c_scalar)) + (has_c_scalar ? "\n" : " (missing or wrong type)\n"); - + + constexpr bool has_a_scalar = requires { + { T::transfer.a_scalar_per_vector } -> std::convertible_to; + }; + constexpr bool has_b_scalar = requires { + { T::transfer.b_scalar_per_vector } -> std::convertible_to; + }; + constexpr bool has_c_scalar = requires { + { T::transfer.c_scalar_per_vector } -> std::convertible_to; + }; + + msg += " → transfer.a_scalar_per_vector: " + std::string(CHECK_MARK(has_a_scalar)) + + (has_a_scalar ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.b_scalar_per_vector: " + std::string(CHECK_MARK(has_b_scalar)) + + (has_b_scalar ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.c_scalar_per_vector: " + std::string(CHECK_MARK(has_c_scalar)) + + (has_c_scalar ? "\n" : " (missing or wrong type)\n"); + return msg; } template -consteval auto detailed_diagnostic_SpecifiesTileBlockGemm() -> std::string { +consteval auto detailed_diagnostic_SpecifiesTileBlockGemm() -> std::string +{ std::string msg; - constexpr bool has_block_gemm = requires { { T::block_gemm } -> TileBlockGemmDescriptor; }; + constexpr bool has_block_gemm = requires { + { T::block_gemm } -> TileBlockGemmDescriptor; + }; msg += " → T::block_gemm member: " + std::string(CHECK_MARK(has_block_gemm)) + "\n"; - - if constexpr (!has_block_gemm) { + + if constexpr(!has_block_gemm) + { return msg; } - - using BG = decltype(T::block_gemm); - constexpr bool has_warps_m = requires(BG t) { { t.warps.m } -> std::convertible_to; }; - constexpr bool has_warps_n = requires(BG t) { { t.warps.n } -> std::convertible_to; }; - constexpr bool has_warps_k = requires(BG t) { { t.warps.k } -> std::convertible_to; }; - constexpr bool has_warp_tile_m = requires(BG t) { { t.warp_tile.m } -> std::convertible_to; }; - constexpr bool has_warp_tile_n = requires(BG t) { { t.warp_tile.n } -> std::convertible_to; }; - constexpr bool has_warp_tile_k = requires(BG t) { { t.warp_tile.k } -> std::convertible_to; }; - constexpr bool has_double_smem = requires(BG t) { { t.double_smem_buffer } -> std::convertible_to; }; - constexpr bool has_num_wave_groups = requires(BG t) { { t.num_wave_groups } -> std::convertible_to; }; - constexpr bool has_pipeline = requires(BG t) { { t.pipeline_version } -> std::convertible_to; }; - constexpr bool has_scheduler = requires(BG t) { { t.scheduler } -> std::convertible_to; }; - - msg += " → block_gemm.warps.m: " + std::string(CHECK_MARK(has_warps_m)) + (has_warps_m ? "\n" : " (missing or wrong type)\n"); - msg += " → block_gemm.warps.n: " + std::string(CHECK_MARK(has_warps_n)) + (has_warps_n ? "\n" : " (missing or wrong type)\n"); - msg += " → block_gemm.warps.k: " + std::string(CHECK_MARK(has_warps_k)) + (has_warps_k ? "\n" : " (missing or wrong type)\n"); - msg += " → block_gemm.warp_tile.m: " + std::string(CHECK_MARK(has_warp_tile_m)) + (has_warp_tile_m ? "\n" : " (missing or wrong type)\n"); - msg += " → block_gemm.warp_tile.n: " + std::string(CHECK_MARK(has_warp_tile_n)) + (has_warp_tile_n ? "\n" : " (missing or wrong type)\n"); - msg += " → block_gemm.warp_tile.k: " + std::string(CHECK_MARK(has_warp_tile_k)) + (has_warp_tile_k ? "\n" : " (missing or wrong type)\n"); - msg += " → block_gemm.double_smem_buffer: " + std::string(CHECK_MARK(has_double_smem)) + (has_double_smem ? "\n" : " (missing or wrong type)\n"); - msg += " → block_gemm.num_wave_groups: " + std::string(CHECK_MARK(has_num_wave_groups)) + (has_num_wave_groups ? "\n" : " (missing or wrong type)\n"); - msg += " → block_gemm.pipeline_version: " + std::string(CHECK_MARK(has_pipeline)) + (has_pipeline ? "\n" : " (missing or wrong type)\n"); - msg += " → block_gemm.scheduler: " + std::string(CHECK_MARK(has_scheduler)) + (has_scheduler ? "\n" : " (missing or wrong type)\n"); - + + using BG = decltype(T::block_gemm); + constexpr bool has_warps_m = requires(BG t) { + { t.warps.m } -> std::convertible_to; + }; + constexpr bool has_warps_n = requires(BG t) { + { t.warps.n } -> std::convertible_to; + }; + constexpr bool has_warps_k = requires(BG t) { + { t.warps.k } -> std::convertible_to; + }; + constexpr bool has_warp_tile_m = requires(BG t) { + { t.warp_tile.m } -> std::convertible_to; + }; + constexpr bool has_warp_tile_n = requires(BG t) { + { t.warp_tile.n } -> std::convertible_to; + }; + constexpr bool has_warp_tile_k = requires(BG t) { + { t.warp_tile.k } -> std::convertible_to; + }; + constexpr bool has_double_smem = requires(BG t) { + { t.double_smem_buffer } -> std::convertible_to; + }; + constexpr bool has_num_wave_groups = requires(BG t) { + { t.num_wave_groups } -> std::convertible_to; + }; + constexpr bool has_pipeline = requires(BG t) { + { t.pipeline_version } -> std::convertible_to; + }; + constexpr bool has_scheduler = requires(BG t) { + { t.scheduler } -> std::convertible_to; + }; + + msg += " → block_gemm.warps.m: " + std::string(CHECK_MARK(has_warps_m)) + + (has_warps_m ? "\n" : " (missing or wrong type)\n"); + msg += " → block_gemm.warps.n: " + std::string(CHECK_MARK(has_warps_n)) + + (has_warps_n ? "\n" : " (missing or wrong type)\n"); + msg += " → block_gemm.warps.k: " + std::string(CHECK_MARK(has_warps_k)) + + (has_warps_k ? "\n" : " (missing or wrong type)\n"); + msg += " → block_gemm.warp_tile.m: " + std::string(CHECK_MARK(has_warp_tile_m)) + + (has_warp_tile_m ? "\n" : " (missing or wrong type)\n"); + msg += " → block_gemm.warp_tile.n: " + std::string(CHECK_MARK(has_warp_tile_n)) + + (has_warp_tile_n ? "\n" : " (missing or wrong type)\n"); + msg += " → block_gemm.warp_tile.k: " + std::string(CHECK_MARK(has_warp_tile_k)) + + (has_warp_tile_k ? "\n" : " (missing or wrong type)\n"); + msg += " → block_gemm.double_smem_buffer: " + std::string(CHECK_MARK(has_double_smem)) + + (has_double_smem ? "\n" : " (missing or wrong type)\n"); + msg += " → block_gemm.num_wave_groups: " + std::string(CHECK_MARK(has_num_wave_groups)) + + (has_num_wave_groups ? "\n" : " (missing or wrong type)\n"); + msg += " → block_gemm.pipeline_version: " + std::string(CHECK_MARK(has_pipeline)) + + (has_pipeline ? "\n" : " (missing or wrong type)\n"); + msg += " → block_gemm.scheduler: " + std::string(CHECK_MARK(has_scheduler)) + + (has_scheduler ? "\n" : " (missing or wrong type)\n"); + return msg; } template -consteval auto detailed_diagnostic_SpecifiesTileOptimizations() -> std::string { +consteval auto detailed_diagnostic_SpecifiesTileOptimizations() -> std::string +{ std::string msg; - constexpr bool has_optimizations = requires { { T::optimizations } -> TileOptimizationsDescriptor; }; + constexpr bool has_optimizations = requires { + { T::optimizations } -> TileOptimizationsDescriptor; + }; msg += " → T::optimizations member: " + std::string(CHECK_MARK(has_optimizations)) + "\n"; - - if constexpr (!has_optimizations) { + + if constexpr(!has_optimizations) + { return msg; } - - using OPT = decltype(T::optimizations); - constexpr bool has_num_groups = requires(OPT t) { { t.num_groups_to_merge } -> std::convertible_to; }; - constexpr bool has_split_image = requires(OPT t) { { t.split_image } -> std::convertible_to; }; - constexpr bool has_explicit_gemm = requires(OPT t) { { t.explicit_gemm } -> std::convertible_to; }; - - msg += " → optimizations.num_groups_to_merge: " + std::string(CHECK_MARK(has_num_groups)) + (has_num_groups ? "\n" : " (missing or wrong type)\n"); - msg += " → optimizations.split_image: " + std::string(CHECK_MARK(has_split_image)) + (has_split_image ? "\n" : " (missing or wrong type)\n"); - msg += " → optimizations.explicit_gemm: " + std::string(CHECK_MARK(has_explicit_gemm)) + (has_explicit_gemm ? "\n" : " (missing or wrong type)\n"); - + + using OPT = decltype(T::optimizations); + constexpr bool has_num_groups = requires(OPT t) { + { t.num_groups_to_merge } -> std::convertible_to; + }; + constexpr bool has_split_image = requires(OPT t) { + { t.split_image } -> std::convertible_to; + }; + constexpr bool has_explicit_gemm = requires(OPT t) { + { t.explicit_gemm } -> std::convertible_to; + }; + + msg += " → optimizations.num_groups_to_merge: " + std::string(CHECK_MARK(has_num_groups)) + + (has_num_groups ? "\n" : " (missing or wrong type)\n"); + msg += " → optimizations.split_image: " + std::string(CHECK_MARK(has_split_image)) + + (has_split_image ? "\n" : " (missing or wrong type)\n"); + msg += " → optimizations.explicit_gemm: " + std::string(CHECK_MARK(has_explicit_gemm)) + + (has_explicit_gemm ? "\n" : " (missing or wrong type)\n"); + return msg; } // DL-specific diagnostics template -consteval auto detailed_diagnostic_SpecifiesDlThreadConfig() -> std::string { +consteval auto detailed_diagnostic_SpecifiesDlThreadConfig() -> std::string +{ std::string msg; - constexpr bool has_thread_config = requires { { T::thread_config } -> DlThreadConfigDescriptor; }; + constexpr bool has_thread_config = requires { + { T::thread_config } -> DlThreadConfigDescriptor; + }; msg += " → T::thread_config member: " + std::string(CHECK_MARK(has_thread_config)) + "\n"; - - if constexpr (!has_thread_config) { + + if constexpr(!has_thread_config) + { return msg; } - - using TC = decltype(T::thread_config); - constexpr bool has_k0 = requires(TC t) { { t.k0_per_block } -> std::convertible_to; }; - constexpr bool has_k1 = requires(TC t) { { t.k1 } -> std::convertible_to; }; - constexpr bool has_m1 = requires(TC t) { { t.m1_per_thread } -> std::convertible_to; }; - constexpr bool has_n1 = requires(TC t) { { t.n1_per_thread } -> std::convertible_to; }; - constexpr bool has_k = requires(TC t) { { t.k_per_thread } -> std::convertible_to; }; - - msg += " → thread_config.k0_per_block: " + std::string(CHECK_MARK(has_k0)) + (has_k0 ? "\n" : " (missing or wrong type)\n"); - msg += " → thread_config.k1: " + std::string(CHECK_MARK(has_k1)) + (has_k1 ? "\n" : " (missing or wrong type)\n"); - msg += " → thread_config.m1_per_thread: " + std::string(CHECK_MARK(has_m1)) + (has_m1 ? "\n" : " (missing or wrong type)\n"); - msg += " → thread_config.n1_per_thread: " + std::string(CHECK_MARK(has_n1)) + (has_n1 ? "\n" : " (missing or wrong type)\n"); - msg += " → thread_config.k_per_thread: " + std::string(CHECK_MARK(has_k)) + (has_k ? "\n" : " (missing or wrong type)\n"); - + + using TC = decltype(T::thread_config); + constexpr bool has_k0 = requires(TC t) { + { t.k0_per_block } -> std::convertible_to; + }; + constexpr bool has_k1 = requires(TC t) { + { t.k1 } -> std::convertible_to; + }; + constexpr bool has_m1 = requires(TC t) { + { t.m1_per_thread } -> std::convertible_to; + }; + constexpr bool has_n1 = requires(TC t) { + { t.n1_per_thread } -> std::convertible_to; + }; + constexpr bool has_k = requires(TC t) { + { t.k_per_thread } -> std::convertible_to; + }; + + msg += " → thread_config.k0_per_block: " + std::string(CHECK_MARK(has_k0)) + + (has_k0 ? "\n" : " (missing or wrong type)\n"); + msg += " → thread_config.k1: " + std::string(CHECK_MARK(has_k1)) + + (has_k1 ? "\n" : " (missing or wrong type)\n"); + msg += " → thread_config.m1_per_thread: " + std::string(CHECK_MARK(has_m1)) + + (has_m1 ? "\n" : " (missing or wrong type)\n"); + msg += " → thread_config.n1_per_thread: " + std::string(CHECK_MARK(has_n1)) + + (has_n1 ? "\n" : " (missing or wrong type)\n"); + msg += " → thread_config.k_per_thread: " + std::string(CHECK_MARK(has_k)) + + (has_k ? "\n" : " (missing or wrong type)\n"); + return msg; } template -consteval auto detailed_diagnostic_SpecifiesDlThreadCluster() -> std::string { +consteval auto detailed_diagnostic_SpecifiesDlThreadCluster() -> std::string +{ std::string msg; - constexpr bool has_thread_cluster = requires { { T::thread_cluster } -> DlThreadClusterDescriptor; }; - msg += " → T::thread_cluster member: " + std::string(CHECK_MARK(has_thread_cluster)) + "\n"; - - if constexpr (!has_thread_cluster) { + constexpr bool has_thread_cluster = requires { + { T::thread_cluster } -> DlThreadClusterDescriptor; + }; + msg += + " → T::thread_cluster member: " + std::string(CHECK_MARK(has_thread_cluster)) + "\n"; + + if constexpr(!has_thread_cluster) + { return msg; } - - using TC = decltype(T::thread_cluster); - constexpr bool has_m1_xs = requires(TC t) { { t.m1_xs } -> std::convertible_to>; }; - constexpr bool has_n1_xs = requires(TC t) { { t.n1_xs } -> std::convertible_to>; }; - - msg += " → thread_cluster.m1_xs: " + std::string(CHECK_MARK(has_m1_xs)) + (has_m1_xs ? "\n" : " (missing or wrong type)\n"); - msg += " → thread_cluster.n1_xs: " + std::string(CHECK_MARK(has_n1_xs)) + (has_n1_xs ? "\n" : " (missing or wrong type)\n"); - + + using TC = decltype(T::thread_cluster); + constexpr bool has_m1_xs = requires(TC t) { + { t.m1_xs } -> std::convertible_to>; + }; + constexpr bool has_n1_xs = requires(TC t) { + { t.n1_xs } -> std::convertible_to>; + }; + + msg += " → thread_cluster.m1_xs: " + std::string(CHECK_MARK(has_m1_xs)) + + (has_m1_xs ? "\n" : " (missing or wrong type)\n"); + msg += " → thread_cluster.n1_xs: " + std::string(CHECK_MARK(has_n1_xs)) + + (has_n1_xs ? "\n" : " (missing or wrong type)\n"); + return msg; } template -consteval auto detailed_diagnostic_SpecifiesDlFwdBlockTransfer() -> std::string { +consteval auto detailed_diagnostic_SpecifiesDlFwdBlockTransfer() -> std::string +{ std::string msg; constexpr bool has_transfer = requires { T::transfer; }; msg += " → T::transfer member: " + std::string(CHECK_MARK(has_transfer)) + "\n"; - - if constexpr (!has_transfer) { + + if constexpr(!has_transfer) + { return msg; } - - constexpr bool has_a = requires { { T::transfer.a } -> DlBlockTransferDescriptor4D; }; - constexpr bool has_b = requires { { T::transfer.b } -> DlBlockTransferDescriptor4D; }; + + constexpr bool has_a = requires { + { T::transfer.a } -> DlBlockTransferDescriptor4D; + }; + constexpr bool has_b = requires { + { T::transfer.b } -> DlBlockTransferDescriptor4D; + }; msg += " → T::transfer.a: " + std::string(CHECK_MARK(has_a)) + "\n"; - - if constexpr (has_a) { - using ABT = decltype(T::transfer.a); - constexpr bool has_thread_slice = requires(ABT t) { { t.thread_slice_lengths } -> std::convertible_to>; }; - constexpr bool has_thread_cluster = requires(ABT t) { { t.thread_cluster_lengths } -> std::convertible_to>; }; - constexpr bool has_cluster_arrange = requires(ABT t) { { t.thread_cluster_arrange_order } -> std::convertible_to>; }; - constexpr bool has_src_access = requires(ABT t) { { t.src_access_order } -> std::convertible_to>; }; - constexpr bool has_src_vector = requires(ABT t) { { t.src_vector_tensor_lengths } -> std::convertible_to>; }; - constexpr bool has_src_contiguous = requires(ABT t) { { t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to>; }; - constexpr bool has_dst_vector = requires(ABT t) { { t.dst_vector_tensor_lengths } -> std::convertible_to>; }; - - msg += " → transfer.a.thread_slice_lengths (4D): " + std::string(CHECK_MARK(has_thread_slice)) + (has_thread_slice ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.a.thread_cluster_lengths (4D): " + std::string(CHECK_MARK(has_thread_cluster)) + (has_thread_cluster ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.a.thread_cluster_arrange_order (4D): " + std::string(CHECK_MARK(has_cluster_arrange)) + (has_cluster_arrange ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.a.src_access_order (4D): " + std::string(CHECK_MARK(has_src_access)) + (has_src_access ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.a.src_vector_tensor_lengths (4D): " + std::string(CHECK_MARK(has_src_vector)) + (has_src_vector ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.a.src_vector_tensor_contiguous_dim_order (4D): " + std::string(CHECK_MARK(has_src_contiguous)) + (has_src_contiguous ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.a.dst_vector_tensor_lengths (4D): " + std::string(CHECK_MARK(has_dst_vector)) + (has_dst_vector ? "\n" : " (missing or wrong type)\n"); - } else { + + if constexpr(has_a) + { + using ABT = decltype(T::transfer.a); + constexpr bool has_thread_slice = requires(ABT t) { + { t.thread_slice_lengths } -> std::convertible_to>; + }; + constexpr bool has_thread_cluster = requires(ABT t) { + { t.thread_cluster_lengths } -> std::convertible_to>; + }; + constexpr bool has_cluster_arrange = requires(ABT t) { + { t.thread_cluster_arrange_order } -> std::convertible_to>; + }; + constexpr bool has_src_access = requires(ABT t) { + { t.src_access_order } -> std::convertible_to>; + }; + constexpr bool has_src_vector = requires(ABT t) { + { t.src_vector_tensor_lengths } -> std::convertible_to>; + }; + constexpr bool has_src_contiguous = requires(ABT t) { + { + t.src_vector_tensor_contiguous_dim_order + } -> std::convertible_to>; + }; + constexpr bool has_dst_vector = requires(ABT t) { + { t.dst_vector_tensor_lengths } -> std::convertible_to>; + }; + + msg += " → transfer.a.thread_slice_lengths (4D): " + + std::string(CHECK_MARK(has_thread_slice)) + + (has_thread_slice ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.a.thread_cluster_lengths (4D): " + + std::string(CHECK_MARK(has_thread_cluster)) + + (has_thread_cluster ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.a.thread_cluster_arrange_order (4D): " + + std::string(CHECK_MARK(has_cluster_arrange)) + + (has_cluster_arrange ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.a.src_access_order (4D): " + + std::string(CHECK_MARK(has_src_access)) + + (has_src_access ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.a.src_vector_tensor_lengths (4D): " + + std::string(CHECK_MARK(has_src_vector)) + + (has_src_vector ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.a.src_vector_tensor_contiguous_dim_order (4D): " + + std::string(CHECK_MARK(has_src_contiguous)) + + (has_src_contiguous ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.a.dst_vector_tensor_lengths (4D): " + + std::string(CHECK_MARK(has_dst_vector)) + + (has_dst_vector ? "\n" : " (missing or wrong type)\n"); + } + else + { msg += " → T::transfer.a (4D): [✗] (missing or wrong type)\n"; } - + msg += " → T::transfer.b: " + std::string(CHECK_MARK(has_b)) + "\n"; - - if constexpr (has_b) { + + if constexpr(has_b) + { msg += " → T::transfer.b (4D): [✓] (similar fields as transfer.a)\n"; - } else { + } + else + { msg += " → T::transfer.b (4D): [✗] (missing or wrong type)\n"; } - + return msg; } template -consteval auto detailed_diagnostic_SpecifiesDlBwdBlockTransfer() -> std::string { +consteval auto detailed_diagnostic_SpecifiesDlBwdBlockTransfer() -> std::string +{ std::string msg; constexpr bool has_transfer = requires { T::transfer; }; msg += " → T::transfer member: " + std::string(CHECK_MARK(has_transfer)) + "\n"; - - if constexpr (!has_transfer) { + + if constexpr(!has_transfer) + { return msg; } - - constexpr bool has_a = requires { { T::transfer.a } -> DlBlockTransferDescriptor5D; }; - constexpr bool has_b = requires { { T::transfer.b } -> DlBlockTransferDescriptor5D; }; + + constexpr bool has_a = requires { + { T::transfer.a } -> DlBlockTransferDescriptor5D; + }; + constexpr bool has_b = requires { + { T::transfer.b } -> DlBlockTransferDescriptor5D; + }; msg += " → T::transfer.a: " + std::string(CHECK_MARK(has_a)) + "\n"; - - if constexpr (has_a) { - using ABT = decltype(T::transfer.a); - constexpr bool has_thread_slice = requires(ABT t) { { t.thread_slice_lengths } -> std::convertible_to>; }; - constexpr bool has_thread_cluster = requires(ABT t) { { t.thread_cluster_lengths } -> std::convertible_to>; }; - constexpr bool has_cluster_arrange = requires(ABT t) { { t.thread_cluster_arrange_order } -> std::convertible_to>; }; - constexpr bool has_src_access = requires(ABT t) { { t.src_access_order } -> std::convertible_to>; }; - constexpr bool has_src_vector = requires(ABT t) { { t.src_vector_tensor_lengths } -> std::convertible_to>; }; - constexpr bool has_src_contiguous = requires(ABT t) { { t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to>; }; - constexpr bool has_dst_vector = requires(ABT t) { { t.dst_vector_tensor_lengths } -> std::convertible_to>; }; - - msg += " → transfer.a.thread_slice_lengths (5D): " + std::string(CHECK_MARK(has_thread_slice)) + (has_thread_slice ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.a.thread_cluster_lengths (5D): " + std::string(CHECK_MARK(has_thread_cluster)) + (has_thread_cluster ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.a.thread_cluster_arrange_order (5D): " + std::string(CHECK_MARK(has_cluster_arrange)) + (has_cluster_arrange ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.a.src_access_order (5D): " + std::string(CHECK_MARK(has_src_access)) + (has_src_access ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.a.src_vector_tensor_lengths (5D): " + std::string(CHECK_MARK(has_src_vector)) + (has_src_vector ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.a.src_vector_tensor_contiguous_dim_order (5D): " + std::string(CHECK_MARK(has_src_contiguous)) + (has_src_contiguous ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.a.dst_vector_tensor_lengths (5D): " + std::string(CHECK_MARK(has_dst_vector)) + (has_dst_vector ? "\n" : " (missing or wrong type)\n"); - } else { + + if constexpr(has_a) + { + using ABT = decltype(T::transfer.a); + constexpr bool has_thread_slice = requires(ABT t) { + { t.thread_slice_lengths } -> std::convertible_to>; + }; + constexpr bool has_thread_cluster = requires(ABT t) { + { t.thread_cluster_lengths } -> std::convertible_to>; + }; + constexpr bool has_cluster_arrange = requires(ABT t) { + { t.thread_cluster_arrange_order } -> std::convertible_to>; + }; + constexpr bool has_src_access = requires(ABT t) { + { t.src_access_order } -> std::convertible_to>; + }; + constexpr bool has_src_vector = requires(ABT t) { + { t.src_vector_tensor_lengths } -> std::convertible_to>; + }; + constexpr bool has_src_contiguous = requires(ABT t) { + { + t.src_vector_tensor_contiguous_dim_order + } -> std::convertible_to>; + }; + constexpr bool has_dst_vector = requires(ABT t) { + { t.dst_vector_tensor_lengths } -> std::convertible_to>; + }; + + msg += " → transfer.a.thread_slice_lengths (5D): " + + std::string(CHECK_MARK(has_thread_slice)) + + (has_thread_slice ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.a.thread_cluster_lengths (5D): " + + std::string(CHECK_MARK(has_thread_cluster)) + + (has_thread_cluster ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.a.thread_cluster_arrange_order (5D): " + + std::string(CHECK_MARK(has_cluster_arrange)) + + (has_cluster_arrange ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.a.src_access_order (5D): " + + std::string(CHECK_MARK(has_src_access)) + + (has_src_access ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.a.src_vector_tensor_lengths (5D): " + + std::string(CHECK_MARK(has_src_vector)) + + (has_src_vector ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.a.src_vector_tensor_contiguous_dim_order (5D): " + + std::string(CHECK_MARK(has_src_contiguous)) + + (has_src_contiguous ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.a.dst_vector_tensor_lengths (5D): " + + std::string(CHECK_MARK(has_dst_vector)) + + (has_dst_vector ? "\n" : " (missing or wrong type)\n"); + } + else + { msg += " → T::transfer.a (5D): [✗] (missing or wrong type)\n"; } - + msg += " → T::transfer.b: " + std::string(CHECK_MARK(has_b)) + "\n"; - - if constexpr (has_b) { + + if constexpr(has_b) + { msg += " → T::transfer.b (5D): [✓] (similar fields as transfer.a)\n"; - } else { + } + else + { msg += " → T::transfer.b (5D): [✗] (missing or wrong type)\n"; } - + return msg; } template -consteval auto detailed_diagnostic_SpecifiesDlEpilogue() -> std::string { +consteval auto detailed_diagnostic_SpecifiesDlEpilogue() -> std::string +{ std::string msg; constexpr bool has_transfer = requires { T::transfer; }; - if constexpr (!has_transfer) { + if constexpr(!has_transfer) + { return " → T::transfer member: [✗] (not found)\n"; } - + constexpr bool has_c = requires { T::transfer.c; }; msg += " → T::transfer.c: " + std::string(CHECK_MARK(has_c)) + "\n"; - - if constexpr (has_c && requires { T::transfer.c.src_dst_access_order; }) { - using C = decltype(T::transfer.c); - constexpr bool has_src_dst_access = requires(C t) { { t.src_dst_access_order } -> std::convertible_to>; }; - constexpr bool has_src_dst_vector_dim = requires(C t) { { t.src_dst_vector_dim } -> std::convertible_to; }; - constexpr bool has_dst_scalar = requires(C t) { { t.dst_scalar_per_vector } -> std::convertible_to; }; - - msg += " → transfer.c.src_dst_access_order: " + std::string(CHECK_MARK(has_src_dst_access)) + (has_src_dst_access ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.c.src_dst_vector_dim: " + std::string(CHECK_MARK(has_src_dst_vector_dim)) + (has_src_dst_vector_dim ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.c.dst_scalar_per_vector: " + std::string(CHECK_MARK(has_dst_scalar)) + (has_dst_scalar ? "\n" : " (missing or wrong type)\n"); - } else if constexpr (has_c) { + + if constexpr(has_c && requires { T::transfer.c.src_dst_access_order; }) + { + using C = decltype(T::transfer.c); + constexpr bool has_src_dst_access = requires(C t) { + { t.src_dst_access_order } -> std::convertible_to>; + }; + constexpr bool has_src_dst_vector_dim = requires(C t) { + { t.src_dst_vector_dim } -> std::convertible_to; + }; + constexpr bool has_dst_scalar = requires(C t) { + { t.dst_scalar_per_vector } -> std::convertible_to; + }; + + msg += " → transfer.c.src_dst_access_order: " + + std::string(CHECK_MARK(has_src_dst_access)) + + (has_src_dst_access ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.c.src_dst_vector_dim: " + + std::string(CHECK_MARK(has_src_dst_vector_dim)) + + (has_src_dst_vector_dim ? "\n" : " (missing or wrong type)\n"); + msg += " → transfer.c.dst_scalar_per_vector: " + + std::string(CHECK_MARK(has_dst_scalar)) + + (has_dst_scalar ? "\n" : " (missing or wrong type)\n"); + } + else if constexpr(has_c) + { msg += " → T::transfer.c (DlEpilogue): [✗] (missing required fields)\n"; } - + return msg; } template -consteval auto detailed_diagnostic_SpecifiesTileConvSpecialization() -> std::string { - if constexpr (requires { T::specialization; }) { - using SpecType = decltype(T::specialization); +consteval auto detailed_diagnostic_SpecifiesTileConvSpecialization() -> std::string +{ + if constexpr(requires { T::specialization; }) + { + using SpecType = decltype(T::specialization); constexpr bool convertible = std::convertible_to; - return " → T::specialization: " + std::string(CHECK_MARK(convertible)) + + return " → T::specialization: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; - } else { + } + else + { return " → T::specialization: [✗] (missing member)\n"; } } template -consteval auto detailed_diagnostic_SpecifiesLdsTransfer() -> std::string { +consteval auto detailed_diagnostic_SpecifiesLdsTransfer() -> std::string +{ std::string msg; - + constexpr bool has_transfer = requires { T::transfer; }; msg += " → T::transfer member: " + std::string(CHECK_MARK(has_transfer)) + "\n"; - - if constexpr (!has_transfer) { + + if constexpr(!has_transfer) + { return msg; } - - constexpr bool has_a = requires { { T::transfer.a.lds_transfer } -> LdsTransferDescriptor; }; + + constexpr bool has_a = requires { + { T::transfer.a.lds_transfer } -> LdsTransferDescriptor; + }; msg += " → T::transfer.a: " + std::string(CHECK_MARK(has_a)) + "\n"; - if constexpr (!has_a) { + if constexpr(!has_a) + { msg += " → T::transfer.a.lds_transfer: [✗] (missing or wrong type)\n"; } - - constexpr bool has_b = requires { { T::transfer.b.lds_transfer } -> LdsTransferDescriptor; }; + + constexpr bool has_b = requires { + { T::transfer.b.lds_transfer } -> LdsTransferDescriptor; + }; msg += " → T::transfer.b: " + std::string(CHECK_MARK(has_b)) + "\n"; - if constexpr (!has_b) { + if constexpr(!has_b) + { msg += " → T::transfer.b.lds_transfer: [✗] (missing or wrong type)\n"; } - - constexpr bool has_c = requires { { T::transfer.c.epilogue } -> EpilogueDescriptor; }; + + constexpr bool has_c = requires { + { T::transfer.c.epilogue } -> EpilogueDescriptor; + }; msg += " → T::transfer.c: " + std::string(CHECK_MARK(has_c)) + "\n"; - if constexpr (!has_c) { + if constexpr(!has_c) + { msg += " → T::transfer.c.epilogue: [✗] (missing or wrong type)\n"; } - + return msg; } template -consteval auto detailed_diagnostic_SpecifieGridwiseGemmPipeline() -> std::string { - if constexpr (requires { T::pipeline_version; }) { - using PipelineType = decltype(T::pipeline_version); +consteval auto detailed_diagnostic_SpecifieGridwiseGemmPipeline() -> std::string +{ + if constexpr(requires { T::pipeline_version; }) + { + using PipelineType = decltype(T::pipeline_version); constexpr bool convertible = std::convertible_to; - return " → T::pipeline_version: " + std::string(CHECK_MARK(convertible)) + + return " → T::pipeline_version: " + std::string(CHECK_MARK(convertible)) + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; - } else { + } + else + { return " → T::pipeline_version: [✗] (missing member)\n"; } } diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index eb02aae584d..9b02b4c9fc0 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -10,27 +10,28 @@ namespace ck_tile::builder::factory { using namespace ck_tile::builder::diagnostics; template -struct ReferenceAlgorithm { +struct ReferenceAlgorithm +{ CHECK_CONCEPT(T, ConvAlgorithmDescriptor) CHECK_CONCEPT(T, SpecifiesReferenceAlgorithm) static constexpr bool c1 = c_ConvAlgorithmDescriptor; static constexpr bool c2 = c_SpecifiesReferenceAlgorithm; - static consteval bool is_valid() { - return c1 && c2; - } + static consteval bool is_valid() { return c1 && c2; } - static consteval auto message() -> std::string { + static consteval auto message() -> std::string + { return std::string("\n=== Reference Algorithm Diagnostic (closest match) ===\n" - "Concepts for Reference Algorithm:\n") + + "Concepts for Reference Algorithm:\n") + DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + DIAGNOSTIC_LINE(SpecifiesReferenceAlgorithm); } }; template -struct FwdXdlV3Algorithm { +struct FwdXdlV3Algorithm +{ CHECK_CONCEPT(T, ConvAlgorithmDescriptor) CHECK_CONCEPT(T, SpecifiesThreadBlock) CHECK_CONCEPT(T, SpecifiesBlockTransfer) @@ -42,39 +43,39 @@ struct FwdXdlV3Algorithm { CHECK_CONCEPT(T, SpecifiesGemmSpecialization) CHECK_CONCEPT(T, SpecifiesBlockGemm) - static constexpr bool c1 = c_ConvAlgorithmDescriptor; - static constexpr bool c2 = c_SpecifiesThreadBlock; - static constexpr bool c3 = c_SpecifiesBlockTransfer; - static constexpr bool c4 = c_SpecifiesLdsTransfer; - static constexpr bool c5 = c_SpecifiesThreadClusterAccessOrder; - static constexpr bool c6 = c_SpecifiesSourceAccessOrder; - static constexpr bool c7 = c_SpecifiesGridwiseFwdXdlGemm; - static constexpr bool c8 = c_SpecifiesFwdConvSpecialization; - static constexpr bool c9 = c_SpecifiesGemmSpecialization; + static constexpr bool c1 = c_ConvAlgorithmDescriptor; + static constexpr bool c2 = c_SpecifiesThreadBlock; + static constexpr bool c3 = c_SpecifiesBlockTransfer; + static constexpr bool c4 = c_SpecifiesLdsTransfer; + static constexpr bool c5 = c_SpecifiesThreadClusterAccessOrder; + static constexpr bool c6 = c_SpecifiesSourceAccessOrder; + static constexpr bool c7 = c_SpecifiesGridwiseFwdXdlGemm; + static constexpr bool c8 = c_SpecifiesFwdConvSpecialization; + static constexpr bool c9 = c_SpecifiesGemmSpecialization; static constexpr bool c10 = c_SpecifiesBlockGemm; - static consteval bool is_valid() { + static consteval bool is_valid() + { return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && c10; } - static consteval auto message() -> std::string { - return std::string("\n=== Forward XDL V3 Algorithm Diagnostic (closest match) ===\n" - "Concepts for FwdXdlV3 Algorithm:\n") + - DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + - DIAGNOSTIC_LINE(SpecifiesThreadBlock) + - DIAGNOSTIC_LINE(SpecifiesBlockTransfer) + - DIAGNOSTIC_LINE(SpecifiesLdsTransfer) + + static consteval auto message() -> std::string + { + return std::string("\n=== Forward XDL V3 Algorithm Diagnostic (closest match) ===\n" + "Concepts for FwdXdlV3 Algorithm:\n") + + DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + DIAGNOSTIC_LINE(SpecifiesThreadBlock) + + DIAGNOSTIC_LINE(SpecifiesBlockTransfer) + DIAGNOSTIC_LINE(SpecifiesLdsTransfer) + DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) + DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) + DIAGNOSTIC_LINE(SpecifiesGridwiseFwdXdlGemm) + DIAGNOSTIC_LINE(SpecifiesFwdConvSpecialization) + - DIAGNOSTIC_LINE(SpecifiesGemmSpecialization) + - DIAGNOSTIC_LINE(SpecifiesBlockGemm); + DIAGNOSTIC_LINE(SpecifiesGemmSpecialization) + DIAGNOSTIC_LINE(SpecifiesBlockGemm); } }; template -struct FwdXdlAlgorithmBase { +struct FwdXdlAlgorithmBase +{ CHECK_CONCEPT(T, ConvAlgorithmDescriptor) CHECK_CONCEPT(T, SpecifiesThreadBlock) CHECK_CONCEPT(T, SpecifiesBlockTransfer) @@ -88,60 +89,58 @@ struct FwdXdlAlgorithmBase { CHECK_CONCEPT(T, SpecifiesNumGroupsToMerge) CHECK_CONCEPT(T, SpecifiesLoopScheduler) - static constexpr bool c1 = c_ConvAlgorithmDescriptor; - static constexpr bool c2 = c_SpecifiesThreadBlock; - static constexpr bool c3 = c_SpecifiesBlockTransfer; - static constexpr bool c4 = c_SpecifiesLdsTransfer; - static constexpr bool c5 = c_SpecifiesThreadClusterAccessOrder; - static constexpr bool c6 = c_SpecifiesSourceAccessOrder; - static constexpr bool c7 = c_SpecifiesGridwiseFwdXdlGemm; - static constexpr bool c8 = c_SpecifiesFwdConvSpecialization; - static constexpr bool c9 = c_SpecifiesGemmSpecialization; + static constexpr bool c1 = c_ConvAlgorithmDescriptor; + static constexpr bool c2 = c_SpecifiesThreadBlock; + static constexpr bool c3 = c_SpecifiesBlockTransfer; + static constexpr bool c4 = c_SpecifiesLdsTransfer; + static constexpr bool c5 = c_SpecifiesThreadClusterAccessOrder; + static constexpr bool c6 = c_SpecifiesSourceAccessOrder; + static constexpr bool c7 = c_SpecifiesGridwiseFwdXdlGemm; + static constexpr bool c8 = c_SpecifiesFwdConvSpecialization; + static constexpr bool c9 = c_SpecifiesGemmSpecialization; static constexpr bool c10 = c_SpecifiesNumPrefetchStages; static constexpr bool c11 = c_SpecifiesNumGroupsToMerge; static constexpr bool c12 = c_SpecifiesLoopScheduler; - static consteval bool is_valid() { + static consteval bool is_valid() + { return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && c10 && c11 && c12; } - static consteval auto message() -> std::string { - return - DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + - DIAGNOSTIC_LINE(SpecifiesThreadBlock) + - DIAGNOSTIC_LINE(SpecifiesBlockTransfer) + - DIAGNOSTIC_LINE(SpecifiesLdsTransfer) + + static consteval auto message() -> std::string + { + return DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + DIAGNOSTIC_LINE(SpecifiesThreadBlock) + + DIAGNOSTIC_LINE(SpecifiesBlockTransfer) + DIAGNOSTIC_LINE(SpecifiesLdsTransfer) + DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) + DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) + DIAGNOSTIC_LINE(SpecifiesGridwiseFwdXdlGemm) + DIAGNOSTIC_LINE(SpecifiesFwdConvSpecialization) + DIAGNOSTIC_LINE(SpecifiesGemmSpecialization) + DIAGNOSTIC_LINE(SpecifiesNumPrefetchStages) + - DIAGNOSTIC_LINE(SpecifiesNumGroupsToMerge) + - DIAGNOSTIC_LINE(SpecifiesLoopScheduler); + DIAGNOSTIC_LINE(SpecifiesNumGroupsToMerge) + DIAGNOSTIC_LINE(SpecifiesLoopScheduler); } }; template -struct FwdXdlAlgorithm : public FwdXdlAlgorithmBase{ +struct FwdXdlAlgorithm : public FwdXdlAlgorithmBase +{ CHECK_CONCEPT(T, SpecifiesGenericInstance) - + static constexpr bool c13 = c_SpecifiesGenericInstance; - static consteval bool is_valid() { - return c13 && FwdXdlAlgorithmBase::is_valid(); - } + static consteval bool is_valid() { return c13 && FwdXdlAlgorithmBase::is_valid(); } - static consteval auto message() -> std::string { + static consteval auto message() -> std::string + { return std::string("\n=== Forward XDL Algorithm Diagnostic (closest match) ===\n" - "Concepts for FwdXdl Algorithm:\n") + - FwdXdlAlgorithmBase::message() + - DIAGNOSTIC_LINE(SpecifiesGenericInstance); + "Concepts for FwdXdl Algorithm:\n") + + FwdXdlAlgorithmBase::message() + DIAGNOSTIC_LINE(SpecifiesGenericInstance); } }; template -struct FwdWmmaAlgorithm { +struct FwdWmmaAlgorithm +{ CHECK_CONCEPT(T, ConvAlgorithmDescriptor) CHECK_CONCEPT(T, SpecifiesThreadBlock) CHECK_CONCEPT(T, SpecifiesBlockTransfer) @@ -155,43 +154,44 @@ struct FwdWmmaAlgorithm { CHECK_CONCEPT(T, SpecifiesLoopScheduler) CHECK_CONCEPT(T, SpecifiesGridwiseGemmPipeline) - static constexpr bool c1 = c_ConvAlgorithmDescriptor; - static constexpr bool c2 = c_SpecifiesThreadBlock; - static constexpr bool c3 = c_SpecifiesBlockTransfer; - static constexpr bool c4 = c_SpecifiesLdsTransfer; - static constexpr bool c5 = c_SpecifiesThreadClusterAccessOrder; - static constexpr bool c6 = c_SpecifiesSourceAccessOrder; - static constexpr bool c7 = c_SpecifiesGridwiseWmmaGemm; - static constexpr bool c8 = c_SpecifiesFwdConvSpecialization; - static constexpr bool c9 = c_SpecifiesGemmSpecialization; + static constexpr bool c1 = c_ConvAlgorithmDescriptor; + static constexpr bool c2 = c_SpecifiesThreadBlock; + static constexpr bool c3 = c_SpecifiesBlockTransfer; + static constexpr bool c4 = c_SpecifiesLdsTransfer; + static constexpr bool c5 = c_SpecifiesThreadClusterAccessOrder; + static constexpr bool c6 = c_SpecifiesSourceAccessOrder; + static constexpr bool c7 = c_SpecifiesGridwiseWmmaGemm; + static constexpr bool c8 = c_SpecifiesFwdConvSpecialization; + static constexpr bool c9 = c_SpecifiesGemmSpecialization; static constexpr bool c10 = c_SpecifiesNumPrefetchStages; static constexpr bool c11 = c_SpecifiesLoopScheduler; static constexpr bool c12 = c_SpecifiesGridwiseGemmPipeline; - static consteval bool is_valid() { + static consteval bool is_valid() + { return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && c10 && c11 && c12; } - static consteval auto message() -> std::string { + static consteval auto message() -> std::string + { return std::string("\n=== Forward WMMA Algorithm Diagnostic (closest match) ===\n" - "Concepts for FwdWmma Algorithm:\n") + - DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + - DIAGNOSTIC_LINE(SpecifiesThreadBlock) + - DIAGNOSTIC_LINE(SpecifiesBlockTransfer) + - DIAGNOSTIC_LINE(SpecifiesLdsTransfer) + + "Concepts for FwdWmma Algorithm:\n") + + DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + DIAGNOSTIC_LINE(SpecifiesThreadBlock) + + DIAGNOSTIC_LINE(SpecifiesBlockTransfer) + DIAGNOSTIC_LINE(SpecifiesLdsTransfer) + DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) + DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) + DIAGNOSTIC_LINE(SpecifiesGridwiseWmmaGemm) + DIAGNOSTIC_LINE(SpecifiesFwdConvSpecialization) + DIAGNOSTIC_LINE(SpecifiesGemmSpecialization) + DIAGNOSTIC_LINE(SpecifiesNumPrefetchStages) + - DIAGNOSTIC_LINE(SpecifiesLoopScheduler) + + DIAGNOSTIC_LINE(SpecifiesLoopScheduler) + DIAGNOSTIC_LINE(SpecifiesGridwiseGemmPipeline); } }; template -struct FwdDlAlgorithm { +struct FwdDlAlgorithm +{ CHECK_CONCEPT(T, ConvAlgorithmDescriptor) CHECK_CONCEPT(T, SpecifiesThreadBlock) CHECK_CONCEPT(T, SpecifiesFwdConvSpecialization) @@ -210,26 +210,24 @@ struct FwdDlAlgorithm { static constexpr bool c7 = c_SpecifiesDlFwdBlockTransfer; static constexpr bool c8 = c_SpecifiesDlEpilogue; - static consteval bool is_valid() { - return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8; - } + static consteval bool is_valid() { return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8; } - static consteval auto message() -> std::string { + static consteval auto message() -> std::string + { return std::string("\n=== Forward DL Algorithm Diagnostic (closest match) ===\n" - "Concepts for FwdDl Algorithm:\n") + - DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + - DIAGNOSTIC_LINE(SpecifiesThreadBlock) + + "Concepts for FwdDl Algorithm:\n") + + DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + DIAGNOSTIC_LINE(SpecifiesThreadBlock) + DIAGNOSTIC_LINE(SpecifiesFwdConvSpecialization) + DIAGNOSTIC_LINE(SpecifiesGemmSpecialization) + DIAGNOSTIC_LINE(SpecifiesDlThreadConfig) + DIAGNOSTIC_LINE(SpecifiesDlThreadCluster) + - DIAGNOSTIC_LINE(SpecifiesDlFwdBlockTransfer) + - DIAGNOSTIC_LINE(SpecifiesDlEpilogue); + DIAGNOSTIC_LINE(SpecifiesDlFwdBlockTransfer) + DIAGNOSTIC_LINE(SpecifiesDlEpilogue); } }; template -struct TileAlgorithm { +struct TileAlgorithm +{ CHECK_CONCEPT(T, ConvAlgorithmDescriptor) CHECK_CONCEPT(T, SpecifiesTileThreadBlock) CHECK_CONCEPT(T, SpecifiesTileTransfer) @@ -244,16 +242,14 @@ struct TileAlgorithm { static constexpr bool c5 = c_SpecifiesTileBlockGemm; static constexpr bool c6 = c_SpecifiesTileOptimizations; - static consteval bool is_valid() { - return c1 && c2 && c3 && c4 && c5 && c6; - } + static consteval bool is_valid() { return c1 && c2 && c3 && c4 && c5 && c6; } - static consteval auto message() -> std::string { - return std::string("\n=== CK Tile Algorithm Diagnostic (closest match) ===\n" - "Concepts for CK Tile Conv Algorithm:\n") + + static consteval auto message() -> std::string + { + return std::string("\n=== CK Tile Algorithm Diagnostic (closest match) ===\n" + "Concepts for CK Tile Conv Algorithm:\n") + DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + - DIAGNOSTIC_LINE(SpecifiesTileThreadBlock) + - DIAGNOSTIC_LINE(SpecifiesTileTransfer) + + DIAGNOSTIC_LINE(SpecifiesTileThreadBlock) + DIAGNOSTIC_LINE(SpecifiesTileTransfer) + DIAGNOSTIC_LINE(SpecifiesTileConvSpecialization) + DIAGNOSTIC_LINE(SpecifiesTileBlockGemm) + DIAGNOSTIC_LINE(SpecifiesTileOptimizations); @@ -267,21 +263,24 @@ struct LargeTensorAlgorithm : public FwdXdlAlgorithmBase static constexpr bool c13 = c_SpecifiesLargeTensorSupport; - static consteval bool is_valid() { + static consteval bool is_valid() + { // Note: Check first if the specialization is set. return c13 && FwdXdlAlgorithmBase::is_valid(); } - static consteval auto message() -> std::string { - return std::string("\n=== Forward XDL Large Tensor Algorithm Diagnostic (closest match) ===\n" - "Concepts for FwdLargeTensorXdl Algorithm:\n") + - FwdXdlAlgorithmBase::message() + - DIAGNOSTIC_LINE(SpecifiesLargeTensorSupport); + static consteval auto message() -> std::string + { + return std::string( + "\n=== Forward XDL Large Tensor Algorithm Diagnostic (closest match) ===\n" + "Concepts for FwdLargeTensorXdl Algorithm:\n") + + FwdXdlAlgorithmBase::message() + DIAGNOSTIC_LINE(SpecifiesLargeTensorSupport); } }; template -struct BwdXdlAlgorithmBase { +struct BwdXdlAlgorithmBase +{ CHECK_CONCEPT(T, ConvAlgorithmDescriptor) CHECK_CONCEPT(T, SpecifiesThreadBlock) CHECK_CONCEPT(T, SpecifiesBlockTransfer4D) @@ -300,16 +299,12 @@ struct BwdXdlAlgorithmBase { static constexpr bool c7 = c_SpecifiesGridwiseBwdXdlGemm; static constexpr bool c8 = c_SpecifiesBwdWeightConvSpecialization; - static consteval bool is_valid() { - return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8; - } + static consteval bool is_valid() { return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8; } - static consteval auto message() -> std::string { - return - DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + - DIAGNOSTIC_LINE(SpecifiesThreadBlock) + - DIAGNOSTIC_LINE(SpecifiesBlockTransfer4D) + - DIAGNOSTIC_LINE(SpecifiesLdsTransfer) + + static consteval auto message() -> std::string + { + return DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + DIAGNOSTIC_LINE(SpecifiesThreadBlock) + + DIAGNOSTIC_LINE(SpecifiesBlockTransfer4D) + DIAGNOSTIC_LINE(SpecifiesLdsTransfer) + DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) + DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) + DIAGNOSTIC_LINE(SpecifiesGridwiseBwdXdlGemm) + @@ -318,46 +313,45 @@ struct BwdXdlAlgorithmBase { }; template -struct BwdXdlAlgorithm : public BwdXdlAlgorithmBase{ +struct BwdXdlAlgorithm : public BwdXdlAlgorithmBase +{ CHECK_CONCEPT(T, SpecifiesTransposeTransfer) CHECK_CONCEPT(T, SpecifiesGenericInstance) - static constexpr bool c9 = c_SpecifiesTransposeTransfer; + static constexpr bool c9 = c_SpecifiesTransposeTransfer; static constexpr bool c10 = c_SpecifiesGenericInstance; - static consteval bool is_valid() { - return c9 && c10 && BwdXdlAlgorithmBase::is_valid(); - } + static consteval bool is_valid() { return c9 && c10 && BwdXdlAlgorithmBase::is_valid(); } - static consteval auto message() -> std::string { + static consteval auto message() -> std::string + { return std::string("\n=== Backward XDL Algorithm Diagnostic (closest match) ===\n" - "Concepts for BwdXdl Algorithm:\n") + - BwdXdlAlgorithmBase::message() + - DIAGNOSTIC_LINE(SpecifiesTransposeTransfer) + + "Concepts for BwdXdl Algorithm:\n") + + BwdXdlAlgorithmBase::message() + DIAGNOSTIC_LINE(SpecifiesTransposeTransfer) + DIAGNOSTIC_LINE(SpecifiesGenericInstance); } }; template -struct BwdMultiDXdlAlgorithm : public BwdXdlAlgorithmBase{ +struct BwdMultiDXdlAlgorithm : public BwdXdlAlgorithmBase +{ CHECK_CONCEPT(T, SpecifiesMultipleDSupport) static constexpr bool c9 = c_SpecifiesMultipleDSupport; - static consteval bool is_valid() { - return c9 && BwdXdlAlgorithmBase::is_valid(); - } + static consteval bool is_valid() { return c9 && BwdXdlAlgorithmBase::is_valid(); } - static consteval auto message() -> std::string { + static consteval auto message() -> std::string + { return std::string("\n=== Backward XDL Algorithm Diagnostic (closest match) ===\n" - "Concepts for BwdXdl Algorithm:\n") + - BwdXdlAlgorithmBase::message() + - DIAGNOSTIC_LINE(SpecifiesMultipleDSupport); + "Concepts for BwdXdl Algorithm:\n") + + BwdXdlAlgorithmBase::message() + DIAGNOSTIC_LINE(SpecifiesMultipleDSupport); } }; template -struct BwdXdlV3AlgorithmBase { +struct BwdXdlV3AlgorithmBase +{ CHECK_CONCEPT(T, ConvAlgorithmDescriptor) CHECK_CONCEPT(T, SpecifiesThreadBlock) CHECK_CONCEPT(T, SpecifiesBlockTransfer) @@ -378,16 +372,12 @@ struct BwdXdlV3AlgorithmBase { static constexpr bool c8 = c_SpecifiesBwdWeightConvSpecialization; static constexpr bool c9 = c_SpecifiesBlockGemm; - static consteval bool is_valid() { - return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9; - } + static consteval bool is_valid() { return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9; } - static consteval auto message() -> std::string { - return - DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + - DIAGNOSTIC_LINE(SpecifiesThreadBlock) + - DIAGNOSTIC_LINE(SpecifiesBlockTransfer) + - DIAGNOSTIC_LINE(SpecifiesLdsTransfer) + + static consteval auto message() -> std::string + { + return DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + DIAGNOSTIC_LINE(SpecifiesThreadBlock) + + DIAGNOSTIC_LINE(SpecifiesBlockTransfer) + DIAGNOSTIC_LINE(SpecifiesLdsTransfer) + DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) + DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) + DIAGNOSTIC_LINE(SpecifiesGridwiseBwdXdlGemm) + @@ -397,25 +387,25 @@ struct BwdXdlV3AlgorithmBase { }; template -struct BwdXdlV3Algorithm : public BwdXdlV3AlgorithmBase{ +struct BwdXdlV3Algorithm : public BwdXdlV3AlgorithmBase +{ CHECK_CONCEPT(T, SpecifiesGenericInstance) static constexpr bool c10 = c_SpecifiesGenericInstance; - static consteval bool is_valid() { - return c10 && BwdXdlV3AlgorithmBase::is_valid(); - } + static consteval bool is_valid() { return c10 && BwdXdlV3AlgorithmBase::is_valid(); } - static consteval auto message() -> std::string { + static consteval auto message() -> std::string + { return std::string("\n=== Backward XDL V3 Algorithm Diagnostic (closest match) ===\n" - "Concepts for BwdXdlV3 Algorithm:\n") + - BwdXdlV3AlgorithmBase::message() + - DIAGNOSTIC_LINE(SpecifiesGenericInstance); + "Concepts for BwdXdlV3 Algorithm:\n") + + BwdXdlV3AlgorithmBase::message() + DIAGNOSTIC_LINE(SpecifiesGenericInstance); } }; template -struct BwdTwoStageXdlAlgorithm : public BwdXdlV3AlgorithmBase{ +struct BwdTwoStageXdlAlgorithm : public BwdXdlV3AlgorithmBase +{ CHECK_CONCEPT(T, SpecifiesTransposeTransfer) CHECK_CONCEPT(T, SpecifiesGemmBatchOptions) CHECK_CONCEPT(T, SpecifiesTwoStageSupport) @@ -424,22 +414,24 @@ struct BwdTwoStageXdlAlgorithm : public BwdXdlV3AlgorithmBase{ static constexpr bool c11 = c_SpecifiesGemmBatchOptions; static constexpr bool c12 = c_SpecifiesTwoStageSupport; - static consteval bool is_valid() { + static consteval bool is_valid() + { return c10 && c11 && c12 && BwdXdlV3AlgorithmBase::is_valid(); } - static consteval auto message() -> std::string { + static consteval auto message() -> std::string + { return std::string("\n=== Backward two stage XDL Algorithm Diagnostic (closest match) ===\n" - "Concepts for BwdXdlV3 Algorithm:\n") + - BwdXdlV3AlgorithmBase::message() + - DIAGNOSTIC_LINE(SpecifiesTransposeTransfer) + + "Concepts for BwdXdlV3 Algorithm:\n") + + BwdXdlV3AlgorithmBase::message() + DIAGNOSTIC_LINE(SpecifiesTransposeTransfer) + DIAGNOSTIC_LINE(SpecifiesGemmBatchOptions) + DIAGNOSTIC_LINE(SpecifiesTwoStageSupport); } }; template -struct BwdWmmaAlgorithmBase { +struct BwdWmmaAlgorithmBase +{ CHECK_CONCEPT(T, ConvAlgorithmDescriptor) CHECK_CONCEPT(T, SpecifiesThreadBlock) CHECK_CONCEPT(T, SpecifiesBlockTransfer) @@ -458,16 +450,12 @@ struct BwdWmmaAlgorithmBase { static constexpr bool c7 = c_SpecifiesGridwiseWmmaGemm; static constexpr bool c8 = c_SpecifiesBwdWeightConvSpecialization; - static consteval bool is_valid() { - return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8; - } + static consteval bool is_valid() { return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8; } - static consteval auto message() -> std::string { - return - DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + - DIAGNOSTIC_LINE(SpecifiesThreadBlock) + - DIAGNOSTIC_LINE(SpecifiesBlockTransfer) + - DIAGNOSTIC_LINE(SpecifiesLdsTransfer) + + static consteval auto message() -> std::string + { + return DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + DIAGNOSTIC_LINE(SpecifiesThreadBlock) + + DIAGNOSTIC_LINE(SpecifiesBlockTransfer) + DIAGNOSTIC_LINE(SpecifiesLdsTransfer) + DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) + DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) + DIAGNOSTIC_LINE(SpecifiesGridwiseWmmaGemm) + @@ -476,26 +464,28 @@ struct BwdWmmaAlgorithmBase { }; template -struct BwdWmmaAlgorithm : public BwdWmmaAlgorithmBase { +struct BwdWmmaAlgorithm : public BwdWmmaAlgorithmBase +{ CHECK_CONCEPT(T, SpecifiesNumPrefetchStages) CHECK_CONCEPT(T, SpecifiesLoopScheduler) CHECK_CONCEPT(T, SpecifiesGridwiseGemmPipeline) CHECK_CONCEPT(T, SpecifiesGenericInstance) - static constexpr bool c9 = c_SpecifiesNumPrefetchStages; + static constexpr bool c9 = c_SpecifiesNumPrefetchStages; static constexpr bool c10 = c_SpecifiesLoopScheduler; static constexpr bool c11 = c_SpecifiesGridwiseGemmPipeline; static constexpr bool c12 = c_SpecifiesGenericInstance; - static consteval bool is_valid() { + static consteval bool is_valid() + { return c9 && c10 && c11 && c12 && BwdWmmaAlgorithmBase::is_valid(); } - static consteval auto message() -> std::string { + static consteval auto message() -> std::string + { return std::string("\n=== Backward WMMA Algorithm Diagnostic (closest match) ===\n" - "Concepts for BwdWmma Algorithm:\n") + - BwdWmmaAlgorithmBase::message() + - DIAGNOSTIC_LINE(SpecifiesNumPrefetchStages) + + "Concepts for BwdWmma Algorithm:\n") + + BwdWmmaAlgorithmBase::message() + DIAGNOSTIC_LINE(SpecifiesNumPrefetchStages) + DIAGNOSTIC_LINE(SpecifiesLoopScheduler) + DIAGNOSTIC_LINE(SpecifiesGridwiseGemmPipeline) + DIAGNOSTIC_LINE(SpecifiesGenericInstance); @@ -503,7 +493,8 @@ struct BwdWmmaAlgorithm : public BwdWmmaAlgorithmBase { }; template -struct BwdWmmaV3AlgorithmBase { +struct BwdWmmaV3AlgorithmBase +{ CHECK_CONCEPT(T, ConvAlgorithmDescriptor) CHECK_CONCEPT(T, SpecifiesThreadBlock) CHECK_CONCEPT(T, SpecifiesBlockTransfer) @@ -524,16 +515,12 @@ struct BwdWmmaV3AlgorithmBase { static constexpr bool c8 = c_SpecifiesBwdWeightConvSpecialization; static constexpr bool c9 = c_SpecifiesBlockGemm; - static consteval bool is_valid() { - return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9; - } + static consteval bool is_valid() { return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9; } - static consteval auto message() -> std::string { - return - DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + - DIAGNOSTIC_LINE(SpecifiesThreadBlock) + - DIAGNOSTIC_LINE(SpecifiesBlockTransfer) + - DIAGNOSTIC_LINE(SpecifiesLdsTransfer) + + static consteval auto message() -> std::string + { + return DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + DIAGNOSTIC_LINE(SpecifiesThreadBlock) + + DIAGNOSTIC_LINE(SpecifiesBlockTransfer) + DIAGNOSTIC_LINE(SpecifiesLdsTransfer) + DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) + DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) + DIAGNOSTIC_LINE(SpecifiesGridwiseWmmaGemm) + @@ -543,25 +530,24 @@ struct BwdWmmaV3AlgorithmBase { }; template -struct BwdMultiDWmmaV3Algorithm : public BwdWmmaV3AlgorithmBase { +struct BwdMultiDWmmaV3Algorithm : public BwdWmmaV3AlgorithmBase +{ CHECK_CONCEPT(T, SpecifiesMultipleDSupport) static constexpr bool c10 = c_SpecifiesMultipleDSupport; - static consteval bool is_valid() { - return c10 && BwdWmmaAlgorithmBase::is_valid(); - } + static consteval bool is_valid() { return c10 && BwdWmmaAlgorithmBase::is_valid(); } - static consteval auto message() -> std::string { + static consteval auto message() -> std::string + { return std::string("\n=== Backward WMMA Algorithm Diagnostic (closest match) ===\n" - "Concepts for BwdMultiDWmma Algorithm:\n") + - BwdWmmaAlgorithmBase::message() + - DIAGNOSTIC_LINE(SpecifiesMultipleDSupport); + "Concepts for BwdMultiDWmma Algorithm:\n") + + BwdWmmaAlgorithmBase::message() + DIAGNOSTIC_LINE(SpecifiesMultipleDSupport); } }; template -struct BwdWmmaV3Algorithm : public BwdWmmaV3AlgorithmBase +struct BwdWmmaV3Algorithm : public BwdWmmaV3AlgorithmBase { CHECK_CONCEPT(T, SpecifiesTransposeTransfer) CHECK_CONCEPT(T, SpecifiesGenericInstance) @@ -569,21 +555,19 @@ struct BwdWmmaV3Algorithm : public BwdWmmaV3AlgorithmBase static constexpr bool c10 = c_SpecifiesTransposeTransfer; static constexpr bool c11 = c_SpecifiesGenericInstance; - static consteval bool is_valid() { - return c10 && c11 && BwdWmmaV3AlgorithmBase::is_valid(); - } + static consteval bool is_valid() { return c10 && c11 && BwdWmmaV3AlgorithmBase::is_valid(); } - static consteval auto message() -> std::string { + static consteval auto message() -> std::string + { return std::string("\n=== Backward WMMA V3 Algorithm Diagnostic (closest match) ===\n" - "Concepts for BwdWmmaV3 Algorithm:\n") + - BwdWmmaV3AlgorithmBase::message() + - DIAGNOSTIC_LINE(SpecifiesTransposeTransfer) + + "Concepts for BwdWmmaV3 Algorithm:\n") + + BwdWmmaV3AlgorithmBase::message() + DIAGNOSTIC_LINE(SpecifiesTransposeTransfer) + DIAGNOSTIC_LINE(SpecifiesGenericInstance); } }; template -struct BwdTwoStageWmmaV3Algorithm : public BwdWmmaV3AlgorithmBase +struct BwdTwoStageWmmaV3Algorithm : public BwdWmmaV3AlgorithmBase { CHECK_CONCEPT(T, SpecifiesTransposeTransfer) CHECK_CONCEPT(T, SpecifiesTwoStageSupport) @@ -593,22 +577,25 @@ struct BwdTwoStageWmmaV3Algorithm : public BwdWmmaV3AlgorithmBase static constexpr bool c11 = c_SpecifiesTwoStageSupport; static constexpr bool c12 = c_SpecifiesGemmBatchOptions; - static consteval bool is_valid() { + static consteval bool is_valid() + { return c10 && c11 && c12 && BwdWmmaV3AlgorithmBase::is_valid(); } - static consteval auto message() -> std::string { - return std::string("\n=== Backward Two Stage WMMA V3 Algorithm Diagnostic (closest match) ===\n" - "Concepts for BwdTwoStageWmmaV3 Algorithm:\n") + - BwdWmmaV3AlgorithmBase::message() + - DIAGNOSTIC_LINE(SpecifiesTransposeTransfer) + + static consteval auto message() -> std::string + { + return std::string( + "\n=== Backward Two Stage WMMA V3 Algorithm Diagnostic (closest match) ===\n" + "Concepts for BwdTwoStageWmmaV3 Algorithm:\n") + + BwdWmmaV3AlgorithmBase::message() + DIAGNOSTIC_LINE(SpecifiesTransposeTransfer) + DIAGNOSTIC_LINE(SpecifiesGemmBatchOptions) + DIAGNOSTIC_LINE(SpecifiesTwoStageSupport); } }; template -struct BwdDlAlgorithm { +struct BwdDlAlgorithm +{ CHECK_CONCEPT(T, ConvAlgorithmDescriptor) CHECK_CONCEPT(T, SpecifiesThreadBlock) CHECK_CONCEPT(T, SpecifiesBwdWeightConvSpecialization) @@ -625,286 +612,345 @@ struct BwdDlAlgorithm { static constexpr bool c6 = c_SpecifiesDlBwdBlockTransfer; static constexpr bool c7 = c_SpecifiesDlEpilogue; - static consteval bool is_valid() { - return c1 && c2 && c3 && c4 && c5 && c6 && c7; - } + static consteval bool is_valid() { return c1 && c2 && c3 && c4 && c5 && c6 && c7; } - static consteval auto message() -> std::string { + static consteval auto message() -> std::string + { return std::string("\n=== Backward DL Algorithm Diagnostic (closest match) ===\n" - "Concepts for BwdDl Algorithm:\n") + - DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + - DIAGNOSTIC_LINE(SpecifiesThreadBlock) + + "Concepts for BwdDl Algorithm:\n") + + DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + DIAGNOSTIC_LINE(SpecifiesThreadBlock) + DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization) + DIAGNOSTIC_LINE(SpecifiesDlThreadConfig) + DIAGNOSTIC_LINE(SpecifiesDlThreadCluster) + - DIAGNOSTIC_LINE(SpecifiesDlBwdBlockTransfer) + - DIAGNOSTIC_LINE(SpecifiesDlEpilogue); + DIAGNOSTIC_LINE(SpecifiesDlBwdBlockTransfer) + DIAGNOSTIC_LINE(SpecifiesDlEpilogue); } }; template -consteval int count_matches_fwd_xdl_v3() { +consteval int count_matches_fwd_xdl_v3() +{ using Alg = FwdXdlV3Algorithm; - return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + Alg::c10; + return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + + Alg::c10; } template -consteval int count_matches_fwd_xdl() { +consteval int count_matches_fwd_xdl() +{ using Alg = FwdXdlAlgorithm; - return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + Alg::c10 + Alg::c11 + Alg::c12 + Alg::c13; + return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + + Alg::c10 + Alg::c11 + Alg::c12 + Alg::c13; } template -consteval int count_matches_fwd_wmma() { +consteval int count_matches_fwd_wmma() +{ using Alg = FwdWmmaAlgorithm; - return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + Alg::c10 + Alg::c11; + return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + + Alg::c10 + Alg::c11; } template -consteval int count_matches_fwd_dl() { +consteval int count_matches_fwd_dl() +{ using Alg = FwdDlAlgorithm; return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8; } template -consteval int count_matches_bwd_xdl() { +consteval int count_matches_bwd_xdl() +{ using Alg = BwdXdlAlgorithm; return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9; } template -consteval int count_matches_bwd_multi_d_xdl() { +consteval int count_matches_bwd_multi_d_xdl() +{ using Alg = BwdMultiDXdlAlgorithm; return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9; } template -consteval int count_matches_bwd_xdl_v3() { +consteval int count_matches_bwd_xdl_v3() +{ using Alg = BwdXdlV3Algorithm; return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9; } template -consteval int count_matches_bwd_two_stage_xdl() { +consteval int count_matches_bwd_two_stage_xdl() +{ using Alg = BwdTwoStageXdlAlgorithm; - return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + Alg::c10 + Alg::c11 + Alg::c12; + return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + + Alg::c10 + Alg::c11 + Alg::c12; } template -consteval int count_matches_bwd_wmma() { +consteval int count_matches_bwd_wmma() +{ using Alg = BwdWmmaAlgorithm; - return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + Alg::c10 + Alg::c11 + Alg::c12; + return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + + Alg::c10 + Alg::c11 + Alg::c12; } template -consteval int count_matches_bwd_multi_d_wmma() { +consteval int count_matches_bwd_multi_d_wmma() +{ using Alg = BwdMultiDWmmaV3Algorithm; - return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + Alg::c10 + Alg::c11 + Alg::c12; + return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + + Alg::c10 + Alg::c11 + Alg::c12; } template -consteval int count_matches_bwd_wmma_v3() { +consteval int count_matches_bwd_wmma_v3() +{ using Alg = BwdWmmaV3Algorithm; - return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + Alg::c10 + Alg::c11; + return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + + Alg::c10 + Alg::c11; } template -consteval int count_matches_bwd_two_stage_wmma_v3() { +consteval int count_matches_bwd_two_stage_wmma_v3() +{ using Alg = BwdTwoStageWmmaV3Algorithm; - return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + Alg::c10 + Alg::c11 + Alg::c12; + return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + + Alg::c10 + Alg::c11 + Alg::c12; } template -consteval int count_matches_bwd_dl() { +consteval int count_matches_bwd_dl() +{ using Alg = BwdDlAlgorithm; return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7; } template -consteval int count_matches_large_tensor() { +consteval int count_matches_large_tensor() +{ using Alg = LargeTensorAlgorithm; - return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + Alg::c10 + Alg::c11 + Alg::c12 + Alg::c13; + return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + + Alg::c10 + Alg::c11 + Alg::c12 + Alg::c13; } template -consteval int count_matches_tile() { +consteval int count_matches_tile() +{ using Alg = TileAlgorithm; return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6; } template -consteval void diagnose_fwd_algorithm_signature() +consteval void diagnose_fwd_algorithm_signature() { // Find closest matching variant - constexpr int xdl_v3_matches = count_matches_fwd_xdl_v3(); - constexpr int xdl_matches = count_matches_fwd_xdl(); - constexpr int wmma_matches = count_matches_fwd_wmma(); - constexpr int dl_matches = count_matches_fwd_dl(); + constexpr int xdl_v3_matches = count_matches_fwd_xdl_v3(); + constexpr int xdl_matches = count_matches_fwd_xdl(); + constexpr int wmma_matches = count_matches_fwd_wmma(); + constexpr int dl_matches = count_matches_fwd_dl(); constexpr int large_tensor_matches = count_matches_large_tensor(); - constexpr int tile_matches = count_matches_tile(); - + constexpr int tile_matches = count_matches_tile(); + // Check whether we have XDL or WMMA algorithm - if constexpr (SpecifiesGridwiseFwdXdlGemm) + if constexpr(SpecifiesGridwiseFwdXdlGemm) { - constexpr int max_1 = xdl_v3_matches > xdl_matches ? xdl_v3_matches : xdl_matches; - constexpr int max_2 = max_1 > dl_matches ? max_1 : dl_matches; + constexpr int max_1 = xdl_v3_matches > xdl_matches ? xdl_v3_matches : xdl_matches; + constexpr int max_2 = max_1 > dl_matches ? max_1 : dl_matches; constexpr int max_matches = large_tensor_matches > max_2 ? large_tensor_matches : max_2; - if constexpr(max_matches == xdl_v3_matches) { + if constexpr(max_matches == xdl_v3_matches) + { using Alg = FwdXdlV3Algorithm; static_assert(Alg::is_valid(), Alg::message()); - } else if constexpr(max_matches == xdl_matches) { + } + else if constexpr(max_matches == xdl_matches) + { using Alg = FwdXdlAlgorithm; static_assert(Alg::is_valid(), Alg::message()); - } else if constexpr(max_matches == dl_matches) { + } + else if constexpr(max_matches == dl_matches) + { using Alg = FwdDlAlgorithm; static_assert(Alg::is_valid(), Alg::message()); - } else if constexpr (max_matches == large_tensor_matches) { + } + else if constexpr(max_matches == large_tensor_matches) + { using Alg = LargeTensorAlgorithm; static_assert(Alg::is_valid(), Alg::message()); } } - else if constexpr (SpecifiesGridwiseWmmaGemm) + else if constexpr(SpecifiesGridwiseWmmaGemm) { using Alg = FwdWmmaAlgorithm; static_assert(Alg::is_valid(), Alg::message()); } - else + else { // Find maximum matches across all variants - constexpr int max_1 = xdl_v3_matches > xdl_matches ? xdl_v3_matches : xdl_matches; - constexpr int max_2 = wmma_matches > dl_matches ? wmma_matches : dl_matches; - constexpr int max_3 = max_1 > max_2 ? max_1 : max_2; - constexpr int max_4 = max_3 > large_tensor_matches ? max_3 : large_tensor_matches; + constexpr int max_1 = xdl_v3_matches > xdl_matches ? xdl_v3_matches : xdl_matches; + constexpr int max_2 = wmma_matches > dl_matches ? wmma_matches : dl_matches; + constexpr int max_3 = max_1 > max_2 ? max_1 : max_2; + constexpr int max_4 = max_3 > large_tensor_matches ? max_3 : large_tensor_matches; constexpr int max_matches = max_4 > tile_matches ? max_4 : tile_matches; // If we cannot match with neither WMMA nor XDL, try all algorithms for diagnostics // and see whichi is the closest match. - if constexpr(max_matches == xdl_v3_matches) { + if constexpr(max_matches == xdl_v3_matches) + { using Alg = FwdXdlV3Algorithm; static_assert(Alg::is_valid(), Alg::message()); - } else if constexpr(max_matches == xdl_matches) { + } + else if constexpr(max_matches == xdl_matches) + { using Alg = FwdXdlAlgorithm; static_assert(Alg::is_valid(), Alg::message()); - } else if constexpr(max_matches == wmma_matches) { + } + else if constexpr(max_matches == wmma_matches) + { using Alg = FwdWmmaAlgorithm; static_assert(Alg::is_valid(), Alg::message()); - } else if constexpr(max_matches == dl_matches) { + } + else if constexpr(max_matches == dl_matches) + { using Alg = FwdDlAlgorithm; static_assert(Alg::is_valid(), Alg::message()); - } else if constexpr (max_matches == large_tensor_matches) { + } + else if constexpr(max_matches == large_tensor_matches) + { using Alg = LargeTensorAlgorithm; static_assert(Alg::is_valid(), Alg::message()); - } else if constexpr (max_matches == tile_matches) { + } + else if constexpr(max_matches == tile_matches) + { using Alg = TileAlgorithm; static_assert(Alg::is_valid(), Alg::message()); } - else { + else + { // This should never happen - static_assert(false, "Internal Error: No matching algorithm variant found for diagnostics."); + static_assert(false, + "Internal Error: No matching algorithm variant found for diagnostics."); } } } template -consteval void diagnose_bwd_weight_algorithm_signature() -{ - constexpr int xdl_matches = count_matches_bwd_xdl(); - constexpr int xdl_v3_matches = count_matches_fwd_xdl_v3(); - constexpr int two_stage_xdl_matches = count_matches_bwd_two_stage_xdl(); - constexpr int dl_matches = count_matches_bwd_dl(); - constexpr int multi_d_xdl_matches = count_matches_bwd_multi_d_xdl(); - constexpr int wmma_v3_matches = count_matches_bwd_wmma_v3(); +consteval void diagnose_bwd_weight_algorithm_signature() +{ + constexpr int xdl_matches = count_matches_bwd_xdl(); + constexpr int xdl_v3_matches = count_matches_fwd_xdl_v3(); + constexpr int two_stage_xdl_matches = count_matches_bwd_two_stage_xdl(); + constexpr int dl_matches = count_matches_bwd_dl(); + constexpr int multi_d_xdl_matches = count_matches_bwd_multi_d_xdl(); + constexpr int wmma_v3_matches = count_matches_bwd_wmma_v3(); constexpr int two_stage_wmma_v3_matches = count_matches_bwd_two_stage_wmma_v3(); - constexpr int wmma_matches = count_matches_bwd_wmma(); - constexpr int multi_d_wmma_matches = count_matches_bwd_multi_d_wmma(); + constexpr int wmma_matches = count_matches_bwd_wmma(); + constexpr int multi_d_wmma_matches = count_matches_bwd_multi_d_wmma(); // Check whether we have XDL or WMMA algorithm - if constexpr (SpecifiesGridwiseBwdXdlGemm) + if constexpr(SpecifiesGridwiseBwdXdlGemm) { - constexpr int max1 = xdl_v3_matches > xdl_matches ? xdl_v3_matches : xdl_matches; - constexpr int max2 = max1 > two_stage_xdl_matches ? max1 : two_stage_xdl_matches; - constexpr int max3 = max2 > dl_matches ? max2 : dl_matches; + constexpr int max1 = xdl_v3_matches > xdl_matches ? xdl_v3_matches : xdl_matches; + constexpr int max2 = max1 > two_stage_xdl_matches ? max1 : two_stage_xdl_matches; + constexpr int max3 = max2 > dl_matches ? max2 : dl_matches; constexpr int max_matches = max3 > multi_d_xdl_matches ? max3 : multi_d_xdl_matches; - if constexpr (max_matches == xdl_matches) { + if constexpr(max_matches == xdl_matches) + { using Alg = BwdXdlAlgorithm; static_assert(Alg::is_valid(), Alg::message()); - } - else if constexpr (max_matches == xdl_v3_matches) { + } + else if constexpr(max_matches == xdl_v3_matches) + { using Alg = BwdXdlV3Algorithm; static_assert(Alg::is_valid(), Alg::message()); } - else if constexpr (max_matches == two_stage_xdl_matches) { + else if constexpr(max_matches == two_stage_xdl_matches) + { using Alg = BwdTwoStageXdlAlgorithm; static_assert(Alg::is_valid(), Alg::message()); } - else if constexpr (max_matches == dl_matches) { + else if constexpr(max_matches == dl_matches) + { using Alg = BwdDlAlgorithm; static_assert(Alg::is_valid(), Alg::message()); } - else if constexpr (max_matches == multi_d_xdl_matches) { + else if constexpr(max_matches == multi_d_xdl_matches) + { using Alg = BwdMultiDXdlAlgorithm; static_assert(Alg::is_valid(), Alg::message()); } } - else if constexpr (SpecifiesGridwiseWmmaGemm) - { - constexpr int max_1 = wmma_v3_matches > two_stage_wmma_v3_matches ? wmma_v3_matches : two_stage_wmma_v3_matches; - constexpr int max_2 = max_1 > wmma_matches ? max_1 : wmma_matches; + else if constexpr(SpecifiesGridwiseWmmaGemm) + { + constexpr int max_1 = wmma_v3_matches > two_stage_wmma_v3_matches + ? wmma_v3_matches + : two_stage_wmma_v3_matches; + constexpr int max_2 = max_1 > wmma_matches ? max_1 : wmma_matches; constexpr int max_matches = multi_d_wmma_matches > max_2 ? multi_d_wmma_matches : max_2; - if constexpr (max_matches == wmma_v3_matches) { + if constexpr(max_matches == wmma_v3_matches) + { using Alg = BwdWmmaV3Algorithm; static_assert(Alg::is_valid(), Alg::message()); } - else if constexpr (max_matches == two_stage_wmma_v3_matches) { + else if constexpr(max_matches == two_stage_wmma_v3_matches) + { using Alg = BwdTwoStageWmmaV3Algorithm; static_assert(Alg::is_valid(), Alg::message()); } - else if constexpr (max_matches == wmma_matches) { + else if constexpr(max_matches == wmma_matches) + { using Alg = BwdWmmaAlgorithm; static_assert(Alg::is_valid(), Alg::message()); } - else if constexpr (max_matches == multi_d_wmma_matches) { + else if constexpr(max_matches == multi_d_wmma_matches) + { using Alg = BwdMultiDWmmaV3Algorithm; static_assert(Alg::is_valid(), Alg::message()); } } - else + else { // If we cannot match with neither WMMA nor XDL, try all algorithms for diagnostics // and see which is the closest match. - constexpr int max1 = xdl_v3_matches > xdl_matches ? xdl_v3_matches : xdl_matches; - constexpr int max2 = max1 > two_stage_xdl_matches ? max1 : two_stage_xdl_matches; - constexpr int max3 = max2 > dl_matches ? max2 : dl_matches; + constexpr int max1 = xdl_v3_matches > xdl_matches ? xdl_v3_matches : xdl_matches; + constexpr int max2 = max1 > two_stage_xdl_matches ? max1 : two_stage_xdl_matches; + constexpr int max3 = max2 > dl_matches ? max2 : dl_matches; constexpr int max_matches = max3 > multi_d_xdl_matches ? max3 : multi_d_xdl_matches; - if constexpr (max_matches == xdl_matches) { + if constexpr(max_matches == xdl_matches) + { using Alg = BwdXdlAlgorithm; static_assert(Alg::is_valid(), Alg::message()); - } - else if constexpr (max_matches == xdl_v3_matches) { + } + else if constexpr(max_matches == xdl_v3_matches) + { using Alg = BwdXdlV3Algorithm; static_assert(Alg::is_valid(), Alg::message()); } - else if constexpr (max_matches == two_stage_xdl_matches) { + else if constexpr(max_matches == two_stage_xdl_matches) + { using Alg = BwdTwoStageXdlAlgorithm; static_assert(Alg::is_valid(), Alg::message()); } - else if constexpr (max_matches == dl_matches) { + else if constexpr(max_matches == dl_matches) + { using Alg = BwdDlAlgorithm; static_assert(Alg::is_valid(), Alg::message()); } - else if constexpr (max_matches == multi_d_xdl_matches) { + else if constexpr(max_matches == multi_d_xdl_matches) + { using Alg = BwdMultiDXdlAlgorithm; static_assert(Alg::is_valid(), Alg::message()); } - else { + else + { // This should never happen - static_assert(false, "Internal Error: No matching algorithm variant found for diagnostics."); + static_assert(false, + "Internal Error: No matching algorithm variant found for diagnostics."); } } } -} +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_dl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_dl_factory.hpp index 203280c0630..80143427c78 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_dl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_dl_factory.hpp @@ -23,12 +23,13 @@ template ; - using Types = internal::BwdWeightConvTensorDataTypes; - using Ops = internal::ElementwiseOps; - using AlgorithmType = decltype(ALGORITHM); + using Layouts = internal::ConvTensorLayouts; + using Types = internal::BwdWeightConvTensorDataTypes; + using Ops = internal::ElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); - static constexpr auto BWD_CONV_SPECIALIZATION = internal::SetBwdWeightConvSpecialization(); + static constexpr auto BWD_CONV_SPECIALIZATION = + internal::SetBwdWeightConvSpecialization(); static constexpr auto BLOCK = internal::SetThreadBlockInfo(); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp index 9d94404a88d..6f5c679b595 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp @@ -26,15 +26,16 @@ template ; - using Types = internal::BwdWeightConvTensorDataTypes; - using Ops = internal::ElementwiseOps; - using AlgorithmType = decltype(ALGORITHM); + using Layouts = internal::ConvTensorLayouts; + using Types = internal::BwdWeightConvTensorDataTypes; + using Ops = internal::ElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); - static constexpr auto BWD_CONV_SPECIALIZATION = internal::SetBwdWeightConvSpecialization(); + static constexpr auto BWD_CONV_SPECIALIZATION = + internal::SetBwdWeightConvSpecialization(); - static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; static constexpr auto A_BLOCK_TRANSFER = internal::SetBwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -47,58 +48,63 @@ struct ConvBwdWeightMultiDWmmaV3Factory static_assert(InputVectorTransferLimits, "Invalid A block transfer config"); static_assert(InputVectorTransferLimits, "Invalid B block transfer config"); static_assert(OutputVectorTransferLimits, "Invalid C block transfer config"); - static_assert(AccessOrderLimits3D, "Invalid A thread cluster access order"); - static_assert(AccessOrderLimits3D, "Invalid B thread cluster access order"); - static_assert(AccessOrderLimits3D, "Invalid A source access order"); - static_assert(AccessOrderLimits3D, "Invalid B source access order"); + static_assert(AccessOrderLimits3D, + "Invalid A thread cluster access order"); + static_assert(AccessOrderLimits3D, + "Invalid B thread cluster access order"); + static_assert(AccessOrderLimits3D, + "Invalid A source access order"); + static_assert(AccessOrderLimits3D, + "Invalid B source access order"); // The forward convolution kernel class instance. - using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< - SPATIAL_DIM, - typename Layouts::InLayout, - typename Layouts::WeiLayout, - typename Layouts::OutLayout, - typename Layouts::DsLayout, - typename Types::InDataType, - typename Types::WeiDataType, - typename Types::OutDataType, - typename Types::AccDataType, - typename Types::DsDataType, - typename Ops::InElementwiseOp, - typename Ops::WeiElementwiseOp, - typename Ops::OutElementwiseOp, - BWD_CONV_SPECIALIZATION, - BLOCK.block_size, - BLOCK.per_block.m, - BLOCK.per_block.n, - BLOCK.per_block.k, - GRIDWISE_GEMM.k1, - GRIDWISE_GEMM.m_per_wmma, - GRIDWISE_GEMM.n_per_wmma, - GRIDWISE_GEMM.m_wmma_per_wave, - GRIDWISE_GEMM.n_wmma_per_wave, - to_sequence_v, - to_sequence_v, - to_sequence_v, - A_BLOCK_TRANSFER.src_vector_dim, - A_BLOCK_TRANSFER.src_scalar_per_vector, - A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, - A_BLOCK_TRANSFER.lds_padding, - to_sequence_v, - to_sequence_v, - to_sequence_v, - B_BLOCK_TRANSFER.src_vector_dim, - B_BLOCK_TRANSFER.src_scalar_per_vector, - B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, - B_BLOCK_TRANSFER.lds_padding, - C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, - C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, - to_sequence_v, - C_BLOCK_TRANSFER.scalar_per_vector, - BLOCK_GEMM.scheduler, - BLOCK_GEMM.pipeline_version, - typename Types::OutComputeType, - typename Types::InComputeType>; + using Instance = + ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< + SPATIAL_DIM, + typename Layouts::InLayout, + typename Layouts::WeiLayout, + typename Layouts::OutLayout, + typename Layouts::DsLayout, + typename Types::InDataType, + typename Types::WeiDataType, + typename Types::OutDataType, + typename Types::AccDataType, + typename Types::DsDataType, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, + BWD_CONV_SPECIALIZATION, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.k1, + GRIDWISE_GEMM.m_per_wmma, + GRIDWISE_GEMM.n_per_wmma, + GRIDWISE_GEMM.m_wmma_per_wave, + GRIDWISE_GEMM.n_wmma_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + A_BLOCK_TRANSFER.lds_padding, + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + B_BLOCK_TRANSFER.lds_padding, + C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + BLOCK_GEMM.scheduler, + BLOCK_GEMM.pipeline_version, + typename Types::OutComputeType, + typename Types::InComputeType>; }; } // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp index dae333d99e1..9f76568ca89 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp @@ -26,16 +26,17 @@ template ; - using Types = internal::BwdWeightConvTensorDataTypes; - using Ops = internal::ElementwiseOps; - using AlgorithmType = decltype(ALGORITHM); + using Layouts = internal::ConvTensorLayouts; + using Types = internal::BwdWeightConvTensorDataTypes; + using Ops = internal::ElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); - static constexpr auto BWD_CONV_SPECIALIZATION = internal::SetBwdWeightConvSpecialization(); + static constexpr auto BWD_CONV_SPECIALIZATION = + internal::SetBwdWeightConvSpecialization(); - static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; - static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; static constexpr auto A_BLOCK_TRANSFER = internal::SetBwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp index 8288874f0e3..86c48fe3220 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp @@ -26,15 +26,16 @@ template ; - using Types = internal::BwdWeightConvTensorDataTypes; - using Ops = internal::ElementwiseOps; - using AlgorithmType = decltype(ALGORITHM); + using Layouts = internal::ConvTensorLayouts; + using Types = internal::BwdWeightConvTensorDataTypes; + using Ops = internal::ElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); - static constexpr auto BWD_CONV_SPECIALIZATION = internal::SetBwdWeightConvSpecialization(); + static constexpr auto BWD_CONV_SPECIALIZATION = + internal::SetBwdWeightConvSpecialization(); - static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; static constexpr auto A_BLOCK_TRANSFER = internal::SetBwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -47,59 +48,64 @@ struct ConvBwdWeightTwoStageWmmaV3Factory static_assert(InputVectorTransferLimits, "Invalid A block transfer config"); static_assert(InputVectorTransferLimits, "Invalid B block transfer config"); static_assert(OutputVectorTransferLimits, "Invalid C block transfer config"); - static_assert(AccessOrderLimits3D, "Invalid A thread cluster access order"); - static_assert(AccessOrderLimits3D, "Invalid B thread cluster access order"); - static_assert(AccessOrderLimits3D, "Invalid A source access order"); - static_assert(AccessOrderLimits3D, "Invalid B source access order"); + static_assert(AccessOrderLimits3D, + "Invalid A thread cluster access order"); + static_assert(AccessOrderLimits3D, + "Invalid B thread cluster access order"); + static_assert(AccessOrderLimits3D, + "Invalid A source access order"); + static_assert(AccessOrderLimits3D, + "Invalid B source access order"); // The forward convolution kernel class instance. - using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< - SPATIAL_DIM, - typename Layouts::InLayout, - typename Layouts::WeiLayout, - typename Layouts::OutLayout, - typename Types::InDataType, - typename Types::WeiDataType, - typename Types::OutDataType, - typename Types::AccDataType, - typename Ops::InElementwiseOp, - typename Ops::WeiElementwiseOp, - typename Ops::OutElementwiseOp, - BWD_CONV_SPECIALIZATION, - BLOCK.block_size, - BLOCK.per_block.m, - BLOCK.per_block.n, - BLOCK.per_block.k, - GRIDWISE_GEMM.k1, - GRIDWISE_GEMM.m_per_wmma, - GRIDWISE_GEMM.n_per_wmma, - GRIDWISE_GEMM.m_wmma_per_wave, - GRIDWISE_GEMM.n_wmma_per_wave, - to_sequence_v, - to_sequence_v, - to_sequence_v, - A_BLOCK_TRANSFER.src_vector_dim, - A_BLOCK_TRANSFER.src_scalar_per_vector, - A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, - A_BLOCK_TRANSFER.lds_padding, - to_sequence_v, - to_sequence_v, - to_sequence_v, - B_BLOCK_TRANSFER.src_vector_dim, - B_BLOCK_TRANSFER.src_scalar_per_vector, - B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, - B_BLOCK_TRANSFER.lds_padding, - C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, - C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, - to_sequence_v, - C_BLOCK_TRANSFER.scalar_per_vector, - BLOCK_GEMM.scheduler, - BLOCK_GEMM.pipeline_version, - ALGORITHM.num_conv_groups_to_merge, - typename Types::OutComputeType, - typename Types::InComputeType, - ALGORITHM.max_transpose_transfer_src_scalar_per_vector, - ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>; + using Instance = + ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< + SPATIAL_DIM, + typename Layouts::InLayout, + typename Layouts::WeiLayout, + typename Layouts::OutLayout, + typename Types::InDataType, + typename Types::WeiDataType, + typename Types::OutDataType, + typename Types::AccDataType, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, + BWD_CONV_SPECIALIZATION, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.k1, + GRIDWISE_GEMM.m_per_wmma, + GRIDWISE_GEMM.n_per_wmma, + GRIDWISE_GEMM.m_wmma_per_wave, + GRIDWISE_GEMM.n_wmma_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + A_BLOCK_TRANSFER.lds_padding, + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + B_BLOCK_TRANSFER.lds_padding, + C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + BLOCK_GEMM.scheduler, + BLOCK_GEMM.pipeline_version, + ALGORITHM.num_conv_groups_to_merge, + typename Types::OutComputeType, + typename Types::InComputeType, + ALGORITHM.max_transpose_transfer_src_scalar_per_vector, + ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>; }; } // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp index df3e4c01b2e..9c37beae46c 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp @@ -26,16 +26,17 @@ template ; - using Types = internal::BwdWeightConvTensorDataTypes; - using Ops = internal::ElementwiseOps; - using AlgorithmType = decltype(ALGORITHM); + using Layouts = internal::ConvTensorLayouts; + using Types = internal::BwdWeightConvTensorDataTypes; + using Ops = internal::ElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); - static constexpr auto BWD_CONV_SPECIALIZATION = internal::SetBwdWeightConvSpecialization(); + static constexpr auto BWD_CONV_SPECIALIZATION = + internal::SetBwdWeightConvSpecialization(); - static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; - static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; static constexpr auto A_BLOCK_TRANSFER = internal::SetBwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -48,10 +49,14 @@ struct ConvBwdWeightTwoStageXdlFactory static_assert(InputVectorTransferLimits, "Invalid A block transfer config"); static_assert(InputVectorTransferLimits, "Invalid B block transfer config"); static_assert(OutputVectorTransferLimits, "Invalid C block transfer config"); - static_assert(AccessOrderLimits3D, "Invalid A thread cluster access order"); - static_assert(AccessOrderLimits3D, "Invalid B thread cluster access order"); - static_assert(AccessOrderLimits3D, "Invalid A source access order"); - static_assert(AccessOrderLimits3D, "Invalid B source access order"); + static_assert(AccessOrderLimits3D, + "Invalid A thread cluster access order"); + static_assert(AccessOrderLimits3D, + "Invalid B thread cluster access order"); + static_assert(AccessOrderLimits3D, + "Invalid A source access order"); + static_assert(AccessOrderLimits3D, + "Invalid B source access order"); // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp index b8c3bc8f9b2..817432081b0 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp @@ -26,15 +26,16 @@ template ; - using Types = internal::BwdWeightConvTensorDataTypes; - using Ops = internal::ElementwiseOps; - using AlgorithmType = decltype(ALGORITHM); + using Layouts = internal::ConvTensorLayouts; + using Types = internal::BwdWeightConvTensorDataTypes; + using Ops = internal::ElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); - static constexpr auto BWD_CONV_SPECIALIZATION = internal::SetBwdWeightConvSpecialization(); + static constexpr auto BWD_CONV_SPECIALIZATION = + internal::SetBwdWeightConvSpecialization(); - static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; static constexpr auto GRIDWISE_GEMM_PIPELINE_VERSION = internal::SetGridwiseGemmPipelineVersion(); static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler(); @@ -50,10 +51,14 @@ struct ConvBwdWeightWmmaFactory static_assert(InputVectorTransferLimits, "Invalid A block transfer config"); static_assert(InputVectorTransferLimits, "Invalid B block transfer config"); static_assert(OutputVectorTransferLimits, "Invalid C block transfer config"); - static_assert(AccessOrderLimits4D, "Invalid A thread cluster access order"); - static_assert(AccessOrderLimits4D, "Invalid B thread cluster access order"); - static_assert(AccessOrderLimits4D, "Invalid A source access order"); - static_assert(AccessOrderLimits4D, "Invalid B source access order"); + static_assert(AccessOrderLimits4D, + "Invalid A thread cluster access order"); + static_assert(AccessOrderLimits4D, + "Invalid B thread cluster access order"); + static_assert(AccessOrderLimits4D, + "Invalid A source access order"); + static_assert(AccessOrderLimits4D, + "Invalid B source access order"); // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle< diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp index d1cfbe9e8dc..baf84402c34 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp @@ -26,15 +26,16 @@ template ; - using Types = internal::BwdWeightConvTensorDataTypes; - using Ops = internal::ElementwiseOps; - using AlgorithmType = decltype(ALGORITHM); + using Layouts = internal::ConvTensorLayouts; + using Types = internal::BwdWeightConvTensorDataTypes; + using Ops = internal::ElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); - static constexpr auto BWD_CONV_SPECIALIZATION = internal::SetBwdWeightConvSpecialization(); + static constexpr auto BWD_CONV_SPECIALIZATION = + internal::SetBwdWeightConvSpecialization(); - static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; static constexpr auto A_BLOCK_TRANSFER = internal::SetBwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -47,10 +48,14 @@ struct ConvBwdWeightWmmaV3Factory static_assert(InputVectorTransferLimits, "Invalid A block transfer config"); static_assert(InputVectorTransferLimits, "Invalid B block transfer config"); static_assert(OutputVectorTransferLimits, "Invalid C block transfer config"); - static_assert(AccessOrderLimits3D, "Invalid A thread cluster access order"); - static_assert(AccessOrderLimits3D, "Invalid B thread cluster access order"); - static_assert(AccessOrderLimits3D, "Invalid A source access order"); - static_assert(AccessOrderLimits3D, "Invalid B source access order"); + static_assert(AccessOrderLimits3D, + "Invalid A thread cluster access order"); + static_assert(AccessOrderLimits3D, + "Invalid B thread cluster access order"); + static_assert(AccessOrderLimits3D, + "Invalid A source access order"); + static_assert(AccessOrderLimits3D, + "Invalid B source access order"); // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp index e89d227f821..91c19d2bd0d 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp @@ -26,16 +26,17 @@ template ; - using Types = internal::BwdWeightConvTensorDataTypes; - using Ops = internal::ElementwiseOps; - using AlgorithmType = decltype(ALGORITHM); + using Layouts = internal::ConvTensorLayouts; + using Types = internal::BwdWeightConvTensorDataTypes; + using Ops = internal::ElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); - static constexpr auto BWD_CONV_SPECIALIZATION = internal::SetBwdWeightConvSpecialization(); + static constexpr auto BWD_CONV_SPECIALIZATION = + internal::SetBwdWeightConvSpecialization(); - static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; - static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; static constexpr auto A_BLOCK_TRANSFER = internal::SetBwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp index e93b1456cd7..f3edd0e6d93 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp @@ -26,16 +26,17 @@ template ; - using Types = internal::BwdWeightConvTensorDataTypes; - using Ops = internal::ElementwiseOps; - using AlgorithmType = decltype(ALGORITHM); + using Layouts = internal::ConvTensorLayouts; + using Types = internal::BwdWeightConvTensorDataTypes; + using Ops = internal::ElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); - static constexpr auto BWD_CONV_SPECIALIZATION = internal::SetBwdWeightConvSpecialization(); + static constexpr auto BWD_CONV_SPECIALIZATION = + internal::SetBwdWeightConvSpecialization(); - static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; - static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; static constexpr auto A_BLOCK_TRANSFER = internal::SetBwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -48,10 +49,14 @@ struct ConvBwdWeightXdlV3Factory static_assert(InputVectorTransferLimits, "Invalid A block transfer config"); static_assert(InputVectorTransferLimits, "Invalid B block transfer config"); static_assert(OutputVectorTransferLimits, "Invalid C block transfer config"); - static_assert(AccessOrderLimits3D, "Invalid A thread cluster access order"); - static_assert(AccessOrderLimits3D, "Invalid B thread cluster access order"); - static_assert(AccessOrderLimits3D, "Invalid A source access order"); - static_assert(AccessOrderLimits3D, "Invalid B source access order"); + static_assert(AccessOrderLimits3D, + "Invalid A thread cluster access order"); + static_assert(AccessOrderLimits3D, + "Invalid B thread cluster access order"); + static_assert(AccessOrderLimits3D, + "Invalid A source access order"); + static_assert(AccessOrderLimits3D, + "Invalid B source access order"); // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp index 9bca0177350..892ad832d8b 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -155,41 +155,45 @@ constexpr auto make_conv_instance() // Backward weight direction (will expand with more algorithms in the future) else if constexpr(ConvDirectionIsBackwardWeight) { - if constexpr (BwdXdlAlgorithm::is_valid()) + if constexpr(BwdXdlAlgorithm::is_valid()) { return typename ConvBwdWeightXdlFactory::Instance{}; } - else if constexpr (BwdXdlV3Algorithm::is_valid()) + else if constexpr(BwdXdlV3Algorithm::is_valid()) { return typename ConvBwdWeightXdlV3Factory::Instance{}; } - else if constexpr (BwdTwoStageXdlAlgorithm::is_valid()) + else if constexpr(BwdTwoStageXdlAlgorithm::is_valid()) { - return typename ConvBwdWeightTwoStageXdlFactory::Instance{}; + return + typename ConvBwdWeightTwoStageXdlFactory::Instance{}; } - else if constexpr (BwdDlAlgorithm::is_valid()) + else if constexpr(BwdDlAlgorithm::is_valid()) { return typename ConvBwdWeightDlFactory::Instance{}; } - else if constexpr (BwdMultiDXdlAlgorithm::is_valid()) + else if constexpr(BwdMultiDXdlAlgorithm::is_valid()) { - return typename ConvBwdWeightMultiDXdlFactory::Instance{}; + return + typename ConvBwdWeightMultiDXdlFactory::Instance{}; } - else if constexpr (BwdWmmaV3Algorithm::is_valid()) + else if constexpr(BwdWmmaV3Algorithm::is_valid()) { return typename ConvBwdWeightWmmaV3Factory::Instance{}; } - else if constexpr (BwdTwoStageWmmaV3Algorithm::is_valid()) + else if constexpr(BwdTwoStageWmmaV3Algorithm::is_valid()) { - return typename ConvBwdWeightTwoStageWmmaV3Factory::Instance{}; + return typename ConvBwdWeightTwoStageWmmaV3Factory:: + Instance{}; } - else if constexpr (BwdWmmaAlgorithm::is_valid()) + else if constexpr(BwdWmmaAlgorithm::is_valid()) { return typename ConvBwdWeightWmmaFactory::Instance{}; } - else if constexpr (BwdMultiDWmmaV3Algorithm::is_valid()) + else if constexpr(BwdMultiDWmmaV3Algorithm::is_valid()) { - return typename ConvBwdWeightMultiDWmmaV3Factory::Instance{}; + return typename ConvBwdWeightMultiDWmmaV3Factory:: + Instance{}; } else { diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp index 07f0976b74a..31246eb5a8b 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp @@ -24,10 +24,10 @@ template ; - using Types = internal::FwdConvTensorDataTypes; - using Ops = internal::ElementwiseOps; - using AlgorithmType = decltype(ALGORITHM); + using Layouts = internal::ConvTensorLayouts; + using Types = internal::FwdConvTensorDataTypes; + using Ops = internal::ElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization(); static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization(); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp index 77a930d1ce1..9cd56ad7ad2 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp @@ -26,14 +26,13 @@ template ; - using Types = internal::FwdConvTensorDataTypes; - using Ops = internal::ElementwiseOps; - using AlgorithmType = decltype(ALGORITHM); + using Layouts = internal::ConvTensorLayouts; + using Types = internal::FwdConvTensorDataTypes; + using Ops = internal::ElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); - static constexpr auto FWD_CONV_SPECIALIZATION = - internal::SetFwdConvSpecialization(); - static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization(); + static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization(); + static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization(); static constexpr internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION, .gemm_spec = GEMM_SPECIALIZATION}; @@ -45,8 +44,7 @@ struct ConvFwdLargeTensorFactory internal::SetFwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = internal::SetFwdConvBlockTransfer(); - static constexpr auto C_BLOCK_TRANSFER = - internal::SetCBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); // Check limits for the algorithm parameters. static_assert(InputVectorTransferLimits); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp index e34f39965f2..7d889a0c01b 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp @@ -26,10 +26,10 @@ template ; - using Types = internal::FwdConvTensorDataTypes; - using Ops = internal::ElementwiseOps; - using AlgorithmType = decltype(ALGORITHM); + using Layouts = internal::ConvTensorLayouts; + using Types = internal::FwdConvTensorDataTypes; + using Ops = internal::ElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); static_assert(ALGORITHM.transfer.a.lds_transfer.is_direct_load == ALGORITHM.transfer.b.lds_transfer.is_direct_load, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp index dbaa8651eb2..3506f5d1a92 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp @@ -26,10 +26,10 @@ template ; - using Types = internal::FwdConvTensorDataTypes; - using Ops = internal::ElementwiseOps; - using AlgorithmType = decltype(ALGORITHM); + using Layouts = internal::ConvTensorLayouts; + using Types = internal::FwdConvTensorDataTypes; + using Ops = internal::ElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization(); static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization(); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp index 5a0084d6da3..446ceceda25 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp @@ -26,10 +26,10 @@ template ; - using Types = internal::FwdConvTensorDataTypes; - using Ops = internal::ElementwiseOps; - using AlgorithmType = decltype(ALGORITHM); + using Layouts = internal::ConvTensorLayouts; + using Types = internal::FwdConvTensorDataTypes; + using Ops = internal::ElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization(); static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization(); diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp index 69facce41ba..4ef2f533c98 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp @@ -15,11 +15,11 @@ struct BlockTransfer ck::Array thread_cluster_dims{}; // k0, m, k1 ck::Array thread_cluster_order{}; ck::Array src_access_order{}; - size_t src_vector_dim = 0; - size_t src_scalar_per_vector = 0; - size_t lds_dst_scalar_per_vector = 0; - bool is_direct_load = false; - bool lds_padding = false; + size_t src_vector_dim = 0; + size_t src_scalar_per_vector = 0; + size_t lds_dst_scalar_per_vector = 0; + bool is_direct_load = false; + bool lds_padding = false; }; template @@ -28,10 +28,10 @@ struct BwdBlockTransfer ck::Array thread_cluster_dims{}; ck::Array thread_cluster_order{}; ck::Array src_access_order{}; - size_t src_vector_dim = 0; - size_t src_scalar_per_vector = 0; - size_t lds_dst_scalar_per_vector = 0; - bool lds_padding = false; + size_t src_vector_dim = 0; + size_t src_scalar_per_vector = 0; + size_t lds_dst_scalar_per_vector = 0; + bool lds_padding = false; }; template @@ -66,11 +66,13 @@ constexpr auto SetBwdConvBlockTransfer() static_assert(block_order.order.size() == src_order.order.size(), "Mismatched size between block order and src order"); - if constexpr (array_length == 3) + if constexpr(array_length == 3) { return BwdBlockTransfer<3>{ .thread_cluster_dims = {block_xfer.k0, block_xfer.m_n, block_xfer.k1}, - .thread_cluster_order = {block_order.order[0], block_order.order[1], block_order.order[2]}, + .thread_cluster_order = {block_order.order[0], + block_order.order[1], + block_order.order[2]}, .src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2]}, .src_vector_dim = lds_cfg.src_vector_dim, .src_scalar_per_vector = lds_cfg.src_scalar_per_vector, @@ -78,14 +80,23 @@ constexpr auto SetBwdConvBlockTransfer() .lds_padding = lds_cfg.lds_padding, }; } - else if constexpr (array_length == 4) + else if constexpr(array_length == 4) { return BwdBlockTransfer<4>{ - .thread_cluster_dims = {block_xfer.k_batch_size, block_xfer.k0, block_xfer.m_n, block_xfer.k1}, - .thread_cluster_order = {block_order.order[0], block_order.order[1], block_order.order[2], block_order.order[3]}, - .src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2], src_order.order[3]}, - .src_vector_dim = lds_cfg.src_vector_dim, - .src_scalar_per_vector = lds_cfg.src_scalar_per_vector, + .thread_cluster_dims = {block_xfer.k_batch_size, + block_xfer.k0, + block_xfer.m_n, + block_xfer.k1}, + .thread_cluster_order = {block_order.order[0], + block_order.order[1], + block_order.order[2], + block_order.order[3]}, + .src_access_order = {src_order.order[0], + src_order.order[1], + src_order.order[2], + src_order.order[3]}, + .src_vector_dim = lds_cfg.src_vector_dim, + .src_scalar_per_vector = lds_cfg.src_scalar_per_vector, .lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector, .lds_padding = lds_cfg.lds_padding, }; diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp index 00205d414ec..c7d9ce0ac6a 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp @@ -64,19 +64,19 @@ consteval auto GetElementwiseOp() template struct ElementwiseOps { -private: + private: static constexpr auto input_op = GetElementwiseOp(); static constexpr auto weight_op = GetElementwiseOp(); static constexpr auto output_op = GetElementwiseOp(); - static constexpr bool is_forward = ConvDirectionIsForward; + static constexpr bool is_forward = ConvDirectionIsForward; static constexpr bool is_bwd_weight = ConvDirectionIsBackwardWeight; using InputOp = typename decltype(input_op)::Op; using WeightOp = typename decltype(weight_op)::Op; using OutputOp = typename decltype(output_op)::Op; -public: + public: // Forward convolution elementwise ops using AElementwiseOp = std::conditional_t; using BElementwiseOp = std::conditional_t; diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp index fd6df5f09ab..d08dddff83a 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp @@ -222,8 +222,8 @@ template ValidConvOutputLayoutForSpatialDim) struct ConvTensorLayouts { -private: - static constexpr bool is_forward = ConvDirectionIsForward; + private: + static constexpr bool is_forward = ConvDirectionIsForward; static constexpr bool is_bwd_weight = ConvDirectionIsBackwardWeight; using InputLayout = decltype(TensorLayoutToCK()); @@ -231,12 +231,12 @@ struct ConvTensorLayouts using OutputLayout = decltype(TensorLayoutToCK()); using AuxLayout = decltype(GetAuxiliaryTensorLayouts())::type; -public: + public: // Forward convolution layouts - using ALayout = std::conditional_t; - using BLayout = std::conditional_t; - using ELayout = std::conditional_t; - + using ALayout = std::conditional_t; + using BLayout = std::conditional_t; + using ELayout = std::conditional_t; + // Backward weight convolution layouts using InLayout = std::conditional_t; using WeiLayout = std::conditional_t; diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp index 2ab1c40882c..1cecb8d43b7 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp @@ -186,8 +186,8 @@ struct BwdWeightConvTensorDataTypes static constexpr auto output_types = GetTensorDataAndComputeTypes(); - using InDataType = typename decltype(input_types.first)::type; - using InComputeType = typename decltype(input_types.second)::type; + using InDataType = typename decltype(input_types.first)::type; + using InComputeType = typename decltype(input_types.second)::type; using WeiDataType = typename decltype(weight_types.first)::type; using WeiComputeType = typename decltype(weight_types.second)::type; using OutDataType = typename decltype(output_types.first)::type; diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp index dde8e19f434..9ed1eebc3c0 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp @@ -160,17 +160,19 @@ consteval ck::tensor_operation::device::ConvolutionForwardSpecialization SetFwdC } template -consteval ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization SetBwdWeightConvSpecialization() +consteval ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization +SetBwdWeightConvSpecialization() { constexpr auto specialization = ALGORITHM.bwd_weight_specialization; - using ck_conv_spec = ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization; + using ck_conv_spec = ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization; switch(specialization) { case ConvSpecialization::DEFAULT: return ck_conv_spec::Default; case ConvSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0; case ConvSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0; case ConvSpecialization::ODD_C: return ck_conv_spec::OddC; - case ConvSpecialization::FILTER_3x3: throw "FILTER_3x3 is not supported for backward weight convolution."; + case ConvSpecialization::FILTER_3x3: + throw "FILTER_3x3 is not supported for backward weight convolution."; default: throw "Unsupported ConvSpecialization"; } } diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_dl.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_dl.cpp index 99cadb6d20d..584bce2f1bb 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_dl.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_dl.cpp @@ -19,11 +19,11 @@ constexpr auto SIGNATURE = .output = {.config = {.layout = ckb::TensorLayout::GNHWK}}}; constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Dl{} - .with_thread_block(cku::ThreadBlock_256_128x128x16) - .with_bwd_specialization(cku::ConvSpecialization::DEFAULT) - .with_dl_thread_config(cku::DlThreadConfig_16x1x4x4x1) - .with_dl_thread_cluster(cku::DlThreadCluster_8x2) - .with_dl_transfer(cku::DlTransfer5D); + .with_thread_block(cku::ThreadBlock_256_128x128x16) + .with_bwd_specialization(cku::ConvSpecialization::DEFAULT) + .with_dl_thread_config(cku::DlThreadConfig_16x1x4x4x1) + .with_dl_thread_cluster(cku::DlThreadCluster_8x2) + .with_dl_transfer(cku::DlTransfer5D); using Builder = ckb::ConvBuilder; using Instance = Builder::Instance; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_wmma_cshuffle_v3.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_wmma_cshuffle_v3.cpp index e581da49510..782f33f8450 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_wmma_cshuffle_v3.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_wmma_cshuffle_v3.cpp @@ -11,14 +11,13 @@ namespace ckt = ck_tile::builder::test; namespace cku = ck_tile::builder::test_utils; using enum ck_tile::builder::TensorLayout; -constexpr auto SIGNATURE = - ckt::ConvSignature{.spatial_dim = 2, - .direction = ckb::ConvDirection::BACKWARD_WEIGHT, - .data_type = ckb::DataType::FP16, - .accumulation_data_type = ckb::DataType::FP32, - .input = {.config = {.layout = NGCHW}}, - .weight = {.config = {.layout = GKYXC}}, - .output = {.config = {.layout = NGKHW}}}; +constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 2, + .direction = ckb::ConvDirection::BACKWARD_WEIGHT, + .data_type = ckb::DataType::FP16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = NGCHW}}, + .weight = {.config = {.layout = GKYXC}}, + .output = {.config = {.layout = NGKHW}}}; constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3{} .with_thread_block(cku::ThreadBlock_64_32x32x32) @@ -27,7 +26,7 @@ constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStag .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT) .with_block_gemm(cku::BlockGemmDesc_v1_intrawave) .with_num_conv_groups_to_merge(2) - .with_transpose_params(2,2); + .with_transpose_params(2, 2); using Builder = ckb::ConvBuilder; using Instance = Builder::Instance; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp index 9a8b9573fa6..a2a877dbcd4 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp @@ -39,6 +39,6 @@ TEST(BwdWeight_2DBf16_TwoStage_CShuffle, Create) "Default", "GNHWC,GKYXC,GNHWK", "PassThrough,PassThrough,PassThrough", - "Intrawave,v2", // pipeline versions + "Intrawave,v2", // pipeline versions "bf16,bf16,2,4>"}); // compute types and transpose params } diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp index 47e07d07220..0981ea6c11b 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp @@ -11,14 +11,13 @@ namespace ckt = ck_tile::builder::test; namespace cku = ck_tile::builder::test_utils; using enum ck_tile::builder::TensorLayout; -constexpr auto SIGNATURE = - ckt::ConvSignature{.spatial_dim = 3, - .direction = ckb::ConvDirection::BACKWARD_WEIGHT, - .data_type = ckb::DataType::BF16, - .accumulation_data_type = ckb::DataType::FP32, - .input = {.config = {.layout = NGCDHW}}, - .weight = {.config = {.layout = GKZYXC}}, - .output = {.config = {.layout = NGKDHW}}}; +constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 3, + .direction = ckb::ConvDirection::BACKWARD_WEIGHT, + .data_type = ckb::DataType::BF16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = NGCDHW}}, + .weight = {.config = {.layout = GKZYXC}}, + .output = {.config = {.layout = NGKDHW}}}; constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle{} .with_thread_block(cku::ThreadBlock_64_32x32x32) diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp index f58ec11129e..60f7d5bd643 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp @@ -11,22 +11,22 @@ namespace ckt = ck_tile::builder::test; namespace cku = ck_tile::builder::test_utils; using enum ck_tile::builder::TensorLayout; -constexpr auto SIGNATURE = - ckt::ConvSignature{.spatial_dim = 1, - .direction = ckb::ConvDirection::BACKWARD_WEIGHT, - .data_type = ckb::DataType::BF16, - .accumulation_data_type = ckb::DataType::FP32, - .input = {.config = {.layout = NGCW}}, - .weight = {.config = {.layout = GKXC}}, - .output = {.config = {.layout = NGKW}}}; +constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 1, + .direction = ckb::ConvDirection::BACKWARD_WEIGHT, + .data_type = ckb::DataType::BF16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = NGCW}}, + .weight = {.config = {.layout = GKXC}}, + .output = {.config = {.layout = NGKW}}}; -constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3{} - .with_thread_block(cku::ThreadBlock_64_32x32x32) - .with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave) - .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) - .with_bwd_specialization(ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0) - .with_block_gemm(cku::BlockGemmDesc_v1_intrawave) - .with_transpose_params(4,4); +constexpr auto ALGORITHM = + cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3{} + .with_thread_block(cku::ThreadBlock_64_32x32x32) + .with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave) + .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) + .with_bwd_specialization(ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0) + .with_block_gemm(cku::BlockGemmDesc_v1_intrawave) + .with_transpose_params(4, 4); using Builder = ckb::ConvBuilder; using Instance = Builder::Instance; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp index fdc17fba2a8..4ad97209e5e 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp @@ -11,21 +11,21 @@ namespace ckt = ck_tile::builder::test; namespace cku = ck_tile::builder::test_utils; using enum ck_tile::builder::TensorLayout; -constexpr auto SIGNATURE = - ckt::ConvSignature{.spatial_dim = 1, - .direction = ckb::ConvDirection::BACKWARD_WEIGHT, - .data_type = ckb::DataType::BF16, - .accumulation_data_type = ckb::DataType::FP32, - .input = {.config = {.layout = NGCW}}, - .weight = {.config = {.layout = GKXC}}, - .output = {.config = {.layout = NGKW}}}; +constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 1, + .direction = ckb::ConvDirection::BACKWARD_WEIGHT, + .data_type = ckb::DataType::BF16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = NGCW}}, + .weight = {.config = {.layout = GKXC}}, + .output = {.config = {.layout = NGKW}}}; -constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3{} - .with_thread_block(cku::ThreadBlock_64_32x32x32) - .with_gemm_config(cku::BwdGemmParams_Xdl_1x1_per_wave) - .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) - .with_bwd_specialization(ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0) - .with_block_gemm(cku::BlockGemmDesc_v2_intrawave); +constexpr auto ALGORITHM = + cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3{} + .with_thread_block(cku::ThreadBlock_64_32x32x32) + .with_gemm_config(cku::BwdGemmParams_Xdl_1x1_per_wave) + .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) + .with_bwd_specialization(ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0) + .with_block_gemm(cku::BlockGemmDesc_v2_intrawave); using Builder = ckb::ConvBuilder; using Instance = Builder::Instance; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp index 3eca32c8970..8d85370b268 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp @@ -34,7 +34,7 @@ TEST(FwdConvInstances, .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) .with_transfer(Transfer_4x64x1) .with_fwd_specializations(ConvSpecialization::FILTER_1X1_STRIDE1_PAD0, - GemmSpecialization::MNKPadding) + GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v2_intrawave); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp index a8df3c0d98d..e4675c89a7a 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp @@ -25,7 +25,7 @@ TEST(FwdConvInstances, .accumulation_data_type = INT32, .input = {.config = {.layout = GNWC}}, .weight = {.config = {.layout = GKXC}}, - .output = {.config = {.layout = GNWK}}}; + .output = {.config = {.layout = GNWK}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle{} diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp index 0ce530c1f85..610e2fad5fe 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp @@ -67,7 +67,8 @@ TEST(FwdConvInstances, .with_thread_block(ThreadBlock_256_256x256x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) .with_transfer(Transfer_4x64x1) - .with_fwd_specializations(ConvSpecialization::FILTER_3x3, GemmSpecialization::MNKPadding) + .with_fwd_specializations(ConvSpecialization::FILTER_3x3, + GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v5_intrawave); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp index 79d4827feea..58171cd5302 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp @@ -62,7 +62,7 @@ TEST(FwdConvInstances, ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK{} .with_thread_block(ThreadBlock_256_128x128x16) .with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0, - GemmSpecialization::MNKPadding) + GemmSpecialization::MNKPadding) .with_dl_thread_config(DlThreadConfig_16x2x4x4x1) .with_dl_thread_cluster(DlThreadCluster_8x2) .with_dl_transfer(DlTransfer4D); diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp index 5e114cb7614..8ae88a09174 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp @@ -25,7 +25,7 @@ constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xd .with_gemm_config(cku::FwdGemmParams_Xdl_4x4_per_wave) .with_transfer(cku::Transfer_4x64x1) .with_fwd_specializations(ckb::ConvSpecialization::DEFAULT, - ckb::GemmSpecialization::MNKPadding) + ckb::GemmSpecialization::MNKPadding) .with_block_gemm(cku::BlockGemmDesc_v3_intrawave); using Builder = ckb::ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp index a8f97c417e3..bb35c53ba06 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp @@ -30,7 +30,7 @@ TEST(FwdConvInstances, .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) .with_transfer(Transfer_4x64x1) .with_fwd_specializations(ConvSpecialization::FILTER_1X1_STRIDE1_PAD0, - GemmSpecialization::MNKPadding) + GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v4_intrawave); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp index 303875b3489..97bc0a00e5d 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp @@ -26,13 +26,12 @@ TEST(FwdConvInstances, constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{} - .with_thread_block(ThreadBlock_256_256x128x32) - .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) - .with_transfer(Transfer_4x16x1) - .with_fwd_specializations(ConvSpecialization::DEFAULT, - GemmSpecialization::MNKPadding) - .with_prefetch_config(1, PipelineScheduler::DEFAULT) - .with_num_conv_groups_to_merge(1); + .with_thread_block(ThreadBlock_256_256x128x32) + .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) + .with_transfer(Transfer_4x16x1) + .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_prefetch_config(1, PipelineScheduler::DEFAULT) + .with_num_conv_groups_to_merge(1); using Builder = ConvBuilder; @@ -63,13 +62,13 @@ TEST( constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{} - .with_thread_block(ThreadBlock_128_128x128x32) - .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) - .with_transfer(Transfer_4x16x1) - .with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0, - GemmSpecialization::MNKPadding) - .with_prefetch_config(1, PipelineScheduler::DEFAULT) - .with_num_conv_groups_to_merge(1); + .with_thread_block(ThreadBlock_128_128x128x32) + .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) + .with_transfer(Transfer_4x16x1) + .with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0, + GemmSpecialization::MNKPadding) + .with_prefetch_config(1, PipelineScheduler::DEFAULT) + .with_num_conv_groups_to_merge(1); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp index 46b82b2407b..56d4b8be590 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp @@ -31,7 +31,7 @@ TEST(FwdConvInstances, .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) .with_transfer(Transfer_4x64x1) .with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0, - GemmSpecialization::MNKPadding) + GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v4_intrawave); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp index 6b69ca89a69..df8339241bc 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp @@ -31,7 +31,7 @@ TEST(FwdConvInstances, .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) .with_transfer(Transfer_4x64x1) .with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0, - GemmSpecialization::MNKPadding) + GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v1_intrawave); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp index 677a3043715..b79fdf513a7 100644 --- a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp @@ -10,13 +10,14 @@ using namespace ck_tile::builder::test_utils; TEST(BwdDataConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC) { - constexpr ConvSignature BwdDataConvSignature{.spatial_dim = 2, - .direction = ConvDirection::BACKWARD_DATA, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NHWGC}}, - .weight = {.config = {.layout = TensorLayout::GKYXC}}, - .output = {.config = {.layout = TensorLayout::NHWGK}}}; + constexpr ConvSignature BwdDataConvSignature{ + .spatial_dim = 2, + .direction = ConvDirection::BACKWARD_DATA, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::NHWGC}}, + .weight = {.config = {.layout = TensorLayout::GKYXC}}, + .output = {.config = {.layout = TensorLayout::NHWGK}}}; constexpr auto BwdDataConvAlgorithm = ConvAlgorithm_Tile_GroupedConvolutionKernel{} diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp index f3de0bb762b..a5801b0e853 100644 --- a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp @@ -10,13 +10,14 @@ using namespace ck_tile::builder::test_utils; TEST(BwdWeightConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC) { - constexpr ConvSignature BwdWeightConvSignature{.spatial_dim = 2, - .direction = ConvDirection::BACKWARD_WEIGHT, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NHWGC}}, - .weight = {.config = {.layout = TensorLayout::GKYXC}}, - .output = {.config = {.layout = TensorLayout::NHWGK}}}; + constexpr ConvSignature BwdWeightConvSignature{ + .spatial_dim = 2, + .direction = ConvDirection::BACKWARD_WEIGHT, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::NHWGC}}, + .weight = {.config = {.layout = TensorLayout::GKYXC}}, + .output = {.config = {.layout = TensorLayout::NHWGK}}}; constexpr auto BwdWeightConvAlgorithm = ConvAlgorithm_Tile_GroupedConvolutionKernel{} diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index 57ef018ae32..797f440bdce 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -41,15 +41,15 @@ static_assert(ckb::GridwiseXdlGemmDescriptor); struct GridwiseFwdXdlGemm { // NOTE: ak1 and bk1 are difficult to verify in the kernel instantiation!!! - size_t ak1 = 0; - size_t bk1 = 0; + size_t ak1 = 0; + size_t bk1 = 0; XdlParams xdl_params; }; static_assert(ckb::GridwiseFwdXdlGemmDescriptor); struct GridwiseBwdXdlGemm { - size_t k1 = 0; + size_t k1 = 0; XdlParams xdl_params; }; static_assert(ckb::GridwiseBwdXdlGemmDescriptor); @@ -284,17 +284,20 @@ struct DlTransfer_ struct TwoStageSpecialization_ { - static constexpr ConvAlgorithmSpecialization specialization = ConvAlgorithmSpecialization::TWO_STAGE; + static constexpr ConvAlgorithmSpecialization specialization = + ConvAlgorithmSpecialization::TWO_STAGE; }; struct MultipleDSpecialization_ { - static constexpr ConvAlgorithmSpecialization specialization = ConvAlgorithmSpecialization::MULTIPLE_D; + static constexpr ConvAlgorithmSpecialization specialization = + ConvAlgorithmSpecialization::MULTIPLE_D; }; struct LargeTensorSpecialization_ { - static constexpr ConvAlgorithmSpecialization specialization = ConvAlgorithmSpecialization::LARGE_TENSOR; + static constexpr ConvAlgorithmSpecialization specialization = + ConvAlgorithmSpecialization::LARGE_TENSOR; }; // Specify thread block dimensions for a GEMM (CK Tile). @@ -395,7 +398,8 @@ struct ConvAlgorithmTemplate : Components... { result.gridwise_gemm = gemm; } - else { + else + { static_assert(false, "Unrecognized GemmConfig type"); } return result; @@ -412,7 +416,7 @@ struct ConvAlgorithmTemplate : Components... } constexpr auto with_fwd_specializations(ConvSpecialization fwd_spec, - GemmSpecialization gemm_spec) const + GemmSpecialization gemm_spec) const { static_assert(std::is_base_of_v); auto result = *this; @@ -424,7 +428,7 @@ struct ConvAlgorithmTemplate : Components... constexpr auto with_bwd_specialization(ConvSpecialization bwd_spec) const { static_assert(std::is_base_of_v); - auto result = *this; + auto result = *this; result.bwd_weight_specialization = bwd_spec; return result; } @@ -442,7 +446,7 @@ struct ConvAlgorithmTemplate : Components... size_t max_dst_scalar_per_vector) const { static_assert(std::is_base_of_v); - auto result = *this; + auto result = *this; result.max_transpose_transfer_src_scalar_per_vector = max_src_scalar_per_vector; result.max_transpose_transfer_dst_scalar_per_vector = max_dst_scalar_per_vector; return result; @@ -451,7 +455,7 @@ struct ConvAlgorithmTemplate : Components... constexpr auto with_num_conv_groups_to_merge(size_t num_groups_to_merge) const { static_assert(std::is_base_of_v); - auto result = *this; + auto result = *this; result.num_conv_groups_to_merge = num_groups_to_merge; return result; } @@ -460,7 +464,7 @@ struct ConvAlgorithmTemplate : Components... constexpr auto with_block_gemm(const BG& bg) const { static_assert(std::is_base_of_v); - auto result = *this; + auto result = *this; result.block_gemm_pipeline = bg; return result; } @@ -468,7 +472,7 @@ struct ConvAlgorithmTemplate : Components... constexpr auto with_gridwise_gemm_pipeline(const PipelineVersion plv) const { static_assert(std::is_base_of_v); - auto result = *this; + auto result = *this; result.pipeline_version = plv; return result; } @@ -550,13 +554,28 @@ struct ConvAlgorithmTemplate : Components... // Fwd algorithm types using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = - ConvAlgorithmTemplate, ConvSpecializationFwd_, Prefetch_, GemmBatchOptions_>; + ConvAlgorithmTemplate, + ConvSpecializationFwd_, + Prefetch_, + GemmBatchOptions_>; using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 = - ConvAlgorithmTemplate, ConvSpecializationFwd_, BlockGemm_>; + ConvAlgorithmTemplate, + ConvSpecializationFwd_, + BlockGemm_>; using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle = - ConvAlgorithmTemplate, ConvSpecializationFwd_, GridGemm_, Prefetch_, GemmBatchOptions_>; + ConvAlgorithmTemplate, + ConvSpecializationFwd_, + GridGemm_, + Prefetch_, + GemmBatchOptions_>; using ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK = ConvAlgorithmTemplate>; using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor = - ConvAlgorithmTemplate, ConvSpecializationFwd_, Prefetch_, GemmBatchOptions_, LargeTensorSpecialization_>; + ConvAlgorithmTemplate, + ConvSpecializationFwd_, + Prefetch_, + GemmBatchOptions_, + LargeTensorSpecialization_>; // CK Tile algorithm using ConvAlgorithm_Tile_GroupedConvolutionKernel = ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, TransposeParams_>; - -using ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle = - ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, BlockGemm_, TransposeParams_, GemmBatchOptions_, TwoStageSpecialization_>; +using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle = + ConvAlgorithmTemplate, + ConvSpecializationBwdWeight_, + TransposeParams_>; -using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 = - ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, BlockGemm_>; +using ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle = + ConvAlgorithmTemplate, + ConvSpecializationBwdWeight_, + BlockGemm_, + TransposeParams_, + GemmBatchOptions_, + TwoStageSpecialization_>; + +using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 = + ConvAlgorithmTemplate, + ConvSpecializationBwdWeight_, + BlockGemm_>; using ConvAlgorithm_DeviceGroupedConvBwdWeight_Dl = - ConvAlgorithmTemplate, ConvSpecializationBwdWeight_>; - -using ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle = - ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, MultipleDSpecialization_>; - -using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3 = - ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, BlockGemm_, TransposeParams_>; + ConvAlgorithmTemplate, + ConvSpecializationBwdWeight_>; -using ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3 = - ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, BlockGemm_, TransposeParams_, GemmBatchOptions_, TwoStageSpecialization_>; +using ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle = + ConvAlgorithmTemplate, + ConvSpecializationBwdWeight_, + MultipleDSpecialization_>; -using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle = - ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, GridGemm_, Prefetch_>; +using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3 = + ConvAlgorithmTemplate, + ConvSpecializationBwdWeight_, + BlockGemm_, + TransposeParams_>; -using ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_V3 = - ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, BlockGemm_, MultipleDSpecialization_>; +using ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3 = + ConvAlgorithmTemplate, + ConvSpecializationBwdWeight_, + BlockGemm_, + TransposeParams_, + GemmBatchOptions_, + TwoStageSpecialization_>; + +using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle = + ConvAlgorithmTemplate, + ConvSpecializationBwdWeight_, + GridGemm_, + Prefetch_>; +using ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_V3 = + ConvAlgorithmTemplate, + ConvSpecializationBwdWeight_, + BlockGemm_, + MultipleDSpecialization_>; } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/test_concept_diagnostics_sync.cpp b/experimental/builder/test/test_concept_diagnostics_sync.cpp index 0bc786dbdf5..ca08ae92ed7 100644 --- a/experimental/builder/test/test_concept_diagnostics_sync.cpp +++ b/experimental/builder/test/test_concept_diagnostics_sync.cpp @@ -22,40 +22,40 @@ namespace ck_tile::builder::test { -using ck_tile::builder::ThreadBlockDescriptor; -using ck_tile::builder::GridwiseXdlGemmDescriptor; -using ck_tile::builder::BlockTransferDescriptor; -using ck_tile::builder::ThreadClusterDescriptor; -using ck_tile::builder::LdsTransferDescriptor; -using ck_tile::builder::EpilogueDescriptor; using ck_tile::builder::AccessOrderDescriptor; using ck_tile::builder::BlockGemmDescriptor; -using ck_tile::builder::GridwiseWmmaGemmDescriptor; -using ck_tile::builder::TileThreadBlockDescriptor; -using ck_tile::builder::TileTransferDescriptor; -using ck_tile::builder::TileBlockGemmDescriptor; -using ck_tile::builder::TileOptimizationsDescriptor; -using ck_tile::builder::DlThreadConfigDescriptor; -using ck_tile::builder::DlThreadClusterDescriptor; +using ck_tile::builder::BlockTransferDescriptor; +using ck_tile::builder::ConvAlgorithmDescriptor; using ck_tile::builder::DlBlockTransferDescriptor; using ck_tile::builder::DlEpilogueDescriptor; -using ck_tile::builder::ConvAlgorithmDescriptor; -using ck_tile::builder::SpecifiesThreadBlock; -using ck_tile::builder::SpecifiesGridwiseFwdXdlGemm; -using ck_tile::builder::SpecifiesGridwiseBwdXdlGemm; +using ck_tile::builder::DlThreadClusterDescriptor; +using ck_tile::builder::DlThreadConfigDescriptor; +using ck_tile::builder::EpilogueDescriptor; +using ck_tile::builder::GridwiseWmmaGemmDescriptor; +using ck_tile::builder::GridwiseXdlGemmDescriptor; +using ck_tile::builder::LdsTransferDescriptor; using ck_tile::builder::SpecifiesBlockGemm; -using ck_tile::builder::SpecifiesFwdConvSpecialization; using ck_tile::builder::SpecifiesBwdWeightConvSpecialization; +using ck_tile::builder::SpecifiesDlThreadCluster; +using ck_tile::builder::SpecifiesDlThreadConfig; +using ck_tile::builder::SpecifiesFwdConvSpecialization; using ck_tile::builder::SpecifiesGemmSpecialization; -using ck_tile::builder::SpecifiesNumPrefetchStages; +using ck_tile::builder::SpecifiesGridwiseBwdXdlGemm; +using ck_tile::builder::SpecifiesGridwiseFwdXdlGemm; using ck_tile::builder::SpecifiesLoopScheduler; -using ck_tile::builder::SpecifiesTileThreadBlock; -using ck_tile::builder::SpecifiesTileTransfer; +using ck_tile::builder::SpecifiesNumPrefetchStages; +using ck_tile::builder::SpecifiesThreadBlock; using ck_tile::builder::SpecifiesTileBlockGemm; -using ck_tile::builder::SpecifiesTileOptimizations; using ck_tile::builder::SpecifiesTileConvSpecialization; -using ck_tile::builder::SpecifiesDlThreadConfig; -using ck_tile::builder::SpecifiesDlThreadCluster; +using ck_tile::builder::SpecifiesTileOptimizations; +using ck_tile::builder::SpecifiesTileThreadBlock; +using ck_tile::builder::SpecifiesTileTransfer; +using ck_tile::builder::ThreadBlockDescriptor; +using ck_tile::builder::ThreadClusterDescriptor; +using ck_tile::builder::TileBlockGemmDescriptor; +using ck_tile::builder::TileOptimizationsDescriptor; +using ck_tile::builder::TileThreadBlockDescriptor; +using ck_tile::builder::TileTransferDescriptor; // Helper to check if a string contains a substring bool contains(const std::string& str, const std::string& substr) @@ -331,18 +331,24 @@ TEST(ConceptDiagnosticsSync, LdsTransferDescriptor_Invalid) TEST(ConceptDiagnosticsSync, CompleteAlgorithmTypes) { // Test that complete algorithm types satisfy their concepts - static_assert(ConvAlgorithmDescriptor); - static_assert(ConvAlgorithmDescriptor); - static_assert(ConvAlgorithmDescriptor); + static_assert( + ConvAlgorithmDescriptor); + static_assert( + ConvAlgorithmDescriptor); + static_assert( + ConvAlgorithmDescriptor); static_assert(ConvAlgorithmDescriptor); static_assert(ConvAlgorithmDescriptor); - + // Test specific requirements for each algorithm type static_assert(SpecifiesThreadBlock); - static_assert(SpecifiesGridwiseFwdXdlGemm); - static_assert(SpecifiesFwdConvSpecialization); - static_assert(SpecifiesNumPrefetchStages); - + static_assert( + SpecifiesGridwiseFwdXdlGemm); + static_assert( + SpecifiesFwdConvSpecialization); + static_assert( + SpecifiesNumPrefetchStages); + static_assert(SpecifiesTileThreadBlock); static_assert(SpecifiesTileBlockGemm); static_assert(SpecifiesTileOptimizations); @@ -356,9 +362,12 @@ TEST(ConceptDiagnosticsSync, DiagnosticMessages) { // Test that diagnostics can be called (even if messages may be empty at compile-time) // The key is that the diagnostic functions exist and compile - std::string diag1 = ck_tile::builder::diagnostics::detailed_diagnostic_SpecifiesThreadBlock(); - std::string diag2 = ck_tile::builder::diagnostics::detailed_diagnostic_SpecifiesGridwiseFwdXdlGemm(); - + std::string diag1 = ck_tile::builder::diagnostics::detailed_diagnostic_SpecifiesThreadBlock< + invalid_types::MissingBlockSize>(); + std::string diag2 = + ck_tile::builder::diagnostics::detailed_diagnostic_SpecifiesGridwiseFwdXdlGemm< + invalid_types::MissingMPerXdl>(); + // These may be empty depending on the implementation, but they should compile EXPECT_TRUE(diag1.empty() || contains(diag1, "thread_block") || contains(diag1, "missing")); EXPECT_TRUE(diag2.empty() || contains(diag2, "gridwise_gemm") || contains(diag2, "missing")); @@ -370,7 +379,7 @@ TEST(ConceptDiagnosticsSync, DiagnosticMessages) /** * @brief Verify that all concepts defined in conv_algorithm_concepts.hpp have tests - * + * * This test serves as documentation of which concepts are tested. If new concepts * are added, this test should be updated to include them. */ @@ -386,13 +395,13 @@ TEST(ConceptDiagnosticsSync, ConceptCoverage) EXPECT_TRUE((LdsTransferDescriptor)); EXPECT_TRUE((EpilogueDescriptor)); EXPECT_TRUE((AccessOrderDescriptor)); - - // Tile Descriptor Concepts + + // Tile Descriptor Concepts EXPECT_TRUE((TileThreadBlockDescriptor)); EXPECT_TRUE((TileTransferDescriptor)); EXPECT_TRUE((TileBlockGemmDescriptor)); EXPECT_TRUE((TileOptimizationsDescriptor)); - + // DL Descriptor Concepts EXPECT_TRUE((DlThreadConfigDescriptor)); EXPECT_TRUE((DlThreadClusterDescriptor)); diff --git a/experimental/builder/test/test_conv_description.cpp b/experimental/builder/test/test_conv_description.cpp index 8808f9bee7c..498de9a42fd 100644 --- a/experimental/builder/test/test_conv_description.cpp +++ b/experimental/builder/test/test_conv_description.cpp @@ -120,14 +120,10 @@ struct DefaultAlgorithm ckb::test::ThreadBlock thread_block{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; - ckb::test::GridwiseFwdXdlGemm gridwise_gemm{.ak1 = 8, - .bk1 = 8, - .xdl_params = - { - .m_per_xdl = 16, - .n_per_xdl = 16, - .m_xdl_per_wave = 8, - .n_xdl_per_wave = 8}}; + ckb::test::GridwiseFwdXdlGemm gridwise_gemm{ + .ak1 = 8, + .bk1 = 8, + .xdl_params = {.m_per_xdl = 16, .n_per_xdl = 16, .m_xdl_per_wave = 8, .n_xdl_per_wave = 8}}; ckb::test::Transfer<> transfer{ .a = @@ -163,8 +159,8 @@ struct DefaultAlgorithm }, }; - ckb::ConvSpecialization fwd_specialization = ckb::ConvSpecialization::DEFAULT; - ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::Default; + ckb::ConvSpecialization fwd_specialization = ckb::ConvSpecialization::DEFAULT; + ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::Default; ckb::test::BlockGemm block_gemm{.pipeline_version = ckb::PipelineVersion::V4, .scheduler = ckb::PipelineScheduler::INTRAWAVE}; }; diff --git a/experimental/builder/test/utils/ckb_conv_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_test_configs.hpp index e1f5a34e20c..3b83ead2d0d 100644 --- a/experimental/builder/test/utils/ckb_conv_test_configs.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_configs.hpp @@ -20,37 +20,35 @@ constexpr DlThreadConfig DlThreadConfig_16x1x4x4x1{ constexpr DlThreadCluster DlThreadCluster_8x2{.m1_xs = {8, 2}, .n1_xs = {8, 2}}; -constexpr DlBlockTransfer<4> DlBlockTransfer_8x1x1x2 - {.thread_slice_lengths = {8, 1, 1, 2}, - .thread_cluster_lengths = {2, 1, 128, 1}, - .thread_cluster_arrange_order = {1, 2, 0, 3}, - .src_access_order = {1, 2, 0, 3}, - .src_vector_tensor_lengths = {4, 1, 1, 2}, - .src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3}, - .dst_vector_tensor_lengths = {1, 1, 1, 2}}; - -constexpr DlTransfer<4> DlTransfer4D {.a = DlBlockTransfer_8x1x1x2, - .b = DlBlockTransfer_8x1x1x2, - .c = { - .src_dst_access_order = {0, 1, 2, 3, 4, 5}, - .src_dst_vector_dim = 5, - .dst_scalar_per_vector = 4}}; - -constexpr DlBlockTransfer<5> DlBlockTransfer_1x8x1x1x1 - {.thread_slice_lengths = {1, 8, 1, 1, 1}, - .thread_cluster_lengths = {1, 2, 1, 128, 1}, - .thread_cluster_arrange_order = {0, 2, 3, 1, 4}, - .src_access_order = {0, 2, 3, 1, 4}, - .src_vector_tensor_lengths = {1, 1, 1, 1, 1}, - .src_vector_tensor_contiguous_dim_order = {0, 2, 3, 1, 4}, - .dst_vector_tensor_lengths = {1, 1, 1, 1, 1}}; - -constexpr DlTransfer<5> DlTransfer5D {.a = DlBlockTransfer_1x8x1x1x1, - .b = DlBlockTransfer_1x8x1x1x1, - .c = { - .src_dst_access_order = {0, 1, 2, 3, 4, 5}, - .src_dst_vector_dim = 5, - .dst_scalar_per_vector = 1}}; +constexpr DlBlockTransfer<4> DlBlockTransfer_8x1x1x2{ + .thread_slice_lengths = {8, 1, 1, 2}, + .thread_cluster_lengths = {2, 1, 128, 1}, + .thread_cluster_arrange_order = {1, 2, 0, 3}, + .src_access_order = {1, 2, 0, 3}, + .src_vector_tensor_lengths = {4, 1, 1, 2}, + .src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3}, + .dst_vector_tensor_lengths = {1, 1, 1, 2}}; + +constexpr DlTransfer<4> DlTransfer4D{.a = DlBlockTransfer_8x1x1x2, + .b = DlBlockTransfer_8x1x1x2, + .c = {.src_dst_access_order = {0, 1, 2, 3, 4, 5}, + .src_dst_vector_dim = 5, + .dst_scalar_per_vector = 4}}; + +constexpr DlBlockTransfer<5> DlBlockTransfer_1x8x1x1x1{ + .thread_slice_lengths = {1, 8, 1, 1, 1}, + .thread_cluster_lengths = {1, 2, 1, 128, 1}, + .thread_cluster_arrange_order = {0, 2, 3, 1, 4}, + .src_access_order = {0, 2, 3, 1, 4}, + .src_vector_tensor_lengths = {1, 1, 1, 1, 1}, + .src_vector_tensor_contiguous_dim_order = {0, 2, 3, 1, 4}, + .dst_vector_tensor_lengths = {1, 1, 1, 1, 1}}; + +constexpr DlTransfer<5> DlTransfer5D{.a = DlBlockTransfer_1x8x1x1x1, + .b = DlBlockTransfer_1x8x1x1x1, + .c = {.src_dst_access_order = {0, 1, 2, 3, 4, 5}, + .src_dst_vector_dim = 5, + .dst_scalar_per_vector = 1}}; constexpr Transfer<> Transfer_4x64x1{ .a = @@ -252,40 +250,38 @@ constexpr Transfer<> Transfer_4x32x1{ }; constexpr GridwiseBwdXdlGemm BwdGemmParams_Xdl_4x4_per_wave{ - .k1 = 8, + .k1 = 8, .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}}; constexpr GridwiseBwdXdlGemm BwdGemmParams_Xdl_1x1_per_wave{ - .k1 = 8, + .k1 = 8, .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 1, .n_xdl_per_wave = 1}}; constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_4x4_per_wave{ - .ak1 = 8, .bk1 = 8, + .ak1 = 8, + .bk1 = 8, .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}}; constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_4x2_per_wave{ - .ak1 = 8, .bk1 = 8, + .ak1 = 8, + .bk1 = 8, .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 2}}; constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_2x2_per_wave{ - .ak1 = 8, .bk1 = 8, + .ak1 = 8, + .bk1 = 8, .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2}}; constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_2x1_per_wave{ - .ak1 = 8, .bk1 = 8, + .ak1 = 8, + .bk1 = 8, .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 1}}; -constexpr GridwiseWmmaGemm GemmParams_Wmma_2x1_per_wave{.k1 = 8, - .m_per_wmma = 32, - .n_per_wmma = 32, - .m_wmma_per_wave = 2, - .n_wmma_per_wave = 1}; +constexpr GridwiseWmmaGemm GemmParams_Wmma_2x1_per_wave{ + .k1 = 8, .m_per_wmma = 32, .n_per_wmma = 32, .m_wmma_per_wave = 2, .n_wmma_per_wave = 1}; -constexpr GridwiseWmmaGemm GemmParams_Wmma_16x16_2x1_per_wave{.k1 = 8, - .m_per_wmma = 16, - .n_per_wmma = 16, - .m_wmma_per_wave = 2, - .n_wmma_per_wave = 1}; +constexpr GridwiseWmmaGemm GemmParams_Wmma_16x16_2x1_per_wave{ + .k1 = 8, .m_per_wmma = 16, .n_per_wmma = 16, .m_wmma_per_wave = 2, .n_wmma_per_wave = 1}; constexpr ThreadBlock ThreadBlock_256_256x256x32{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; @@ -300,13 +296,13 @@ constexpr ThreadBlock ThreadBlock_256_128x128x16{.block_size = 256, .tile_size = {.m = 128, .n = 128, .k = 16}}; constexpr ThreadBlock ThreadBlock_256_128x128x8{.block_size = 256, - .tile_size = {.m = 128, .n = 128, .k = 8}}; + .tile_size = {.m = 128, .n = 128, .k = 8}}; constexpr ThreadBlock ThreadBlock_64_64x32x32{.block_size = 64, .tile_size = {.m = 64, .n = 32, .k = 32}}; constexpr ThreadBlock ThreadBlock_64_32x32x32{.block_size = 64, - .tile_size = {.m = 32, .n = 32, .k = 32}}; + .tile_size = {.m = 32, .n = 32, .k = 32}}; constexpr ThreadBlock ThreadBlock_128_128x128x32{.block_size = 128, .tile_size = {.m = 128, .n = 128, .k = 32}}; @@ -314,19 +310,19 @@ constexpr ThreadBlock ThreadBlock_128_128x128x32{.block_size = 128, constexpr ThreadBlock ThreadBlock_128_64x64x64{.block_size = 128, .tile_size = {.m = 64, .n = 64, .k = 64}}; -constexpr BlockGemmPipeline BlockGemmDesc_v1_intrawave = {.pipeline_version = PipelineVersion::V1, - .scheduler = PipelineScheduler::INTRAWAVE}; +constexpr BlockGemmPipeline BlockGemmDesc_v1_intrawave = { + .pipeline_version = PipelineVersion::V1, .scheduler = PipelineScheduler::INTRAWAVE}; -constexpr BlockGemmPipeline BlockGemmDesc_v2_intrawave = {.pipeline_version = PipelineVersion::V2, - .scheduler = PipelineScheduler::INTRAWAVE}; +constexpr BlockGemmPipeline BlockGemmDesc_v2_intrawave = { + .pipeline_version = PipelineVersion::V2, .scheduler = PipelineScheduler::INTRAWAVE}; -constexpr BlockGemmPipeline BlockGemmDesc_v3_intrawave = {.pipeline_version = PipelineVersion::V3, - .scheduler = PipelineScheduler::INTRAWAVE}; +constexpr BlockGemmPipeline BlockGemmDesc_v3_intrawave = { + .pipeline_version = PipelineVersion::V3, .scheduler = PipelineScheduler::INTRAWAVE}; -constexpr BlockGemmPipeline BlockGemmDesc_v4_intrawave = {.pipeline_version = PipelineVersion::V4, - .scheduler = PipelineScheduler::INTRAWAVE}; +constexpr BlockGemmPipeline BlockGemmDesc_v4_intrawave = { + .pipeline_version = PipelineVersion::V4, .scheduler = PipelineScheduler::INTRAWAVE}; -constexpr BlockGemmPipeline BlockGemmDesc_v5_intrawave = {.pipeline_version = PipelineVersion::V5, - .scheduler = PipelineScheduler::INTRAWAVE}; +constexpr BlockGemmPipeline BlockGemmDesc_v5_intrawave = { + .pipeline_version = PipelineVersion::V5, .scheduler = PipelineScheduler::INTRAWAVE}; } // namespace ck_tile::builder::test_utils diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index 6ad61797800..b705ab1e471 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -98,8 +98,8 @@ template <> inline std::string to_string(GridwiseFwdXdlGemm t) { std::ostringstream oss; - oss << t.ak1 << "," << t.bk1 << "," << t.xdl_params.m_per_xdl << "," << t.xdl_params.n_per_xdl << "," - << t.xdl_params.m_xdl_per_wave << "," << t.xdl_params.n_xdl_per_wave; + oss << t.ak1 << "," << t.bk1 << "," << t.xdl_params.m_per_xdl << "," << t.xdl_params.n_per_xdl + << "," << t.xdl_params.m_xdl_per_wave << "," << t.xdl_params.n_xdl_per_wave; return oss.str(); } @@ -123,15 +123,15 @@ inline std::string to_string(BlockGemmPipeline t) template inline std::string to_string(BlockTransfer t) { - if constexpr (ThreadSliceDim == 4) + if constexpr(ThreadSliceDim == 4) { return array_to_seq(std::array{t.k_batch_size, t.k0, t.m_n, t.k1}); } - else if constexpr (ThreadSliceDim == 3) + else if constexpr(ThreadSliceDim == 3) { return array_to_seq(std::array{t.k0, t.m_n, t.k1}); } - else + else { static_assert(ThreadSliceDim == 3 || ThreadSliceDim == 4, "Unsupported ThreadSliceDim"); } @@ -160,7 +160,7 @@ inline std::string to_string(AccessOrder t) return array_to_seq(t.order); } -template +template inline std::string to_string(InputTransfer t) { std::ostringstream oss; @@ -314,8 +314,7 @@ template <> inline std::string to_string(Prefetch_ t) { std::ostringstream oss; - oss << t.num_gemm_k_prefetch_stages << "," - << to_string(t.loop_scheduler); + oss << t.num_gemm_k_prefetch_stages << "," << to_string(t.loop_scheduler); return oss.str(); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp index dbce8e8ccf6..80e4ea983e1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp @@ -861,30 +861,32 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 { if(gemm_arg.KBatch > 1) { - const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< - GridwiseGemm, - remove_reference_t, - remove_reference_t, - remove_reference_t< - DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy>; + const auto kernel = + kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; Run(kernel); } else { - const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< - GridwiseGemm, - remove_reference_t, - remove_reference_t, - remove_reference_t< - DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy>; + const auto kernel = + kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; Run(kernel); } } @@ -900,30 +902,32 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 { if(gemm_arg.KBatch > 1) { - const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< - GridwiseGemm, - remove_reference_t, - remove_reference_t, - remove_reference_t< - DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, - ComputePtrOffsetOfStridedBatch, - false, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy>; + const auto kernel = + kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; Run(kernel); } else { - const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< - GridwiseGemm, - remove_reference_t, - remove_reference_t, - remove_reference_t< - DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, - ComputePtrOffsetOfStridedBatch, - false, - InMemoryDataOperationEnum::Set, - minimum_occupancy>; + const auto kernel = + kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; Run(kernel); } } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp index 0910da35a6c..d3bf2a364a0 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -52,19 +52,20 @@ __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batched_gemm_xdlops_bwd_weight_multiple_d(const FloatA* __restrict__ p_a_grid, - const FloatB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const index_t batch_count, - const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, - const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2CTileMap block_2_ctile_map, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) + kernel_batched_gemm_xdlops_bwd_weight_multiple_d( + const FloatA* __restrict__ p_a_grid, + const FloatB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const index_t batch_count, + const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2CTileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) From 201039646e1d514de5a6dd9345372952ebf0a352 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Mon, 5 Jan 2026 05:32:09 -0500 Subject: [PATCH 50/81] Move compile-time diagnostics to a separate branch. --- .../builder/conv_algorithm_diagnostics.hpp | 1769 ----------------- .../builder/factory/conv_algorithms.hpp | 981 +-------- .../builder/factory/conv_dispatcher.hpp | 45 +- 3 files changed, 102 insertions(+), 2693 deletions(-) delete mode 100644 experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp deleted file mode 100644 index 6db35a9ba0f..00000000000 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp +++ /dev/null @@ -1,1769 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include "ck_tile/builder/conv_algorithm_concepts.hpp" - -namespace ck_tile::builder::diagnostics { - -#define CHECK_MARK(cond) (cond ? "[✓]" : "[✗]") - -// Macro to check a concept and generate both the boolean and the string representation -#define CHECK_CONCEPT(Type, Concept) \ - static constexpr bool c_##Concept = Concept; \ - static constexpr const char* s_##Concept = #Concept; - -// Helper to create diagnostic message line -#define DIAGNOSTIC_LINE(Concept) \ - " " + std::string(s_##Concept) + ": " + std::string(CHECK_MARK(c_##Concept)) + "\n" + \ - (c_##Concept ? std::string("") : detailed_diagnostic_##Concept()) - -namespace detail { - -// Helper to get type information -template -consteval auto get_type_info() -> const char* -{ - // Returns a descriptive string about the type - if constexpr(std::is_same_v) - { - return " (type: size_t)"; - } - else if constexpr(std::is_same_v) - { - return " (type: int)"; - } - else if constexpr(std::is_same_v) - { - return " (type: bool)"; - } - else if constexpr(std::is_same_v) - { - return " (type: PipelineVersion)"; - } - else if constexpr(std::is_same_v) - { - return " (type: PipelineScheduler)"; - } - else if constexpr(std::is_same_v) - { - return " (type: ConvSpecialization)"; - } - else if constexpr(std::is_same_v) - { - return " (type: GemmSpecialization)"; - } - else if constexpr(std::is_same_v) - { - return " (type: TileConvSpecialization)"; - } - else if constexpr(std::is_same_v) - { - return " (type: ConvAlgorithmSpecialization)"; - } - else if constexpr(std::is_same_v>) - { - return " (type: std::array)"; - } - else if constexpr(std::is_same_v>) - { - return " (type: std::array)"; - } - else if constexpr(std::is_same_v>) - { - return " (type: std::array)"; - } - else if constexpr(std::is_same_v>) - { - return " (type: std::array)"; - } - else - { - return " (type: found but unknown)"; - } -} - -// ThreadBlockDescriptor diagnostics -template -consteval auto diagnose_thread_block_descriptor() -> std::string -{ - if constexpr(!requires { T::thread_block; }) - { - return " → T::thread_block member: [✗] (missing member)\n"; - } - else - { - using TB = decltype(T::thread_block); - std::string msg; - - if constexpr(requires(TB t) { t.block_size; }) - { - using BlockSizeType = decltype(std::declval().block_size); - constexpr bool convertible = SizeType; - msg += " → thread_block.block_size: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(get_type_info())) + "\n"; - } - else - { - msg += " → thread_block.block_size: [✗] (missing member)\n"; - } - - if constexpr(requires(TB t) { t.tile_size.m; }) - { - using TileMType = decltype(std::declval().tile_size.m); - constexpr bool convertible = SizeType; - msg += " → thread_block.tile_size.m: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(get_type_info())) + "\n"; - } - else - { - msg += " → thread_block.tile_size.m: [✗] (missing member)\n"; - } - - if constexpr(requires(TB t) { t.tile_size.n; }) - { - using TileNType = decltype(std::declval().tile_size.n); - constexpr bool convertible = SizeType; - msg += " → thread_block.tile_size.n: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(get_type_info())) + "\n"; - } - else - { - msg += " → thread_block.tile_size.n: [✗] (missing member)\n"; - } - - if constexpr(requires(TB t) { t.tile_size.k; }) - { - using TileKType = decltype(std::declval().tile_size.k); - constexpr bool convertible = SizeType; - msg += " → thread_block.tile_size.k: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(get_type_info())) + "\n"; - } - else - { - msg += " → thread_block.tile_size.k: [✗] (missing member)\n"; - } - - return msg; - } -} - -// GridwiseXdlGemmDescriptor diagnostics -template -consteval auto diagnose_xdl_params() -> std::string -{ - std::string msg; - - if constexpr(requires(XdlParams t) { t.m_per_xdl; }) - { - using MPerXdlType = decltype(std::declval().m_per_xdl); - constexpr bool convertible = SizeType; - msg += " → xdl_params.m_per_xdl: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(get_type_info())) + "\n"; - } - else - { - msg += " → xdl_params.m_per_xdl: [✗] (missing member)\n"; - } - - if constexpr(requires(XdlParams t) { t.n_per_xdl; }) - { - using NPerXdlType = decltype(std::declval().n_per_xdl); - constexpr bool convertible = SizeType; - msg += " → xdl_params.n_per_xdl: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(get_type_info())) + "\n"; - } - else - { - msg += " → xdl_params.n_per_xdl: [✗] (missing member)\n"; - } - - if constexpr(requires(XdlParams t) { t.m_xdl_per_wave; }) - { - using MXdlPerWaveType = decltype(std::declval().m_xdl_per_wave); - constexpr bool convertible = SizeType; - msg += " → xdl_params.m_xdl_per_wave: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(get_type_info())) + "\n"; - } - else - { - msg += " → xdl_params.m_xdl_per_wave: [✗] (missing member)\n"; - } - - if constexpr(requires(XdlParams t) { t.n_xdl_per_wave; }) - { - using NXdlPerWaveType = decltype(std::declval().n_xdl_per_wave); - constexpr bool convertible = SizeType; - msg += " → xdl_params.n_xdl_per_wave: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(get_type_info())) + "\n"; - } - else - { - msg += " → xdl_params.n_xdl_per_wave: [✗] (missing member)\n"; - } - - return msg; -} - -// BlockTransferDescriptor diagnostics -template -consteval auto diagnose_block_transfer(const char* prefix) -> std::string -{ - std::string msg; - - if constexpr(requires(BT t) { t.k0; }) - { - using K0Type = decltype(std::declval().k0); - constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + - ".k0: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(get_type_info())) + "\n"; - } - else - { - msg += std::string(" → ") + prefix + ".k0: [✗] (missing member)\n"; - } - - if constexpr(requires(BT t) { t.m_n; }) - { - using MNType = decltype(std::declval().m_n); - constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + - ".m_n: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(get_type_info())) + "\n"; - } - else - { - msg += std::string(" → ") + prefix + ".m_n: [✗] (missing member)\n"; - } - - if constexpr(requires(BT t) { t.k1; }) - { - using K1Type = decltype(std::declval().k1); - constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + - ".k1: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(get_type_info())) + "\n"; - } - else - { - msg += std::string(" → ") + prefix + ".k1: [✗] (missing member)\n"; - } - - return msg; -} - -// BlockTransferDescriptor4D diagnostics (requires k_batch_size) -template -consteval auto diagnose_block_transfer_4d(const char* prefix) -> std::string -{ - std::string msg; - - if constexpr(requires(BT t) { t.k0; }) - { - using K0Type = decltype(std::declval().k0); - constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + - ".k0: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(get_type_info())) + "\n"; - } - else - { - msg += std::string(" → ") + prefix + ".k0: [✗] (missing member)\n"; - } - - if constexpr(requires(BT t) { t.m_n; }) - { - using MNType = decltype(std::declval().m_n); - constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + - ".m_n: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(get_type_info())) + "\n"; - } - else - { - msg += std::string(" → ") + prefix + ".m_n: [✗] (missing member)\n"; - } - - if constexpr(requires(BT t) { t.k1; }) - { - using K1Type = decltype(std::declval().k1); - constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + - ".k1: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(get_type_info())) + "\n"; - } - else - { - msg += std::string(" → ") + prefix + ".k1: [✗] (missing member)\n"; - } - - // k_batch_size is required for Bwd descriptor - if constexpr(requires(BT t) { t.k_batch_size; }) - { - using KBatchType = decltype(std::declval().k_batch_size); - constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + - ".k_batch_size: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(get_type_info())) + "\n"; - } - else - { - msg += std::string(" → ") + prefix + ".k_batch_size: [✗] (missing member)\n"; - } - - return msg; -} - -// LdsTransferDescriptor diagnostics -template -consteval auto diagnose_lds_transfer(const char* prefix) -> std::string -{ - std::string msg; - - if constexpr(requires(LT t) { t.src_vector_dim; }) - { - using SrcVectorDimType = decltype(std::declval().src_vector_dim); - constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + - ".src_vector_dim: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(get_type_info())) + "\n"; - } - else - { - msg += std::string(" → ") + prefix + ".src_vector_dim: [✗] (missing member)\n"; - } - - if constexpr(requires(LT t) { t.src_scalar_per_vector; }) - { - using SrcScalarType = decltype(std::declval().src_scalar_per_vector); - constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + - ".src_scalar_per_vector: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(get_type_info())) + "\n"; - } - else - { - msg += - std::string(" → ") + prefix + ".src_scalar_per_vector: [✗] (missing member)\n"; - } - - if constexpr(requires(LT t) { t.lds_dst_scalar_per_vector; }) - { - using LdsDstScalarType = decltype(std::declval().lds_dst_scalar_per_vector); - constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + - ".lds_dst_scalar_per_vector: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(get_type_info())) + "\n"; - } - else - { - msg += std::string(" → ") + prefix + - ".lds_dst_scalar_per_vector: [✗] (missing member)\n"; - } - - if constexpr(requires(LT t) { t.is_direct_load; }) - { - using IsDirectLoadType = decltype(std::declval().is_direct_load); - constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + - ".is_direct_load: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(get_type_info())) + "\n"; - } - else - { - msg += std::string(" → ") + prefix + ".is_direct_load: [✗] (missing member)\n"; - } - - if constexpr(requires(LT t) { t.lds_padding; }) - { - using LdsPaddingType = decltype(std::declval().lds_padding); - constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + - ".lds_padding: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(get_type_info())) + "\n"; - } - else - { - msg += std::string(" → ") + prefix + ".lds_padding: [✗] (missing member)\n"; - } - - return msg; -} - -// ThreadClusterDescriptor diagnostics -template -consteval auto diagnose_thread_cluster(const char* prefix) -> std::string -{ - std::string msg; - - if constexpr(requires(TC t) { t.m_block; }) - { - using MBlockType = decltype(std::declval().m_block); - constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + - ".m_block: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(get_type_info())) + "\n"; - } - else - { - msg += std::string(" → ") + prefix + ".m_block: [✗] (missing member)\n"; - } - - if constexpr(requires(TC t) { t.m_wave_per_xdl; }) - { - using MWaveType = decltype(std::declval().m_wave_per_xdl); - constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + - ".m_wave_per_xdl: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(get_type_info())) + "\n"; - } - else - { - msg += std::string(" → ") + prefix + ".m_wave_per_xdl: [✗] (missing member)\n"; - } - - if constexpr(requires(TC t) { t.n_block; }) - { - using NBlockType = decltype(std::declval().n_block); - constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + - ".n_block: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(get_type_info())) + "\n"; - } - else - { - msg += std::string(" → ") + prefix + ".n_block: [✗] (missing member)\n"; - } - - if constexpr(requires(TC t) { t.n_wave_per_xdl; }) - { - using NWaveType = decltype(std::declval().n_wave_per_xdl); - constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + - ".n_wave_per_xdl: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(get_type_info())) + "\n"; - } - else - { - msg += std::string(" → ") + prefix + ".n_wave_per_xdl: [✗] (missing member)\n"; - } - - return msg; -} - -// AccessOrderDescriptor diagnostics -template -consteval auto diagnose_access_order(const char* prefix) -> std::string -{ - std::string msg; - - if constexpr(requires(AO t) { t.order; }) - { - using OrderType = decltype(std::declval().order); - constexpr bool convertible_3 = std::convertible_to>; - constexpr bool convertible_4 = std::convertible_to>; - constexpr bool convertible = convertible_3 || convertible_4; - msg += std::string(" → ") + prefix + - ".order: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(get_type_info())) + "\n"; - } - else - { - msg += std::string(" → ") + prefix + ".order: [✗] (missing member)\n"; - } - - return msg; -} - -// EpilogueDescriptor diagnostics -template -consteval auto diagnose_epilogue(const char* prefix) -> std::string -{ - std::string msg; - - if constexpr(requires(E t) { t.m_xdl_per_wave_per_shuffle; }) - { - using MXdlType = decltype(std::declval().m_xdl_per_wave_per_shuffle); - constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + - ".m_xdl_per_wave_per_shuffle: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(get_type_info())) + "\n"; - } - else - { - msg += std::string(" → ") + prefix + - ".m_xdl_per_wave_per_shuffle: [✗] (missing member)\n"; - } - - if constexpr(requires(E t) { t.n_per_wave_per_shuffle; }) - { - using NPerWaveType = decltype(std::declval().n_per_wave_per_shuffle); - constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + - ".n_per_wave_per_shuffle: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(get_type_info())) + "\n"; - } - else - { - msg += std::string(" → ") + prefix + - ".n_per_wave_per_shuffle: [✗] (missing member)\n"; - } - - if constexpr(requires(E t) { t.scalar_per_vector; }) - { - using ScalarType = decltype(std::declval().scalar_per_vector); - constexpr bool convertible = std::convertible_to; - msg += std::string(" → ") + prefix + - ".scalar_per_vector: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(get_type_info())) + "\n"; - } - else - { - msg += std::string(" → ") + prefix + ".scalar_per_vector: [✗] (missing member)\n"; - } - - return msg; -} - -} // namespace detail - -// Detailed diagnostic functions for high-level concepts -template -consteval auto detailed_diagnostic_ConvAlgorithmDescriptor() -> std::string -{ - return ""; // Base concept, no sub-requirements to check -} - -template -consteval auto detailed_diagnostic_SpecifiesThreadBlock() -> std::string -{ - if constexpr(!requires { - { T::thread_block } -> ThreadBlockDescriptor; - }) - { - return " → T::thread_block member: [✗] (missing or wrong type)\n"; - } - else - { - return " → T::thread_block member: [✓]\n" + - detail::diagnose_thread_block_descriptor(); - } -} - -template -consteval auto detailed_diagnostic_SpecifiesGridwiseFwdXdlGemm() -> std::string -{ - std::string msg; - - if constexpr(!requires(T t) { - { t.gridwise_gemm } -> GridwiseFwdXdlGemmDescriptor; - }) - { - return " → T::gridwise_gemm member: [✗] (missing or wrong type)\n"; - } - - msg += " → T::gridwise_gemm member: [✓]\n"; - using GG = decltype(T::gridwise_gemm); - - if constexpr(requires(GG t) { t.ak1; }) - { - using AK1Type = decltype(std::declval().ak1); - constexpr bool convertible = std::convertible_to; - msg += " → gridwise_gemm.ak1: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(detail::get_type_info())) + "\n"; - } - else - { - msg += " → gridwise_gemm.ak1: [✗] (missing member)\n"; - } - - if constexpr(requires(GG t) { t.bk1; }) - { - using BK1Type = decltype(std::declval().bk1); - constexpr bool convertible = std::convertible_to; - msg += " → gridwise_gemm.bk1: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(detail::get_type_info())) + "\n"; - } - else - { - msg += " → gridwise_gemm.bk1: [✗] (missing member)\n"; - } - - if constexpr(requires(GG t) { t.xdl_params; }) - { - msg += " → gridwise_gemm.xdl_params member: [✓]\n"; - msg += detail::diagnose_xdl_params().xdl_params)>(); - } - else - { - msg += " → gridwise_gemm.xdl_params: [✗] (missing member)\n"; - } - - return msg; -} - -template -consteval auto detailed_diagnostic_SpecifiesGridwiseBwdXdlGemm() -> std::string -{ - std::string msg; - - if constexpr(!requires(T t) { - { t.gridwise_gemm } -> GridwiseBwdXdlGemmDescriptor; - }) - { - return " → T::gridwise_gemm member: [✗] (missing or wrong type)\n"; - } - - msg += " → T::gridwise_gemm member: [✓]\n"; - using GG = decltype(T::gridwise_gemm); - - if constexpr(requires(GG t) { t.k1; }) - { - using K1Type = decltype(std::declval().k1); - constexpr bool convertible = std::convertible_to; - msg += " → gridwise_gemm.k1: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(detail::get_type_info())) + "\n"; - } - else - { - msg += " → gridwise_gemm.k1: [✗] (missing member)\n"; - } - - if constexpr(requires(GG t) { t.xdl_params; }) - { - msg += " → gridwise_gemm.xdl_params member: [✓]\n"; - msg += detail::diagnose_xdl_params().xdl_params)>(); - } - else - { - msg += " → gridwise_gemm.xdl_params: [✗] (missing member)\n"; - } - - return msg; -} - -template -consteval auto detailed_diagnostic_SpecifiesBlockTransfer() -> std::string -{ - std::string msg; - - constexpr bool has_transfer = requires { T::transfer; }; - msg += " → T::transfer member: " + std::string(CHECK_MARK(has_transfer)) + "\n"; - - if constexpr(!has_transfer) - { - return msg; - } - - constexpr bool has_a = requires { - { T::transfer.a.block_transfer } -> BlockTransferDescriptor; - }; - msg += " → T::transfer.a: " + std::string(CHECK_MARK(has_a)) + "\n"; - if constexpr(!has_a) - { - msg += " → T::transfer.a.block_transfer: [✗] (missing or wrong type)\n"; - } - - constexpr bool has_b = requires { - { T::transfer.b.block_transfer } -> BlockTransferDescriptor; - }; - msg += " → T::transfer.b: " + std::string(CHECK_MARK(has_b)) + "\n"; - if constexpr(!has_b) - { - msg += " → T::transfer.b.block_transfer: [✗] (missing or wrong type)\n"; - } - - constexpr bool has_c = requires { - { T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor; - }; - msg += " → T::transfer.c: " + std::string(CHECK_MARK(has_c)) + "\n"; - if constexpr(!has_c) - { - msg += " → T::transfer.c.thread_cluster_dims: [✗] (missing or wrong type)\n"; - } - - return msg; -} - -template -consteval auto detailed_diagnostic_SpecifiesBlockTransfer4D() -> std::string -{ - std::string msg; - - constexpr bool has_transfer = requires { T::transfer; }; - msg += " → T::transfer member: " + std::string(CHECK_MARK(has_transfer)) + "\n"; - - if constexpr(!has_transfer) - { - return msg; - } - - constexpr bool has_a = requires { - { T::transfer.a.block_transfer } -> BlockTransferDescriptor4D; - }; - msg += " → T::transfer.a: " + std::string(CHECK_MARK(has_a)) + "\n"; - if constexpr(!has_a) - { - msg += " → T::transfer.a.block_transfer: [✗] (missing or wrong type)\n"; - } - else - { - msg += detail::diagnose_block_transfer_4d( - "transfer.a.block_transfer"); - } - - constexpr bool has_b = requires { - { T::transfer.b.block_transfer } -> BlockTransferDescriptor4D; - }; - msg += " → T::transfer.b: " + std::string(CHECK_MARK(has_b)) + "\n"; - if constexpr(!has_b) - { - msg += " → T::transfer.b.block_transfer: [✗] (missing or wrong type)\n"; - } - else - { - msg += detail::diagnose_block_transfer_4d( - "transfer.b.block_transfer"); - } - - constexpr bool has_c = requires { - { T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor; - }; - msg += " → T::transfer.c: " + std::string(CHECK_MARK(has_c)) + "\n"; - if constexpr(!has_c) - { - msg += " → T::transfer.c.thread_cluster_dims: [✗] (missing or wrong type)\n"; - } - - return msg; -} - -template -consteval auto detailed_diagnostic_SpecifiesThreadClusterAccessOrder() -> std::string -{ - std::string msg; - - constexpr bool has_transfer = requires { T::transfer; }; - if constexpr(!has_transfer) - { - return " → T::transfer member: [✗] (missing member)\n"; - } - - constexpr bool has_a = requires { T::transfer.a; }; - constexpr bool has_b = requires { T::transfer.b; }; - - if constexpr(has_a && requires { T::transfer.a.block_transfer_access_order; }) - { - msg += - detail::diagnose_access_order( - "transfer.a.block_transfer_access_order"); - } - else if constexpr(has_a) - { - msg += " → T::transfer.a.block_transfer_access_order: [✗] (missing member)\n"; - } - - if constexpr(has_b && requires { T::transfer.b.block_transfer_access_order; }) - { - msg += - detail::diagnose_access_order( - "transfer.b.block_transfer_access_order"); - } - else if constexpr(has_b) - { - msg += " → T::transfer.b.block_transfer_access_order: [✗] (missing member)\n"; - } - - return msg; -} - -template -consteval auto detailed_diagnostic_SpecifiesSourceAccessOrder() -> std::string -{ - std::string msg; - - constexpr bool has_transfer = requires { T::transfer; }; - if constexpr(!has_transfer) - { - return " → T::transfer member: [✗] (missing member)\n"; - } - - constexpr bool has_a = requires { T::transfer.a; }; - constexpr bool has_b = requires { T::transfer.b; }; - - if constexpr(has_a && requires { T::transfer.a.src_access_order; }) - { - msg += detail::diagnose_access_order( - "transfer.a.src_access_order"); - } - else if constexpr(has_a) - { - msg += " → T::transfer.a.src_access_order: [✗] (missing member)\n"; - } - - if constexpr(has_b && requires { T::transfer.b.src_access_order; }) - { - msg += detail::diagnose_access_order( - "transfer.b.src_access_order"); - } - else if constexpr(has_b) - { - msg += " → T::transfer.b.src_access_order: [✗] (missing member)\n"; - } - - return msg; -} - -template -consteval auto detailed_diagnostic_SpecifiesBlockGemm() -> std::string -{ - std::string msg; - - if constexpr(!requires { - { T::block_gemm_pipeline } -> BlockGemmPipelineDescriptor; - }) - { - return " → T::block_gemm_pipeline: [✗] (missing or wrong type)\n"; - } - - msg += " → T::block_gemm_pipeline member: [✓]\n"; - - if constexpr(requires { T::block_gemm_pipeline.pipeline_version; }) - { - using PipelineType = decltype(T::block_gemm_pipeline.pipeline_version); - constexpr bool convertible = std::convertible_to; - msg += " → block_gemm_pipeline.pipeline_version: " + - std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(detail::get_type_info())) + "\n"; - } - else - { - msg += " → block_gemm_pipeline.pipeline_version: [✗] (missing member)\n"; - } - - if constexpr(requires { T::block_gemm_pipeline.scheduler; }) - { - using SchedulerType = decltype(T::block_gemm_pipeline.scheduler); - constexpr bool convertible = std::convertible_to; - msg += " → block_gemm_pipeline.scheduler: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(detail::get_type_info())) + "\n"; - } - else - { - msg += " → block_gemm_pipeline.scheduler: [✗] (missing member)\n"; - } - - return msg; -} - -template -consteval auto detailed_diagnostic_SpecifiesFwdConvSpecialization() -> std::string -{ - if constexpr(requires { T::fwd_specialization; }) - { - using FwdSpecType = decltype(T::fwd_specialization); - constexpr bool convertible = std::convertible_to; - return " → T::fwd_specialization: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(detail::get_type_info())) + "\n"; - } - else - { - return " → T::fwd_specialization: [✗] (missing member)\n"; - } -} - -template -consteval auto detailed_diagnostic_SpecifiesBwdWeightConvSpecialization() -> std::string -{ - if constexpr(requires { T::bwd_weight_specialization; }) - { - using BwdSpecType = decltype(T::bwd_weight_specialization); - constexpr bool convertible = std::convertible_to; - return " → T::bwd_weight_specialization: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(detail::get_type_info())) + "\n"; - } - else - { - return " → T::bwd_weight_specialization: [✗] (missing member)\n"; - } -} - -template -consteval auto detailed_diagnostic_SpecifiesGemmSpecialization() -> std::string -{ - if constexpr(requires { T::gemm_specialization; }) - { - using GemmSpecType = decltype(T::gemm_specialization); - constexpr bool convertible = std::convertible_to; - return " → T::gemm_specialization: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(detail::get_type_info())) + "\n"; - } - else - { - return " → T::gemm_specialization: [✗] (missing member)\n"; - } -} - -template -consteval auto detailed_diagnostic_SpecifiesNumPrefetchStages() -> std::string -{ - if constexpr(requires { T::num_gemm_k_prefetch_stages; }) - { - using NumPrefetchType = decltype(T::num_gemm_k_prefetch_stages); - constexpr bool convertible = std::convertible_to; - return " → T::num_gemm_k_prefetch_stages: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(detail::get_type_info())) + "\n"; - } - else - { - return " → T::num_gemm_k_prefetch_stages: [✗] (missing member)\n"; - } -} - -template -consteval auto detailed_diagnostic_SpecifiesNumGroupsToMerge() -> std::string -{ - if constexpr(requires { T::num_conv_groups_to_merge; }) - { - using NumGroupsType = decltype(T::num_conv_groups_to_merge); - constexpr bool convertible = std::convertible_to; - return " → T::num_conv_groups_to_merge: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(detail::get_type_info())) + "\n"; - } - else - { - return " → T::num_conv_groups_to_merge: [✗] (missing member)\n"; - } -} - -template -consteval auto detailed_diagnostic_SpecifiesLoopScheduler() -> std::string -{ - if constexpr(requires { T::loop_scheduler; }) - { - using LoopSchedulerType = decltype(T::loop_scheduler); - constexpr bool convertible = std::convertible_to; - return " → T::loop_scheduler: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(detail::get_type_info())) + "\n"; - } - else - { - return " → T::loop_scheduler: [✗] (missing member)\n"; - } -} - -template -consteval auto detailed_diagnostic_SpecifiesLargeTensorSupport() -> std::string -{ - std::string msg; - if constexpr(requires { T::specialization; }) - { - using SpecType = decltype(T::specialization); - constexpr bool convertible = std::convertible_to; - msg += " → T::specialization: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(detail::get_type_info())) + "\n"; - - if constexpr(convertible) - { - constexpr bool is_large_tensor = - (T::specialization == ConvAlgorithmSpecialization::LARGE_TENSOR); - msg += " → specialization == LARGE_TENSOR: " + - std::string(CHECK_MARK(is_large_tensor)) + "\n"; - } - } - else - { - msg += " → T::specialization: [✗] (missing member)\n"; - } - - return msg; -} - -template -consteval auto detailed_diagnostic_SpecifiesReferenceAlgorithm() -> std::string -{ - std::string msg; - if constexpr(requires { T::specialization; }) - { - using SpecType = decltype(T::specialization); - constexpr bool convertible = std::convertible_to; - msg += " → T::specialization: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(detail::get_type_info())) + "\n"; - - if constexpr(convertible) - { - constexpr bool is_reference = - (T::specialization == ConvAlgorithmSpecialization::REFERENCE); - msg += " → specialization == REFERENCE: " + std::string(CHECK_MARK(is_reference)) + - "\n"; - } - } - else - { - msg += " → T::specialization: [✗] (missing member)\n"; - } - - return msg; -} - -template -consteval auto detailed_diagnostic_SpecifiesTwoStageSupport() -> std::string -{ - std::string msg; - if constexpr(requires { T::specialization; }) - { - using SpecType = decltype(T::specialization); - constexpr bool convertible = std::convertible_to; - msg += " → T::specialization: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(detail::get_type_info())) + "\n"; - - if constexpr(convertible) - { - constexpr bool is_two_stage = - (T::specialization == ConvAlgorithmSpecialization::TWO_STAGE); - msg += " → specialization == TWO_STAGE: " + std::string(CHECK_MARK(is_two_stage)) + - "\n"; - } - } - else - { - msg += " → T::specialization: [✗] (missing member)\n"; - } - - return msg; -} - -template -consteval auto detailed_diagnostic_SpecifiesMultipleDSupport() -> std::string -{ - std::string msg; - if constexpr(requires { T::specialization; }) - { - using SpecType = decltype(T::specialization); - constexpr bool convertible = std::convertible_to; - msg += " → T::specialization: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(detail::get_type_info())) + "\n"; - - if constexpr(convertible) - { - constexpr bool is_multiple_d = - (T::specialization == ConvAlgorithmSpecialization::MULTIPLE_D); - msg += - " → specialization == MULTIPLE_D: " + std::string(CHECK_MARK(is_multiple_d)) + - "\n"; - } - } - else - { - msg += " → T::specialization: [✗] (missing member)\n"; - } - - return msg; -} - -template -consteval auto detailed_diagnostic_SpecifiesGenericInstance() -> std::string -{ - std::string msg; - if constexpr(requires { T::specialization; }) - { - msg += " → T::specialization: [✗] (member should NOT exist for generic instance)\n"; - msg += " → This concept requires the absence of the specialization member\n"; - } - return msg; -} - -template -consteval auto detailed_diagnostic_SpecifiesTransposeTransfer() -> std::string -{ - std::string msg; - - if constexpr(requires { T::max_transpose_transfer_src_scalar_per_vector; }) - { - using SrcType = decltype(T::max_transpose_transfer_src_scalar_per_vector); - constexpr bool convertible = std::convertible_to; - msg += " → T::max_transpose_transfer_src_scalar_per_vector: " + - std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(detail::get_type_info())) + "\n"; - } - else - { - msg += " → T::max_transpose_transfer_src_scalar_per_vector: [✗] (missing member)\n"; - } - - if constexpr(requires { T::max_transpose_transfer_dst_scalar_per_vector; }) - { - using DstType = decltype(T::max_transpose_transfer_dst_scalar_per_vector); - constexpr bool convertible = std::convertible_to; - msg += " → T::max_transpose_transfer_dst_scalar_per_vector: " + - std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(detail::get_type_info())) + "\n"; - } - else - { - msg += " → T::max_transpose_transfer_dst_scalar_per_vector: [✗] (missing member)\n"; - } - - return msg; -} - -template -consteval auto detailed_diagnostic_TransposeTransferWellDefinedIfProvided() -> std::string -{ - std::string msg; - - constexpr bool has_src = requires { T::max_transpose_transfer_src_scalar_per_vector; }; - constexpr bool has_dst = requires { T::max_transpose_transfer_dst_scalar_per_vector; }; - constexpr bool has_transpose_transfer = has_src || has_dst; - - if constexpr(!has_transpose_transfer) - { - msg += " → Transpose transfer fields not provided: [✓] (optional, not required)\n"; - } - else - { - msg += " → Transpose transfer fields provided, checking if well-defined:\n"; - - if constexpr(has_src) - { - using SrcType = decltype(T::max_transpose_transfer_src_scalar_per_vector); - constexpr bool src_convertible = std::convertible_to; - msg += " → T::max_transpose_transfer_src_scalar_per_vector: " + - std::string(CHECK_MARK(src_convertible)) + - (src_convertible ? "" : std::string(detail::get_type_info())) + "\n"; - } - else - { - msg += " → T::max_transpose_transfer_src_scalar_per_vector: [✗] (missing, but " - "dst is provided)\n"; - } - - if constexpr(has_dst) - { - using DstType = decltype(T::max_transpose_transfer_dst_scalar_per_vector); - constexpr bool dst_convertible = std::convertible_to; - msg += " → T::max_transpose_transfer_dst_scalar_per_vector: " + - std::string(CHECK_MARK(dst_convertible)) + - (dst_convertible ? "" : std::string(detail::get_type_info())) + "\n"; - } - else - { - msg += " → T::max_transpose_transfer_dst_scalar_per_vector: [✗] (missing, but " - "src is provided)\n"; - } - } - - return msg; -} - -template -consteval auto detailed_diagnostic_SpecifiesGemmBatchOptions() -> std::string -{ - if constexpr(requires { T::num_conv_groups_to_merge; }) - { - using NumGroupsType = decltype(T::num_conv_groups_to_merge); - constexpr bool convertible = std::convertible_to; - return " → T::num_conv_groups_to_merge: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(detail::get_type_info())) + "\n"; - } - else - { - return " → T::num_conv_groups_to_merge: [✗] (missing member)\n"; - } -} - -template -consteval auto detailed_diagnostic_SpecifiesGridwiseWmmaGemm() -> std::string -{ - std::string msg; - constexpr bool has_gridwise_gemm = requires(T t) { - { t.gridwise_gemm } -> GridwiseWmmaGemmDescriptor; - }; - msg += " → T::gridwise_gemm member: " + std::string(CHECK_MARK(has_gridwise_gemm)) + "\n"; - - if constexpr(!has_gridwise_gemm) - { - return msg; - } - - using GG = decltype(T::gridwise_gemm); - constexpr bool has_k1 = requires(GG t) { - { t.k1 } -> std::convertible_to; - }; - constexpr bool has_m_per_wmma = requires(GG t) { - { t.m_per_wmma } -> std::convertible_to; - }; - constexpr bool has_n_per_wmma = requires(GG t) { - { t.n_per_wmma } -> std::convertible_to; - }; - constexpr bool has_m_wmma_per_wave = requires(GG t) { - { t.m_wmma_per_wave } -> std::convertible_to; - }; - constexpr bool has_n_wmma_per_wave = requires(GG t) { - { t.n_wmma_per_wave } -> std::convertible_to; - }; - - msg += " → gridwise_gemm.k1: " + std::string(CHECK_MARK(has_k1)) + - (has_k1 ? "\n" : " (missing or wrong type)\n"); - msg += " → gridwise_gemm.m_per_wmma: " + std::string(CHECK_MARK(has_m_per_wmma)) + - (has_m_per_wmma ? "\n" : " (missing or wrong type)\n"); - msg += " → gridwise_gemm.n_per_wmma: " + std::string(CHECK_MARK(has_n_per_wmma)) + - (has_n_per_wmma ? "\n" : " (missing or wrong type)\n"); - msg += - " → gridwise_gemm.m_wmma_per_wave: " + std::string(CHECK_MARK(has_m_wmma_per_wave)) + - (has_m_wmma_per_wave ? "\n" : " (missing or wrong type)\n"); - msg += - " → gridwise_gemm.n_wmma_per_wave: " + std::string(CHECK_MARK(has_n_wmma_per_wave)) + - (has_n_wmma_per_wave ? "\n" : " (missing or wrong type)\n"); - - return msg; -} - -// Tile-specific diagnostics -template -consteval auto detailed_diagnostic_SpecifiesTileThreadBlock() -> std::string -{ - if constexpr(!requires { - { T::thread_block } -> TileThreadBlockDescriptor; - }) - { - return " → T::thread_block member: [✗] (missing or wrong type)\n"; - } - else - { - using TB = decltype(T::thread_block); - std::string msg = " → T::thread_block member: [✓]\n"; - - constexpr bool has_tile_m = requires(TB t) { - { t.tile_size.m } -> std::convertible_to; - }; - constexpr bool has_tile_n = requires(TB t) { - { t.tile_size.n } -> std::convertible_to; - }; - constexpr bool has_tile_k = requires(TB t) { - { t.tile_size.k } -> std::convertible_to; - }; - - msg += " → thread_block.tile_size.m: " + std::string(CHECK_MARK(has_tile_m)) + - (has_tile_m ? "\n" : " (missing or wrong type)\n"); - msg += " → thread_block.tile_size.n: " + std::string(CHECK_MARK(has_tile_n)) + - (has_tile_n ? "\n" : " (missing or wrong type)\n"); - msg += " → thread_block.tile_size.k: " + std::string(CHECK_MARK(has_tile_k)) + - (has_tile_k ? "\n" : " (missing or wrong type)\n"); - - return msg; - } -} - -template -consteval auto detailed_diagnostic_SpecifiesTileTransfer() -> std::string -{ - std::string msg; - constexpr bool has_transfer = requires { T::transfer; }; - msg += " → T::transfer member: " + std::string(CHECK_MARK(has_transfer)) + "\n"; - - if constexpr(!has_transfer) - { - return msg; - } - - constexpr bool has_a_scalar = requires { - { T::transfer.a_scalar_per_vector } -> std::convertible_to; - }; - constexpr bool has_b_scalar = requires { - { T::transfer.b_scalar_per_vector } -> std::convertible_to; - }; - constexpr bool has_c_scalar = requires { - { T::transfer.c_scalar_per_vector } -> std::convertible_to; - }; - - msg += " → transfer.a_scalar_per_vector: " + std::string(CHECK_MARK(has_a_scalar)) + - (has_a_scalar ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.b_scalar_per_vector: " + std::string(CHECK_MARK(has_b_scalar)) + - (has_b_scalar ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.c_scalar_per_vector: " + std::string(CHECK_MARK(has_c_scalar)) + - (has_c_scalar ? "\n" : " (missing or wrong type)\n"); - - return msg; -} - -template -consteval auto detailed_diagnostic_SpecifiesTileBlockGemm() -> std::string -{ - std::string msg; - constexpr bool has_block_gemm = requires { - { T::block_gemm } -> TileBlockGemmDescriptor; - }; - msg += " → T::block_gemm member: " + std::string(CHECK_MARK(has_block_gemm)) + "\n"; - - if constexpr(!has_block_gemm) - { - return msg; - } - - using BG = decltype(T::block_gemm); - constexpr bool has_warps_m = requires(BG t) { - { t.warps.m } -> std::convertible_to; - }; - constexpr bool has_warps_n = requires(BG t) { - { t.warps.n } -> std::convertible_to; - }; - constexpr bool has_warps_k = requires(BG t) { - { t.warps.k } -> std::convertible_to; - }; - constexpr bool has_warp_tile_m = requires(BG t) { - { t.warp_tile.m } -> std::convertible_to; - }; - constexpr bool has_warp_tile_n = requires(BG t) { - { t.warp_tile.n } -> std::convertible_to; - }; - constexpr bool has_warp_tile_k = requires(BG t) { - { t.warp_tile.k } -> std::convertible_to; - }; - constexpr bool has_double_smem = requires(BG t) { - { t.double_smem_buffer } -> std::convertible_to; - }; - constexpr bool has_num_wave_groups = requires(BG t) { - { t.num_wave_groups } -> std::convertible_to; - }; - constexpr bool has_pipeline = requires(BG t) { - { t.pipeline_version } -> std::convertible_to; - }; - constexpr bool has_scheduler = requires(BG t) { - { t.scheduler } -> std::convertible_to; - }; - - msg += " → block_gemm.warps.m: " + std::string(CHECK_MARK(has_warps_m)) + - (has_warps_m ? "\n" : " (missing or wrong type)\n"); - msg += " → block_gemm.warps.n: " + std::string(CHECK_MARK(has_warps_n)) + - (has_warps_n ? "\n" : " (missing or wrong type)\n"); - msg += " → block_gemm.warps.k: " + std::string(CHECK_MARK(has_warps_k)) + - (has_warps_k ? "\n" : " (missing or wrong type)\n"); - msg += " → block_gemm.warp_tile.m: " + std::string(CHECK_MARK(has_warp_tile_m)) + - (has_warp_tile_m ? "\n" : " (missing or wrong type)\n"); - msg += " → block_gemm.warp_tile.n: " + std::string(CHECK_MARK(has_warp_tile_n)) + - (has_warp_tile_n ? "\n" : " (missing or wrong type)\n"); - msg += " → block_gemm.warp_tile.k: " + std::string(CHECK_MARK(has_warp_tile_k)) + - (has_warp_tile_k ? "\n" : " (missing or wrong type)\n"); - msg += " → block_gemm.double_smem_buffer: " + std::string(CHECK_MARK(has_double_smem)) + - (has_double_smem ? "\n" : " (missing or wrong type)\n"); - msg += " → block_gemm.num_wave_groups: " + std::string(CHECK_MARK(has_num_wave_groups)) + - (has_num_wave_groups ? "\n" : " (missing or wrong type)\n"); - msg += " → block_gemm.pipeline_version: " + std::string(CHECK_MARK(has_pipeline)) + - (has_pipeline ? "\n" : " (missing or wrong type)\n"); - msg += " → block_gemm.scheduler: " + std::string(CHECK_MARK(has_scheduler)) + - (has_scheduler ? "\n" : " (missing or wrong type)\n"); - - return msg; -} - -template -consteval auto detailed_diagnostic_SpecifiesTileOptimizations() -> std::string -{ - std::string msg; - constexpr bool has_optimizations = requires { - { T::optimizations } -> TileOptimizationsDescriptor; - }; - msg += " → T::optimizations member: " + std::string(CHECK_MARK(has_optimizations)) + "\n"; - - if constexpr(!has_optimizations) - { - return msg; - } - - using OPT = decltype(T::optimizations); - constexpr bool has_num_groups = requires(OPT t) { - { t.num_groups_to_merge } -> std::convertible_to; - }; - constexpr bool has_split_image = requires(OPT t) { - { t.split_image } -> std::convertible_to; - }; - constexpr bool has_explicit_gemm = requires(OPT t) { - { t.explicit_gemm } -> std::convertible_to; - }; - - msg += " → optimizations.num_groups_to_merge: " + std::string(CHECK_MARK(has_num_groups)) + - (has_num_groups ? "\n" : " (missing or wrong type)\n"); - msg += " → optimizations.split_image: " + std::string(CHECK_MARK(has_split_image)) + - (has_split_image ? "\n" : " (missing or wrong type)\n"); - msg += " → optimizations.explicit_gemm: " + std::string(CHECK_MARK(has_explicit_gemm)) + - (has_explicit_gemm ? "\n" : " (missing or wrong type)\n"); - - return msg; -} - -// DL-specific diagnostics -template -consteval auto detailed_diagnostic_SpecifiesDlThreadConfig() -> std::string -{ - std::string msg; - constexpr bool has_thread_config = requires { - { T::thread_config } -> DlThreadConfigDescriptor; - }; - msg += " → T::thread_config member: " + std::string(CHECK_MARK(has_thread_config)) + "\n"; - - if constexpr(!has_thread_config) - { - return msg; - } - - using TC = decltype(T::thread_config); - constexpr bool has_k0 = requires(TC t) { - { t.k0_per_block } -> std::convertible_to; - }; - constexpr bool has_k1 = requires(TC t) { - { t.k1 } -> std::convertible_to; - }; - constexpr bool has_m1 = requires(TC t) { - { t.m1_per_thread } -> std::convertible_to; - }; - constexpr bool has_n1 = requires(TC t) { - { t.n1_per_thread } -> std::convertible_to; - }; - constexpr bool has_k = requires(TC t) { - { t.k_per_thread } -> std::convertible_to; - }; - - msg += " → thread_config.k0_per_block: " + std::string(CHECK_MARK(has_k0)) + - (has_k0 ? "\n" : " (missing or wrong type)\n"); - msg += " → thread_config.k1: " + std::string(CHECK_MARK(has_k1)) + - (has_k1 ? "\n" : " (missing or wrong type)\n"); - msg += " → thread_config.m1_per_thread: " + std::string(CHECK_MARK(has_m1)) + - (has_m1 ? "\n" : " (missing or wrong type)\n"); - msg += " → thread_config.n1_per_thread: " + std::string(CHECK_MARK(has_n1)) + - (has_n1 ? "\n" : " (missing or wrong type)\n"); - msg += " → thread_config.k_per_thread: " + std::string(CHECK_MARK(has_k)) + - (has_k ? "\n" : " (missing or wrong type)\n"); - - return msg; -} - -template -consteval auto detailed_diagnostic_SpecifiesDlThreadCluster() -> std::string -{ - std::string msg; - constexpr bool has_thread_cluster = requires { - { T::thread_cluster } -> DlThreadClusterDescriptor; - }; - msg += - " → T::thread_cluster member: " + std::string(CHECK_MARK(has_thread_cluster)) + "\n"; - - if constexpr(!has_thread_cluster) - { - return msg; - } - - using TC = decltype(T::thread_cluster); - constexpr bool has_m1_xs = requires(TC t) { - { t.m1_xs } -> std::convertible_to>; - }; - constexpr bool has_n1_xs = requires(TC t) { - { t.n1_xs } -> std::convertible_to>; - }; - - msg += " → thread_cluster.m1_xs: " + std::string(CHECK_MARK(has_m1_xs)) + - (has_m1_xs ? "\n" : " (missing or wrong type)\n"); - msg += " → thread_cluster.n1_xs: " + std::string(CHECK_MARK(has_n1_xs)) + - (has_n1_xs ? "\n" : " (missing or wrong type)\n"); - - return msg; -} - -template -consteval auto detailed_diagnostic_SpecifiesDlFwdBlockTransfer() -> std::string -{ - std::string msg; - constexpr bool has_transfer = requires { T::transfer; }; - msg += " → T::transfer member: " + std::string(CHECK_MARK(has_transfer)) + "\n"; - - if constexpr(!has_transfer) - { - return msg; - } - - constexpr bool has_a = requires { - { T::transfer.a } -> DlBlockTransferDescriptor4D; - }; - constexpr bool has_b = requires { - { T::transfer.b } -> DlBlockTransferDescriptor4D; - }; - msg += " → T::transfer.a: " + std::string(CHECK_MARK(has_a)) + "\n"; - - if constexpr(has_a) - { - using ABT = decltype(T::transfer.a); - constexpr bool has_thread_slice = requires(ABT t) { - { t.thread_slice_lengths } -> std::convertible_to>; - }; - constexpr bool has_thread_cluster = requires(ABT t) { - { t.thread_cluster_lengths } -> std::convertible_to>; - }; - constexpr bool has_cluster_arrange = requires(ABT t) { - { t.thread_cluster_arrange_order } -> std::convertible_to>; - }; - constexpr bool has_src_access = requires(ABT t) { - { t.src_access_order } -> std::convertible_to>; - }; - constexpr bool has_src_vector = requires(ABT t) { - { t.src_vector_tensor_lengths } -> std::convertible_to>; - }; - constexpr bool has_src_contiguous = requires(ABT t) { - { - t.src_vector_tensor_contiguous_dim_order - } -> std::convertible_to>; - }; - constexpr bool has_dst_vector = requires(ABT t) { - { t.dst_vector_tensor_lengths } -> std::convertible_to>; - }; - - msg += " → transfer.a.thread_slice_lengths (4D): " + - std::string(CHECK_MARK(has_thread_slice)) + - (has_thread_slice ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.a.thread_cluster_lengths (4D): " + - std::string(CHECK_MARK(has_thread_cluster)) + - (has_thread_cluster ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.a.thread_cluster_arrange_order (4D): " + - std::string(CHECK_MARK(has_cluster_arrange)) + - (has_cluster_arrange ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.a.src_access_order (4D): " + - std::string(CHECK_MARK(has_src_access)) + - (has_src_access ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.a.src_vector_tensor_lengths (4D): " + - std::string(CHECK_MARK(has_src_vector)) + - (has_src_vector ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.a.src_vector_tensor_contiguous_dim_order (4D): " + - std::string(CHECK_MARK(has_src_contiguous)) + - (has_src_contiguous ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.a.dst_vector_tensor_lengths (4D): " + - std::string(CHECK_MARK(has_dst_vector)) + - (has_dst_vector ? "\n" : " (missing or wrong type)\n"); - } - else - { - msg += " → T::transfer.a (4D): [✗] (missing or wrong type)\n"; - } - - msg += " → T::transfer.b: " + std::string(CHECK_MARK(has_b)) + "\n"; - - if constexpr(has_b) - { - msg += " → T::transfer.b (4D): [✓] (similar fields as transfer.a)\n"; - } - else - { - msg += " → T::transfer.b (4D): [✗] (missing or wrong type)\n"; - } - - return msg; -} - -template -consteval auto detailed_diagnostic_SpecifiesDlBwdBlockTransfer() -> std::string -{ - std::string msg; - constexpr bool has_transfer = requires { T::transfer; }; - msg += " → T::transfer member: " + std::string(CHECK_MARK(has_transfer)) + "\n"; - - if constexpr(!has_transfer) - { - return msg; - } - - constexpr bool has_a = requires { - { T::transfer.a } -> DlBlockTransferDescriptor5D; - }; - constexpr bool has_b = requires { - { T::transfer.b } -> DlBlockTransferDescriptor5D; - }; - msg += " → T::transfer.a: " + std::string(CHECK_MARK(has_a)) + "\n"; - - if constexpr(has_a) - { - using ABT = decltype(T::transfer.a); - constexpr bool has_thread_slice = requires(ABT t) { - { t.thread_slice_lengths } -> std::convertible_to>; - }; - constexpr bool has_thread_cluster = requires(ABT t) { - { t.thread_cluster_lengths } -> std::convertible_to>; - }; - constexpr bool has_cluster_arrange = requires(ABT t) { - { t.thread_cluster_arrange_order } -> std::convertible_to>; - }; - constexpr bool has_src_access = requires(ABT t) { - { t.src_access_order } -> std::convertible_to>; - }; - constexpr bool has_src_vector = requires(ABT t) { - { t.src_vector_tensor_lengths } -> std::convertible_to>; - }; - constexpr bool has_src_contiguous = requires(ABT t) { - { - t.src_vector_tensor_contiguous_dim_order - } -> std::convertible_to>; - }; - constexpr bool has_dst_vector = requires(ABT t) { - { t.dst_vector_tensor_lengths } -> std::convertible_to>; - }; - - msg += " → transfer.a.thread_slice_lengths (5D): " + - std::string(CHECK_MARK(has_thread_slice)) + - (has_thread_slice ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.a.thread_cluster_lengths (5D): " + - std::string(CHECK_MARK(has_thread_cluster)) + - (has_thread_cluster ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.a.thread_cluster_arrange_order (5D): " + - std::string(CHECK_MARK(has_cluster_arrange)) + - (has_cluster_arrange ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.a.src_access_order (5D): " + - std::string(CHECK_MARK(has_src_access)) + - (has_src_access ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.a.src_vector_tensor_lengths (5D): " + - std::string(CHECK_MARK(has_src_vector)) + - (has_src_vector ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.a.src_vector_tensor_contiguous_dim_order (5D): " + - std::string(CHECK_MARK(has_src_contiguous)) + - (has_src_contiguous ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.a.dst_vector_tensor_lengths (5D): " + - std::string(CHECK_MARK(has_dst_vector)) + - (has_dst_vector ? "\n" : " (missing or wrong type)\n"); - } - else - { - msg += " → T::transfer.a (5D): [✗] (missing or wrong type)\n"; - } - - msg += " → T::transfer.b: " + std::string(CHECK_MARK(has_b)) + "\n"; - - if constexpr(has_b) - { - msg += " → T::transfer.b (5D): [✓] (similar fields as transfer.a)\n"; - } - else - { - msg += " → T::transfer.b (5D): [✗] (missing or wrong type)\n"; - } - - return msg; -} - -template -consteval auto detailed_diagnostic_SpecifiesDlEpilogue() -> std::string -{ - std::string msg; - constexpr bool has_transfer = requires { T::transfer; }; - if constexpr(!has_transfer) - { - return " → T::transfer member: [✗] (not found)\n"; - } - - constexpr bool has_c = requires { T::transfer.c; }; - msg += " → T::transfer.c: " + std::string(CHECK_MARK(has_c)) + "\n"; - - if constexpr(has_c && requires { T::transfer.c.src_dst_access_order; }) - { - using C = decltype(T::transfer.c); - constexpr bool has_src_dst_access = requires(C t) { - { t.src_dst_access_order } -> std::convertible_to>; - }; - constexpr bool has_src_dst_vector_dim = requires(C t) { - { t.src_dst_vector_dim } -> std::convertible_to; - }; - constexpr bool has_dst_scalar = requires(C t) { - { t.dst_scalar_per_vector } -> std::convertible_to; - }; - - msg += " → transfer.c.src_dst_access_order: " + - std::string(CHECK_MARK(has_src_dst_access)) + - (has_src_dst_access ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.c.src_dst_vector_dim: " + - std::string(CHECK_MARK(has_src_dst_vector_dim)) + - (has_src_dst_vector_dim ? "\n" : " (missing or wrong type)\n"); - msg += " → transfer.c.dst_scalar_per_vector: " + - std::string(CHECK_MARK(has_dst_scalar)) + - (has_dst_scalar ? "\n" : " (missing or wrong type)\n"); - } - else if constexpr(has_c) - { - msg += " → T::transfer.c (DlEpilogue): [✗] (missing required fields)\n"; - } - - return msg; -} - -template -consteval auto detailed_diagnostic_SpecifiesTileConvSpecialization() -> std::string -{ - if constexpr(requires { T::specialization; }) - { - using SpecType = decltype(T::specialization); - constexpr bool convertible = std::convertible_to; - return " → T::specialization: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(detail::get_type_info())) + "\n"; - } - else - { - return " → T::specialization: [✗] (missing member)\n"; - } -} - -template -consteval auto detailed_diagnostic_SpecifiesLdsTransfer() -> std::string -{ - std::string msg; - - constexpr bool has_transfer = requires { T::transfer; }; - msg += " → T::transfer member: " + std::string(CHECK_MARK(has_transfer)) + "\n"; - - if constexpr(!has_transfer) - { - return msg; - } - - constexpr bool has_a = requires { - { T::transfer.a.lds_transfer } -> LdsTransferDescriptor; - }; - msg += " → T::transfer.a: " + std::string(CHECK_MARK(has_a)) + "\n"; - if constexpr(!has_a) - { - msg += " → T::transfer.a.lds_transfer: [✗] (missing or wrong type)\n"; - } - - constexpr bool has_b = requires { - { T::transfer.b.lds_transfer } -> LdsTransferDescriptor; - }; - msg += " → T::transfer.b: " + std::string(CHECK_MARK(has_b)) + "\n"; - if constexpr(!has_b) - { - msg += " → T::transfer.b.lds_transfer: [✗] (missing or wrong type)\n"; - } - - constexpr bool has_c = requires { - { T::transfer.c.epilogue } -> EpilogueDescriptor; - }; - msg += " → T::transfer.c: " + std::string(CHECK_MARK(has_c)) + "\n"; - if constexpr(!has_c) - { - msg += " → T::transfer.c.epilogue: [✗] (missing or wrong type)\n"; - } - - return msg; -} - -template -consteval auto detailed_diagnostic_SpecifieGridwiseGemmPipeline() -> std::string -{ - if constexpr(requires { T::pipeline_version; }) - { - using PipelineType = decltype(T::pipeline_version); - constexpr bool convertible = std::convertible_to; - return " → T::pipeline_version: " + std::string(CHECK_MARK(convertible)) + - (convertible ? "" : std::string(detail::get_type_info())) + "\n"; - } - else - { - return " → T::pipeline_version: [✗] (missing member)\n"; - } -} - -} // namespace ck_tile::builder::diagnostics diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index 9b02b4c9fc0..ffd45efe496 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -3,954 +3,123 @@ #pragma once -#include "ck_tile/builder/conv_algorithm_diagnostics.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" namespace ck_tile::builder::factory { -using namespace ck_tile::builder::diagnostics; - -template -struct ReferenceAlgorithm -{ - CHECK_CONCEPT(T, ConvAlgorithmDescriptor) - CHECK_CONCEPT(T, SpecifiesReferenceAlgorithm) - - static constexpr bool c1 = c_ConvAlgorithmDescriptor; - static constexpr bool c2 = c_SpecifiesReferenceAlgorithm; - - static consteval bool is_valid() { return c1 && c2; } - - static consteval auto message() -> std::string - { - return std::string("\n=== Reference Algorithm Diagnostic (closest match) ===\n" - "Concepts for Reference Algorithm:\n") + - DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + - DIAGNOSTIC_LINE(SpecifiesReferenceAlgorithm); - } -}; - -template -struct FwdXdlV3Algorithm -{ - CHECK_CONCEPT(T, ConvAlgorithmDescriptor) - CHECK_CONCEPT(T, SpecifiesThreadBlock) - CHECK_CONCEPT(T, SpecifiesBlockTransfer) - CHECK_CONCEPT(T, SpecifiesLdsTransfer) - CHECK_CONCEPT(T, SpecifiesThreadClusterAccessOrder) - CHECK_CONCEPT(T, SpecifiesSourceAccessOrder) - CHECK_CONCEPT(T, SpecifiesGridwiseFwdXdlGemm) - CHECK_CONCEPT(T, SpecifiesFwdConvSpecialization) - CHECK_CONCEPT(T, SpecifiesGemmSpecialization) - CHECK_CONCEPT(T, SpecifiesBlockGemm) - - static constexpr bool c1 = c_ConvAlgorithmDescriptor; - static constexpr bool c2 = c_SpecifiesThreadBlock; - static constexpr bool c3 = c_SpecifiesBlockTransfer; - static constexpr bool c4 = c_SpecifiesLdsTransfer; - static constexpr bool c5 = c_SpecifiesThreadClusterAccessOrder; - static constexpr bool c6 = c_SpecifiesSourceAccessOrder; - static constexpr bool c7 = c_SpecifiesGridwiseFwdXdlGemm; - static constexpr bool c8 = c_SpecifiesFwdConvSpecialization; - static constexpr bool c9 = c_SpecifiesGemmSpecialization; - static constexpr bool c10 = c_SpecifiesBlockGemm; - - static consteval bool is_valid() - { - return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && c10; - } - - static consteval auto message() -> std::string - { - return std::string("\n=== Forward XDL V3 Algorithm Diagnostic (closest match) ===\n" - "Concepts for FwdXdlV3 Algorithm:\n") + - DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + DIAGNOSTIC_LINE(SpecifiesThreadBlock) + - DIAGNOSTIC_LINE(SpecifiesBlockTransfer) + DIAGNOSTIC_LINE(SpecifiesLdsTransfer) + - DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) + - DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) + - DIAGNOSTIC_LINE(SpecifiesGridwiseFwdXdlGemm) + - DIAGNOSTIC_LINE(SpecifiesFwdConvSpecialization) + - DIAGNOSTIC_LINE(SpecifiesGemmSpecialization) + DIAGNOSTIC_LINE(SpecifiesBlockGemm); - } -}; - -template -struct FwdXdlAlgorithmBase -{ - CHECK_CONCEPT(T, ConvAlgorithmDescriptor) - CHECK_CONCEPT(T, SpecifiesThreadBlock) - CHECK_CONCEPT(T, SpecifiesBlockTransfer) - CHECK_CONCEPT(T, SpecifiesLdsTransfer) - CHECK_CONCEPT(T, SpecifiesThreadClusterAccessOrder) - CHECK_CONCEPT(T, SpecifiesSourceAccessOrder) - CHECK_CONCEPT(T, SpecifiesGridwiseFwdXdlGemm) - CHECK_CONCEPT(T, SpecifiesFwdConvSpecialization) - CHECK_CONCEPT(T, SpecifiesGemmSpecialization) - CHECK_CONCEPT(T, SpecifiesNumPrefetchStages) - CHECK_CONCEPT(T, SpecifiesNumGroupsToMerge) - CHECK_CONCEPT(T, SpecifiesLoopScheduler) - - static constexpr bool c1 = c_ConvAlgorithmDescriptor; - static constexpr bool c2 = c_SpecifiesThreadBlock; - static constexpr bool c3 = c_SpecifiesBlockTransfer; - static constexpr bool c4 = c_SpecifiesLdsTransfer; - static constexpr bool c5 = c_SpecifiesThreadClusterAccessOrder; - static constexpr bool c6 = c_SpecifiesSourceAccessOrder; - static constexpr bool c7 = c_SpecifiesGridwiseFwdXdlGemm; - static constexpr bool c8 = c_SpecifiesFwdConvSpecialization; - static constexpr bool c9 = c_SpecifiesGemmSpecialization; - static constexpr bool c10 = c_SpecifiesNumPrefetchStages; - static constexpr bool c11 = c_SpecifiesNumGroupsToMerge; - static constexpr bool c12 = c_SpecifiesLoopScheduler; - - static consteval bool is_valid() - { - return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && c10 && c11 && c12; - } - - static consteval auto message() -> std::string - { - return DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + DIAGNOSTIC_LINE(SpecifiesThreadBlock) + - DIAGNOSTIC_LINE(SpecifiesBlockTransfer) + DIAGNOSTIC_LINE(SpecifiesLdsTransfer) + - DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) + - DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) + - DIAGNOSTIC_LINE(SpecifiesGridwiseFwdXdlGemm) + - DIAGNOSTIC_LINE(SpecifiesFwdConvSpecialization) + - DIAGNOSTIC_LINE(SpecifiesGemmSpecialization) + - DIAGNOSTIC_LINE(SpecifiesNumPrefetchStages) + - DIAGNOSTIC_LINE(SpecifiesNumGroupsToMerge) + DIAGNOSTIC_LINE(SpecifiesLoopScheduler); - } -}; - -template -struct FwdXdlAlgorithm : public FwdXdlAlgorithmBase -{ - CHECK_CONCEPT(T, SpecifiesGenericInstance) - - static constexpr bool c13 = c_SpecifiesGenericInstance; - - static consteval bool is_valid() { return c13 && FwdXdlAlgorithmBase::is_valid(); } - - static consteval auto message() -> std::string - { - return std::string("\n=== Forward XDL Algorithm Diagnostic (closest match) ===\n" - "Concepts for FwdXdl Algorithm:\n") + - FwdXdlAlgorithmBase::message() + DIAGNOSTIC_LINE(SpecifiesGenericInstance); - } -}; - -template -struct FwdWmmaAlgorithm -{ - CHECK_CONCEPT(T, ConvAlgorithmDescriptor) - CHECK_CONCEPT(T, SpecifiesThreadBlock) - CHECK_CONCEPT(T, SpecifiesBlockTransfer) - CHECK_CONCEPT(T, SpecifiesLdsTransfer) - CHECK_CONCEPT(T, SpecifiesThreadClusterAccessOrder) - CHECK_CONCEPT(T, SpecifiesSourceAccessOrder) - CHECK_CONCEPT(T, SpecifiesGridwiseWmmaGemm) - CHECK_CONCEPT(T, SpecifiesFwdConvSpecialization) - CHECK_CONCEPT(T, SpecifiesGemmSpecialization) - CHECK_CONCEPT(T, SpecifiesNumPrefetchStages) - CHECK_CONCEPT(T, SpecifiesLoopScheduler) - CHECK_CONCEPT(T, SpecifiesGridwiseGemmPipeline) - - static constexpr bool c1 = c_ConvAlgorithmDescriptor; - static constexpr bool c2 = c_SpecifiesThreadBlock; - static constexpr bool c3 = c_SpecifiesBlockTransfer; - static constexpr bool c4 = c_SpecifiesLdsTransfer; - static constexpr bool c5 = c_SpecifiesThreadClusterAccessOrder; - static constexpr bool c6 = c_SpecifiesSourceAccessOrder; - static constexpr bool c7 = c_SpecifiesGridwiseWmmaGemm; - static constexpr bool c8 = c_SpecifiesFwdConvSpecialization; - static constexpr bool c9 = c_SpecifiesGemmSpecialization; - static constexpr bool c10 = c_SpecifiesNumPrefetchStages; - static constexpr bool c11 = c_SpecifiesLoopScheduler; - static constexpr bool c12 = c_SpecifiesGridwiseGemmPipeline; - - static consteval bool is_valid() - { - return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && c10 && c11 && c12; - } - - static consteval auto message() -> std::string - { - return std::string("\n=== Forward WMMA Algorithm Diagnostic (closest match) ===\n" - "Concepts for FwdWmma Algorithm:\n") + - DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + DIAGNOSTIC_LINE(SpecifiesThreadBlock) + - DIAGNOSTIC_LINE(SpecifiesBlockTransfer) + DIAGNOSTIC_LINE(SpecifiesLdsTransfer) + - DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) + - DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) + - DIAGNOSTIC_LINE(SpecifiesGridwiseWmmaGemm) + - DIAGNOSTIC_LINE(SpecifiesFwdConvSpecialization) + - DIAGNOSTIC_LINE(SpecifiesGemmSpecialization) + - DIAGNOSTIC_LINE(SpecifiesNumPrefetchStages) + - DIAGNOSTIC_LINE(SpecifiesLoopScheduler) + - DIAGNOSTIC_LINE(SpecifiesGridwiseGemmPipeline); - } -}; - -template -struct FwdDlAlgorithm -{ - CHECK_CONCEPT(T, ConvAlgorithmDescriptor) - CHECK_CONCEPT(T, SpecifiesThreadBlock) - CHECK_CONCEPT(T, SpecifiesFwdConvSpecialization) - CHECK_CONCEPT(T, SpecifiesGemmSpecialization) - CHECK_CONCEPT(T, SpecifiesDlThreadConfig) - CHECK_CONCEPT(T, SpecifiesDlThreadCluster) - CHECK_CONCEPT(T, SpecifiesDlFwdBlockTransfer) - CHECK_CONCEPT(T, SpecifiesDlEpilogue) - - static constexpr bool c1 = c_ConvAlgorithmDescriptor; - static constexpr bool c2 = c_SpecifiesThreadBlock; - static constexpr bool c3 = c_SpecifiesFwdConvSpecialization; - static constexpr bool c4 = c_SpecifiesGemmSpecialization; - static constexpr bool c5 = c_SpecifiesDlThreadConfig; - static constexpr bool c6 = c_SpecifiesDlThreadCluster; - static constexpr bool c7 = c_SpecifiesDlFwdBlockTransfer; - static constexpr bool c8 = c_SpecifiesDlEpilogue; - - static consteval bool is_valid() { return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8; } - - static consteval auto message() -> std::string - { - return std::string("\n=== Forward DL Algorithm Diagnostic (closest match) ===\n" - "Concepts for FwdDl Algorithm:\n") + - DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + DIAGNOSTIC_LINE(SpecifiesThreadBlock) + - DIAGNOSTIC_LINE(SpecifiesFwdConvSpecialization) + - DIAGNOSTIC_LINE(SpecifiesGemmSpecialization) + - DIAGNOSTIC_LINE(SpecifiesDlThreadConfig) + - DIAGNOSTIC_LINE(SpecifiesDlThreadCluster) + - DIAGNOSTIC_LINE(SpecifiesDlFwdBlockTransfer) + DIAGNOSTIC_LINE(SpecifiesDlEpilogue); - } -}; - -template -struct TileAlgorithm -{ - CHECK_CONCEPT(T, ConvAlgorithmDescriptor) - CHECK_CONCEPT(T, SpecifiesTileThreadBlock) - CHECK_CONCEPT(T, SpecifiesTileTransfer) - CHECK_CONCEPT(T, SpecifiesTileConvSpecialization) - CHECK_CONCEPT(T, SpecifiesTileBlockGemm) - CHECK_CONCEPT(T, SpecifiesTileOptimizations) - - static constexpr bool c1 = c_ConvAlgorithmDescriptor; - static constexpr bool c2 = c_SpecifiesTileThreadBlock; - static constexpr bool c3 = c_SpecifiesTileTransfer; - static constexpr bool c4 = c_SpecifiesTileConvSpecialization; - static constexpr bool c5 = c_SpecifiesTileBlockGemm; - static constexpr bool c6 = c_SpecifiesTileOptimizations; - - static consteval bool is_valid() { return c1 && c2 && c3 && c4 && c5 && c6; } - - static consteval auto message() -> std::string - { - return std::string("\n=== CK Tile Algorithm Diagnostic (closest match) ===\n" - "Concepts for CK Tile Conv Algorithm:\n") + - DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + - DIAGNOSTIC_LINE(SpecifiesTileThreadBlock) + DIAGNOSTIC_LINE(SpecifiesTileTransfer) + - DIAGNOSTIC_LINE(SpecifiesTileConvSpecialization) + - DIAGNOSTIC_LINE(SpecifiesTileBlockGemm) + - DIAGNOSTIC_LINE(SpecifiesTileOptimizations); - } -}; - -template -struct LargeTensorAlgorithm : public FwdXdlAlgorithmBase -{ - CHECK_CONCEPT(T, SpecifiesLargeTensorSupport) - - static constexpr bool c13 = c_SpecifiesLargeTensorSupport; - - static consteval bool is_valid() - { - // Note: Check first if the specialization is set. - return c13 && FwdXdlAlgorithmBase::is_valid(); - } - - static consteval auto message() -> std::string - { - return std::string( - "\n=== Forward XDL Large Tensor Algorithm Diagnostic (closest match) ===\n" - "Concepts for FwdLargeTensorXdl Algorithm:\n") + - FwdXdlAlgorithmBase::message() + DIAGNOSTIC_LINE(SpecifiesLargeTensorSupport); - } -}; - -template -struct BwdXdlAlgorithmBase -{ - CHECK_CONCEPT(T, ConvAlgorithmDescriptor) - CHECK_CONCEPT(T, SpecifiesThreadBlock) - CHECK_CONCEPT(T, SpecifiesBlockTransfer4D) - CHECK_CONCEPT(T, SpecifiesLdsTransfer) - CHECK_CONCEPT(T, SpecifiesThreadClusterAccessOrder) - CHECK_CONCEPT(T, SpecifiesSourceAccessOrder) - CHECK_CONCEPT(T, SpecifiesGridwiseBwdXdlGemm) - CHECK_CONCEPT(T, SpecifiesBwdWeightConvSpecialization) - - static constexpr bool c1 = c_ConvAlgorithmDescriptor; - static constexpr bool c2 = c_SpecifiesThreadBlock; - static constexpr bool c3 = c_SpecifiesBlockTransfer4D; - static constexpr bool c4 = c_SpecifiesLdsTransfer; - static constexpr bool c5 = c_SpecifiesThreadClusterAccessOrder; - static constexpr bool c6 = c_SpecifiesSourceAccessOrder; - static constexpr bool c7 = c_SpecifiesGridwiseBwdXdlGemm; - static constexpr bool c8 = c_SpecifiesBwdWeightConvSpecialization; - - static consteval bool is_valid() { return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8; } - - static consteval auto message() -> std::string - { - return DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + DIAGNOSTIC_LINE(SpecifiesThreadBlock) + - DIAGNOSTIC_LINE(SpecifiesBlockTransfer4D) + DIAGNOSTIC_LINE(SpecifiesLdsTransfer) + - DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) + - DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) + - DIAGNOSTIC_LINE(SpecifiesGridwiseBwdXdlGemm) + - DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization); - } -}; - -template -struct BwdXdlAlgorithm : public BwdXdlAlgorithmBase -{ - CHECK_CONCEPT(T, SpecifiesTransposeTransfer) - CHECK_CONCEPT(T, SpecifiesGenericInstance) - - static constexpr bool c9 = c_SpecifiesTransposeTransfer; - static constexpr bool c10 = c_SpecifiesGenericInstance; - - static consteval bool is_valid() { return c9 && c10 && BwdXdlAlgorithmBase::is_valid(); } - - static consteval auto message() -> std::string - { - return std::string("\n=== Backward XDL Algorithm Diagnostic (closest match) ===\n" - "Concepts for BwdXdl Algorithm:\n") + - BwdXdlAlgorithmBase::message() + DIAGNOSTIC_LINE(SpecifiesTransposeTransfer) + - DIAGNOSTIC_LINE(SpecifiesGenericInstance); - } -}; - -template -struct BwdMultiDXdlAlgorithm : public BwdXdlAlgorithmBase -{ - CHECK_CONCEPT(T, SpecifiesMultipleDSupport) - - static constexpr bool c9 = c_SpecifiesMultipleDSupport; - - static consteval bool is_valid() { return c9 && BwdXdlAlgorithmBase::is_valid(); } - - static consteval auto message() -> std::string - { - return std::string("\n=== Backward XDL Algorithm Diagnostic (closest match) ===\n" - "Concepts for BwdXdl Algorithm:\n") + - BwdXdlAlgorithmBase::message() + DIAGNOSTIC_LINE(SpecifiesMultipleDSupport); - } -}; - -template -struct BwdXdlV3AlgorithmBase -{ - CHECK_CONCEPT(T, ConvAlgorithmDescriptor) - CHECK_CONCEPT(T, SpecifiesThreadBlock) - CHECK_CONCEPT(T, SpecifiesBlockTransfer) - CHECK_CONCEPT(T, SpecifiesLdsTransfer) - CHECK_CONCEPT(T, SpecifiesThreadClusterAccessOrder) - CHECK_CONCEPT(T, SpecifiesSourceAccessOrder) - CHECK_CONCEPT(T, SpecifiesGridwiseBwdXdlGemm) - CHECK_CONCEPT(T, SpecifiesBwdWeightConvSpecialization) - CHECK_CONCEPT(T, SpecifiesBlockGemm) - - static constexpr bool c1 = c_ConvAlgorithmDescriptor; - static constexpr bool c2 = c_SpecifiesThreadBlock; - static constexpr bool c3 = c_SpecifiesBlockTransfer; - static constexpr bool c4 = c_SpecifiesLdsTransfer; - static constexpr bool c5 = c_SpecifiesThreadClusterAccessOrder; - static constexpr bool c6 = c_SpecifiesSourceAccessOrder; - static constexpr bool c7 = c_SpecifiesGridwiseBwdXdlGemm; - static constexpr bool c8 = c_SpecifiesBwdWeightConvSpecialization; - static constexpr bool c9 = c_SpecifiesBlockGemm; - - static consteval bool is_valid() { return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9; } - - static consteval auto message() -> std::string - { - return DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + DIAGNOSTIC_LINE(SpecifiesThreadBlock) + - DIAGNOSTIC_LINE(SpecifiesBlockTransfer) + DIAGNOSTIC_LINE(SpecifiesLdsTransfer) + - DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) + - DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) + - DIAGNOSTIC_LINE(SpecifiesGridwiseBwdXdlGemm) + - DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization) + - DIAGNOSTIC_LINE(SpecifiesBlockGemm); - } -}; - -template -struct BwdXdlV3Algorithm : public BwdXdlV3AlgorithmBase -{ - CHECK_CONCEPT(T, SpecifiesGenericInstance) - - static constexpr bool c10 = c_SpecifiesGenericInstance; - - static consteval bool is_valid() { return c10 && BwdXdlV3AlgorithmBase::is_valid(); } - - static consteval auto message() -> std::string - { - return std::string("\n=== Backward XDL V3 Algorithm Diagnostic (closest match) ===\n" - "Concepts for BwdXdlV3 Algorithm:\n") + - BwdXdlV3AlgorithmBase::message() + DIAGNOSTIC_LINE(SpecifiesGenericInstance); - } -}; - +// Base algorithm concepts template -struct BwdTwoStageXdlAlgorithm : public BwdXdlV3AlgorithmBase -{ - CHECK_CONCEPT(T, SpecifiesTransposeTransfer) - CHECK_CONCEPT(T, SpecifiesGemmBatchOptions) - CHECK_CONCEPT(T, SpecifiesTwoStageSupport) - - static constexpr bool c10 = c_SpecifiesTransposeTransfer; - static constexpr bool c11 = c_SpecifiesGemmBatchOptions; - static constexpr bool c12 = c_SpecifiesTwoStageSupport; - - static consteval bool is_valid() - { - return c10 && c11 && c12 && BwdXdlV3AlgorithmBase::is_valid(); - } - - static consteval auto message() -> std::string - { - return std::string("\n=== Backward two stage XDL Algorithm Diagnostic (closest match) ===\n" - "Concepts for BwdXdlV3 Algorithm:\n") + - BwdXdlV3AlgorithmBase::message() + DIAGNOSTIC_LINE(SpecifiesTransposeTransfer) + - DIAGNOSTIC_LINE(SpecifiesGemmBatchOptions) + - DIAGNOSTIC_LINE(SpecifiesTwoStageSupport); - } -}; - -template -struct BwdWmmaAlgorithmBase -{ - CHECK_CONCEPT(T, ConvAlgorithmDescriptor) - CHECK_CONCEPT(T, SpecifiesThreadBlock) - CHECK_CONCEPT(T, SpecifiesBlockTransfer) - CHECK_CONCEPT(T, SpecifiesLdsTransfer) - CHECK_CONCEPT(T, SpecifiesThreadClusterAccessOrder) - CHECK_CONCEPT(T, SpecifiesSourceAccessOrder) - CHECK_CONCEPT(T, SpecifiesGridwiseWmmaGemm) - CHECK_CONCEPT(T, SpecifiesBwdWeightConvSpecialization) - - static constexpr bool c1 = c_ConvAlgorithmDescriptor; - static constexpr bool c2 = c_SpecifiesThreadBlock; - static constexpr bool c3 = c_SpecifiesBlockTransfer; - static constexpr bool c4 = c_SpecifiesLdsTransfer; - static constexpr bool c5 = c_SpecifiesThreadClusterAccessOrder; - static constexpr bool c6 = c_SpecifiesSourceAccessOrder; - static constexpr bool c7 = c_SpecifiesGridwiseWmmaGemm; - static constexpr bool c8 = c_SpecifiesBwdWeightConvSpecialization; - - static consteval bool is_valid() { return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8; } - - static consteval auto message() -> std::string - { - return DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + DIAGNOSTIC_LINE(SpecifiesThreadBlock) + - DIAGNOSTIC_LINE(SpecifiesBlockTransfer) + DIAGNOSTIC_LINE(SpecifiesLdsTransfer) + - DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) + - DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) + - DIAGNOSTIC_LINE(SpecifiesGridwiseWmmaGemm) + - DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization); - } -}; - -template -struct BwdWmmaAlgorithm : public BwdWmmaAlgorithmBase -{ - CHECK_CONCEPT(T, SpecifiesNumPrefetchStages) - CHECK_CONCEPT(T, SpecifiesLoopScheduler) - CHECK_CONCEPT(T, SpecifiesGridwiseGemmPipeline) - CHECK_CONCEPT(T, SpecifiesGenericInstance) - - static constexpr bool c9 = c_SpecifiesNumPrefetchStages; - static constexpr bool c10 = c_SpecifiesLoopScheduler; - static constexpr bool c11 = c_SpecifiesGridwiseGemmPipeline; - static constexpr bool c12 = c_SpecifiesGenericInstance; - - static consteval bool is_valid() - { - return c9 && c10 && c11 && c12 && BwdWmmaAlgorithmBase::is_valid(); - } - - static consteval auto message() -> std::string - { - return std::string("\n=== Backward WMMA Algorithm Diagnostic (closest match) ===\n" - "Concepts for BwdWmma Algorithm:\n") + - BwdWmmaAlgorithmBase::message() + DIAGNOSTIC_LINE(SpecifiesNumPrefetchStages) + - DIAGNOSTIC_LINE(SpecifiesLoopScheduler) + - DIAGNOSTIC_LINE(SpecifiesGridwiseGemmPipeline) + - DIAGNOSTIC_LINE(SpecifiesGenericInstance); - } -}; +concept FwdXdlAlgorithmBase = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesBlockTransfer && + SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && + SpecifiesSourceAccessOrder && SpecifiesGridwiseFwdXdlGemm && + SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && + SpecifiesNumPrefetchStages && SpecifiesNumGroupsToMerge && SpecifiesLoopScheduler; template -struct BwdWmmaV3AlgorithmBase -{ - CHECK_CONCEPT(T, ConvAlgorithmDescriptor) - CHECK_CONCEPT(T, SpecifiesThreadBlock) - CHECK_CONCEPT(T, SpecifiesBlockTransfer) - CHECK_CONCEPT(T, SpecifiesLdsTransfer) - CHECK_CONCEPT(T, SpecifiesThreadClusterAccessOrder) - CHECK_CONCEPT(T, SpecifiesSourceAccessOrder) - CHECK_CONCEPT(T, SpecifiesGridwiseWmmaGemm) - CHECK_CONCEPT(T, SpecifiesBwdWeightConvSpecialization) - CHECK_CONCEPT(T, SpecifiesBlockGemm) - - static constexpr bool c1 = c_ConvAlgorithmDescriptor; - static constexpr bool c2 = c_SpecifiesThreadBlock; - static constexpr bool c3 = c_SpecifiesBlockTransfer; - static constexpr bool c4 = c_SpecifiesLdsTransfer; - static constexpr bool c5 = c_SpecifiesThreadClusterAccessOrder; - static constexpr bool c6 = c_SpecifiesSourceAccessOrder; - static constexpr bool c7 = c_SpecifiesGridwiseWmmaGemm; - static constexpr bool c8 = c_SpecifiesBwdWeightConvSpecialization; - static constexpr bool c9 = c_SpecifiesBlockGemm; - - static consteval bool is_valid() { return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9; } - - static consteval auto message() -> std::string - { - return DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + DIAGNOSTIC_LINE(SpecifiesThreadBlock) + - DIAGNOSTIC_LINE(SpecifiesBlockTransfer) + DIAGNOSTIC_LINE(SpecifiesLdsTransfer) + - DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) + - DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) + - DIAGNOSTIC_LINE(SpecifiesGridwiseWmmaGemm) + - DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization) + - DIAGNOSTIC_LINE(SpecifiesBlockGemm); - } -}; +concept BwdXdlAlgorithmBase = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesBlockTransfer4D && + SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && + SpecifiesSourceAccessOrder && SpecifiesGridwiseBwdXdlGemm && + SpecifiesBwdWeightConvSpecialization; template -struct BwdMultiDWmmaV3Algorithm : public BwdWmmaV3AlgorithmBase -{ - CHECK_CONCEPT(T, SpecifiesMultipleDSupport) - - static constexpr bool c10 = c_SpecifiesMultipleDSupport; - - static consteval bool is_valid() { return c10 && BwdWmmaAlgorithmBase::is_valid(); } - - static consteval auto message() -> std::string - { - return std::string("\n=== Backward WMMA Algorithm Diagnostic (closest match) ===\n" - "Concepts for BwdMultiDWmma Algorithm:\n") + - BwdWmmaAlgorithmBase::message() + DIAGNOSTIC_LINE(SpecifiesMultipleDSupport); - } -}; - +concept BwdXdlV3AlgorithmBase = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesBlockTransfer && + SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && + SpecifiesSourceAccessOrder && SpecifiesGridwiseBwdXdlGemm && + SpecifiesBwdWeightConvSpecialization && SpecifiesBlockGemm; template -struct BwdWmmaV3Algorithm : public BwdWmmaV3AlgorithmBase -{ - CHECK_CONCEPT(T, SpecifiesTransposeTransfer) - CHECK_CONCEPT(T, SpecifiesGenericInstance) - - static constexpr bool c10 = c_SpecifiesTransposeTransfer; - static constexpr bool c11 = c_SpecifiesGenericInstance; - - static consteval bool is_valid() { return c10 && c11 && BwdWmmaV3AlgorithmBase::is_valid(); } - - static consteval auto message() -> std::string - { - return std::string("\n=== Backward WMMA V3 Algorithm Diagnostic (closest match) ===\n" - "Concepts for BwdWmmaV3 Algorithm:\n") + - BwdWmmaV3AlgorithmBase::message() + DIAGNOSTIC_LINE(SpecifiesTransposeTransfer) + - DIAGNOSTIC_LINE(SpecifiesGenericInstance); - } -}; +concept BwdWmmaAlgorithmBase = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesBlockTransfer && + SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && + SpecifiesSourceAccessOrder && SpecifiesGridwiseWmmaGemm && + SpecifiesBwdWeightConvSpecialization; template -struct BwdTwoStageWmmaV3Algorithm : public BwdWmmaV3AlgorithmBase -{ - CHECK_CONCEPT(T, SpecifiesTransposeTransfer) - CHECK_CONCEPT(T, SpecifiesTwoStageSupport) - CHECK_CONCEPT(T, SpecifiesGemmBatchOptions) - - static constexpr bool c10 = c_SpecifiesTransposeTransfer; - static constexpr bool c11 = c_SpecifiesTwoStageSupport; - static constexpr bool c12 = c_SpecifiesGemmBatchOptions; - - static consteval bool is_valid() - { - return c10 && c11 && c12 && BwdWmmaV3AlgorithmBase::is_valid(); - } - - static consteval auto message() -> std::string - { - return std::string( - "\n=== Backward Two Stage WMMA V3 Algorithm Diagnostic (closest match) ===\n" - "Concepts for BwdTwoStageWmmaV3 Algorithm:\n") + - BwdWmmaV3AlgorithmBase::message() + DIAGNOSTIC_LINE(SpecifiesTransposeTransfer) + - DIAGNOSTIC_LINE(SpecifiesGemmBatchOptions) + - DIAGNOSTIC_LINE(SpecifiesTwoStageSupport); - } -}; +concept BwdWmmaV3AlgorithmBase = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesBlockTransfer && + SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && + SpecifiesSourceAccessOrder && SpecifiesGridwiseWmmaGemm && + SpecifiesBwdWeightConvSpecialization && SpecifiesBlockGemm; +// Reference algorithm concept template -struct BwdDlAlgorithm -{ - CHECK_CONCEPT(T, ConvAlgorithmDescriptor) - CHECK_CONCEPT(T, SpecifiesThreadBlock) - CHECK_CONCEPT(T, SpecifiesBwdWeightConvSpecialization) - CHECK_CONCEPT(T, SpecifiesDlThreadConfig) - CHECK_CONCEPT(T, SpecifiesDlThreadCluster) - CHECK_CONCEPT(T, SpecifiesDlBwdBlockTransfer) - CHECK_CONCEPT(T, SpecifiesDlEpilogue) - - static constexpr bool c1 = c_ConvAlgorithmDescriptor; - static constexpr bool c2 = c_SpecifiesThreadBlock; - static constexpr bool c3 = c_SpecifiesBwdWeightConvSpecialization; - static constexpr bool c4 = c_SpecifiesDlThreadConfig; - static constexpr bool c5 = c_SpecifiesDlThreadCluster; - static constexpr bool c6 = c_SpecifiesDlBwdBlockTransfer; - static constexpr bool c7 = c_SpecifiesDlEpilogue; - - static consteval bool is_valid() { return c1 && c2 && c3 && c4 && c5 && c6 && c7; } - - static consteval auto message() -> std::string - { - return std::string("\n=== Backward DL Algorithm Diagnostic (closest match) ===\n" - "Concepts for BwdDl Algorithm:\n") + - DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + DIAGNOSTIC_LINE(SpecifiesThreadBlock) + - DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization) + - DIAGNOSTIC_LINE(SpecifiesDlThreadConfig) + - DIAGNOSTIC_LINE(SpecifiesDlThreadCluster) + - DIAGNOSTIC_LINE(SpecifiesDlBwdBlockTransfer) + DIAGNOSTIC_LINE(SpecifiesDlEpilogue); - } -}; +concept ReferenceAlgorithm = ConvAlgorithmDescriptor && SpecifiesReferenceAlgorithm; +// Tile-based algorithm concept template -consteval int count_matches_fwd_xdl_v3() -{ - using Alg = FwdXdlV3Algorithm; - return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + - Alg::c10; -} +concept TileAlgorithm = ConvAlgorithmDescriptor && SpecifiesTileThreadBlock && + SpecifiesTileTransfer && SpecifiesTileConvSpecialization && + SpecifiesTileBlockGemm && SpecifiesTileOptimizations; +// FWD XDL algorithm concepts template -consteval int count_matches_fwd_xdl() -{ - using Alg = FwdXdlAlgorithm; - return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + - Alg::c10 + Alg::c11 + Alg::c12 + Alg::c13; -} +concept FwdXdlAlgorithm = FwdXdlAlgorithmBase && SpecifiesGenericInstance; template -consteval int count_matches_fwd_wmma() -{ - using Alg = FwdWmmaAlgorithm; - return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + - Alg::c10 + Alg::c11; -} +concept LargeTensorAlgorithm = FwdXdlAlgorithmBase && SpecifiesLargeTensorSupport; template -consteval int count_matches_fwd_dl() -{ - using Alg = FwdDlAlgorithm; - return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8; -} +concept FwdXdlV3Algorithm = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesBlockTransfer && + SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && + SpecifiesSourceAccessOrder && SpecifiesGridwiseFwdXdlGemm && + SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && SpecifiesBlockGemm; +// FWD WMMA algorithm concepts template -consteval int count_matches_bwd_xdl() -{ - using Alg = BwdXdlAlgorithm; - return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9; -} +concept FwdWmmaAlgorithm = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesBlockTransfer && + SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && + SpecifiesSourceAccessOrder && SpecifiesGridwiseWmmaGemm && + SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && + SpecifiesNumPrefetchStages && SpecifiesLoopScheduler && SpecifiesGridwiseGemmPipeline; +// FWD DL algorithms template -consteval int count_matches_bwd_multi_d_xdl() -{ - using Alg = BwdMultiDXdlAlgorithm; - return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9; -} +concept FwdDlAlgorithm = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesFwdConvSpecialization && + SpecifiesGemmSpecialization && SpecifiesDlThreadConfig && SpecifiesDlThreadCluster && + SpecifiesDlFwdBlockTransfer && SpecifiesDlEpilogue; +// BWD weight XDL algorithm concepts template -consteval int count_matches_bwd_xdl_v3() -{ - using Alg = BwdXdlV3Algorithm; - return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9; -} +concept BwdXdlAlgorithm = + BwdXdlAlgorithmBase && SpecifiesTransposeTransfer && SpecifiesGenericInstance; template -consteval int count_matches_bwd_two_stage_xdl() -{ - using Alg = BwdTwoStageXdlAlgorithm; - return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + - Alg::c10 + Alg::c11 + Alg::c12; -} +concept BwdMultiDXdlAlgorithm = BwdXdlAlgorithmBase && SpecifiesMultipleDSupport; template -consteval int count_matches_bwd_wmma() -{ - using Alg = BwdWmmaAlgorithm; - return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + - Alg::c10 + Alg::c11 + Alg::c12; -} +concept BwdXdlV3Algorithm = BwdXdlV3AlgorithmBase && SpecifiesGenericInstance; template -consteval int count_matches_bwd_multi_d_wmma() -{ - using Alg = BwdMultiDWmmaV3Algorithm; - return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + - Alg::c10 + Alg::c11 + Alg::c12; -} +concept BwdTwoStageXdlAlgorithm = BwdXdlV3AlgorithmBase && SpecifiesTransposeTransfer && + SpecifiesGemmBatchOptions && SpecifiesTwoStageSupport; +// BWD weight WMMA algorithm concepts template -consteval int count_matches_bwd_wmma_v3() -{ - using Alg = BwdWmmaV3Algorithm; - return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + - Alg::c10 + Alg::c11; -} +concept BwdWmmaAlgorithm = + BwdWmmaAlgorithmBase && SpecifiesNumPrefetchStages && SpecifiesLoopScheduler && + SpecifiesGridwiseGemmPipeline && SpecifiesGenericInstance; template -consteval int count_matches_bwd_two_stage_wmma_v3() -{ - using Alg = BwdTwoStageWmmaV3Algorithm; - return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + - Alg::c10 + Alg::c11 + Alg::c12; -} +concept BwdMultiDWmmaV3Algorithm = BwdWmmaV3AlgorithmBase && SpecifiesMultipleDSupport; template -consteval int count_matches_bwd_dl() -{ - using Alg = BwdDlAlgorithm; - return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7; -} +concept BwdWmmaV3Algorithm = + BwdWmmaV3AlgorithmBase && SpecifiesTransposeTransfer && SpecifiesGenericInstance; template -consteval int count_matches_large_tensor() -{ - using Alg = LargeTensorAlgorithm; - return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + - Alg::c10 + Alg::c11 + Alg::c12 + Alg::c13; -} +concept BwdTwoStageWmmaV3Algorithm = BwdWmmaV3AlgorithmBase && SpecifiesTransposeTransfer && + SpecifiesGemmBatchOptions && SpecifiesTwoStageSupport; +// BWD weigth DL algorithms template -consteval int count_matches_tile() -{ - using Alg = TileAlgorithm; - return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6; -} - -template -consteval void diagnose_fwd_algorithm_signature() -{ - // Find closest matching variant - constexpr int xdl_v3_matches = count_matches_fwd_xdl_v3(); - constexpr int xdl_matches = count_matches_fwd_xdl(); - constexpr int wmma_matches = count_matches_fwd_wmma(); - constexpr int dl_matches = count_matches_fwd_dl(); - constexpr int large_tensor_matches = count_matches_large_tensor(); - constexpr int tile_matches = count_matches_tile(); - - // Check whether we have XDL or WMMA algorithm - if constexpr(SpecifiesGridwiseFwdXdlGemm) - { - constexpr int max_1 = xdl_v3_matches > xdl_matches ? xdl_v3_matches : xdl_matches; - constexpr int max_2 = max_1 > dl_matches ? max_1 : dl_matches; - constexpr int max_matches = large_tensor_matches > max_2 ? large_tensor_matches : max_2; - - if constexpr(max_matches == xdl_v3_matches) - { - using Alg = FwdXdlV3Algorithm; - static_assert(Alg::is_valid(), Alg::message()); - } - else if constexpr(max_matches == xdl_matches) - { - using Alg = FwdXdlAlgorithm; - static_assert(Alg::is_valid(), Alg::message()); - } - else if constexpr(max_matches == dl_matches) - { - using Alg = FwdDlAlgorithm; - static_assert(Alg::is_valid(), Alg::message()); - } - else if constexpr(max_matches == large_tensor_matches) - { - using Alg = LargeTensorAlgorithm; - static_assert(Alg::is_valid(), Alg::message()); - } - } - else if constexpr(SpecifiesGridwiseWmmaGemm) - { - using Alg = FwdWmmaAlgorithm; - static_assert(Alg::is_valid(), Alg::message()); - } - else - { - // Find maximum matches across all variants - constexpr int max_1 = xdl_v3_matches > xdl_matches ? xdl_v3_matches : xdl_matches; - constexpr int max_2 = wmma_matches > dl_matches ? wmma_matches : dl_matches; - constexpr int max_3 = max_1 > max_2 ? max_1 : max_2; - constexpr int max_4 = max_3 > large_tensor_matches ? max_3 : large_tensor_matches; - constexpr int max_matches = max_4 > tile_matches ? max_4 : tile_matches; - - // If we cannot match with neither WMMA nor XDL, try all algorithms for diagnostics - // and see whichi is the closest match. - if constexpr(max_matches == xdl_v3_matches) - { - using Alg = FwdXdlV3Algorithm; - static_assert(Alg::is_valid(), Alg::message()); - } - else if constexpr(max_matches == xdl_matches) - { - using Alg = FwdXdlAlgorithm; - static_assert(Alg::is_valid(), Alg::message()); - } - else if constexpr(max_matches == wmma_matches) - { - using Alg = FwdWmmaAlgorithm; - static_assert(Alg::is_valid(), Alg::message()); - } - else if constexpr(max_matches == dl_matches) - { - using Alg = FwdDlAlgorithm; - static_assert(Alg::is_valid(), Alg::message()); - } - else if constexpr(max_matches == large_tensor_matches) - { - using Alg = LargeTensorAlgorithm; - static_assert(Alg::is_valid(), Alg::message()); - } - else if constexpr(max_matches == tile_matches) - { - using Alg = TileAlgorithm; - static_assert(Alg::is_valid(), Alg::message()); - } - else - { - // This should never happen - static_assert(false, - "Internal Error: No matching algorithm variant found for diagnostics."); - } - } -} - -template -consteval void diagnose_bwd_weight_algorithm_signature() -{ - constexpr int xdl_matches = count_matches_bwd_xdl(); - constexpr int xdl_v3_matches = count_matches_fwd_xdl_v3(); - constexpr int two_stage_xdl_matches = count_matches_bwd_two_stage_xdl(); - constexpr int dl_matches = count_matches_bwd_dl(); - constexpr int multi_d_xdl_matches = count_matches_bwd_multi_d_xdl(); - constexpr int wmma_v3_matches = count_matches_bwd_wmma_v3(); - constexpr int two_stage_wmma_v3_matches = count_matches_bwd_two_stage_wmma_v3(); - constexpr int wmma_matches = count_matches_bwd_wmma(); - constexpr int multi_d_wmma_matches = count_matches_bwd_multi_d_wmma(); - - // Check whether we have XDL or WMMA algorithm - if constexpr(SpecifiesGridwiseBwdXdlGemm) - { - constexpr int max1 = xdl_v3_matches > xdl_matches ? xdl_v3_matches : xdl_matches; - constexpr int max2 = max1 > two_stage_xdl_matches ? max1 : two_stage_xdl_matches; - constexpr int max3 = max2 > dl_matches ? max2 : dl_matches; - constexpr int max_matches = max3 > multi_d_xdl_matches ? max3 : multi_d_xdl_matches; - - if constexpr(max_matches == xdl_matches) - { - using Alg = BwdXdlAlgorithm; - static_assert(Alg::is_valid(), Alg::message()); - } - else if constexpr(max_matches == xdl_v3_matches) - { - using Alg = BwdXdlV3Algorithm; - static_assert(Alg::is_valid(), Alg::message()); - } - else if constexpr(max_matches == two_stage_xdl_matches) - { - using Alg = BwdTwoStageXdlAlgorithm; - static_assert(Alg::is_valid(), Alg::message()); - } - else if constexpr(max_matches == dl_matches) - { - using Alg = BwdDlAlgorithm; - static_assert(Alg::is_valid(), Alg::message()); - } - else if constexpr(max_matches == multi_d_xdl_matches) - { - using Alg = BwdMultiDXdlAlgorithm; - static_assert(Alg::is_valid(), Alg::message()); - } - } - else if constexpr(SpecifiesGridwiseWmmaGemm) - { - constexpr int max_1 = wmma_v3_matches > two_stage_wmma_v3_matches - ? wmma_v3_matches - : two_stage_wmma_v3_matches; - constexpr int max_2 = max_1 > wmma_matches ? max_1 : wmma_matches; - constexpr int max_matches = multi_d_wmma_matches > max_2 ? multi_d_wmma_matches : max_2; - - if constexpr(max_matches == wmma_v3_matches) - { - using Alg = BwdWmmaV3Algorithm; - static_assert(Alg::is_valid(), Alg::message()); - } - else if constexpr(max_matches == two_stage_wmma_v3_matches) - { - using Alg = BwdTwoStageWmmaV3Algorithm; - static_assert(Alg::is_valid(), Alg::message()); - } - else if constexpr(max_matches == wmma_matches) - { - using Alg = BwdWmmaAlgorithm; - static_assert(Alg::is_valid(), Alg::message()); - } - else if constexpr(max_matches == multi_d_wmma_matches) - { - using Alg = BwdMultiDWmmaV3Algorithm; - static_assert(Alg::is_valid(), Alg::message()); - } - } - else - { - // If we cannot match with neither WMMA nor XDL, try all algorithms for diagnostics - // and see which is the closest match. - constexpr int max1 = xdl_v3_matches > xdl_matches ? xdl_v3_matches : xdl_matches; - constexpr int max2 = max1 > two_stage_xdl_matches ? max1 : two_stage_xdl_matches; - constexpr int max3 = max2 > dl_matches ? max2 : dl_matches; - constexpr int max_matches = max3 > multi_d_xdl_matches ? max3 : multi_d_xdl_matches; - - if constexpr(max_matches == xdl_matches) - { - using Alg = BwdXdlAlgorithm; - static_assert(Alg::is_valid(), Alg::message()); - } - else if constexpr(max_matches == xdl_v3_matches) - { - using Alg = BwdXdlV3Algorithm; - static_assert(Alg::is_valid(), Alg::message()); - } - else if constexpr(max_matches == two_stage_xdl_matches) - { - using Alg = BwdTwoStageXdlAlgorithm; - static_assert(Alg::is_valid(), Alg::message()); - } - else if constexpr(max_matches == dl_matches) - { - using Alg = BwdDlAlgorithm; - static_assert(Alg::is_valid(), Alg::message()); - } - else if constexpr(max_matches == multi_d_xdl_matches) - { - using Alg = BwdMultiDXdlAlgorithm; - static_assert(Alg::is_valid(), Alg::message()); - } - else - { - // This should never happen - static_assert(false, - "Internal Error: No matching algorithm variant found for diagnostics."); - } - } -} +concept BwdDlAlgorithm = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && + SpecifiesBwdWeightConvSpecialization && SpecifiesDlThreadConfig && + SpecifiesDlThreadCluster && SpecifiesDlBwdBlockTransfer && SpecifiesDlEpilogue; } // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp index 892ad832d8b..319293cff14 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -107,41 +107,45 @@ constexpr auto make_conv_instance() using AlgoType = std::remove_const_t; // Reference algorithm supports all directions - if constexpr(ReferenceAlgorithm::is_valid()) + if constexpr(ReferenceAlgorithm) { return typename ReferenceFactory::Instance{}; } // CK Tile supports common factory for each direction - if constexpr(TileAlgorithm::is_valid()) + if constexpr(TileAlgorithm) { return typename ConvTileFactory::Instance{}; } // Forward direction (supports most algorithm variants) else if constexpr(ConvDirectionIsForward) { - if constexpr(FwdXdlV3Algorithm::is_valid()) + if constexpr(FwdXdlV3Algorithm) { return typename ConvFwdXdlV3Factory::Instance{}; } - else if constexpr(FwdXdlAlgorithm::is_valid()) + else if constexpr(FwdXdlAlgorithm) { return typename ConvFwdXdlFactory::Instance{}; } - else if constexpr(FwdWmmaAlgorithm::is_valid()) + else if constexpr(FwdWmmaAlgorithm) { return typename ConvFwdWmmaFactory::Instance{}; } - else if constexpr(FwdDlAlgorithm::is_valid()) + else if constexpr(FwdDlAlgorithm) { return typename ConvFwdDlFactory::Instance{}; } - else if constexpr(LargeTensorAlgorithm::is_valid()) + else if constexpr(LargeTensorAlgorithm) { return typename ConvFwdLargeTensorFactory::Instance{}; } else { - diagnose_fwd_algorithm_signature(); + static_assert( + false, + "No suitable forward convolution kernel factory found for the provided ALGORITHM. " + "The ALGORITHM must satisfy requirements for one of: Reference, Tile, XDL V3, XDL, " + "WMMA, DL (NHWC layout), or Large Tensor variant."); } } // Backward data direction (will expand with more algorithms in the future) @@ -155,49 +159,54 @@ constexpr auto make_conv_instance() // Backward weight direction (will expand with more algorithms in the future) else if constexpr(ConvDirectionIsBackwardWeight) { - if constexpr(BwdXdlAlgorithm::is_valid()) + if constexpr(BwdXdlAlgorithm) { return typename ConvBwdWeightXdlFactory::Instance{}; } - else if constexpr(BwdXdlV3Algorithm::is_valid()) + else if constexpr(BwdXdlV3Algorithm) { return typename ConvBwdWeightXdlV3Factory::Instance{}; } - else if constexpr(BwdTwoStageXdlAlgorithm::is_valid()) + else if constexpr(BwdTwoStageXdlAlgorithm) { return typename ConvBwdWeightTwoStageXdlFactory::Instance{}; } - else if constexpr(BwdDlAlgorithm::is_valid()) + else if constexpr(BwdDlAlgorithm) { return typename ConvBwdWeightDlFactory::Instance{}; } - else if constexpr(BwdMultiDXdlAlgorithm::is_valid()) + else if constexpr(BwdMultiDXdlAlgorithm) { return typename ConvBwdWeightMultiDXdlFactory::Instance{}; } - else if constexpr(BwdWmmaV3Algorithm::is_valid()) + else if constexpr(BwdWmmaV3Algorithm) { return typename ConvBwdWeightWmmaV3Factory::Instance{}; } - else if constexpr(BwdTwoStageWmmaV3Algorithm::is_valid()) + else if constexpr(BwdTwoStageWmmaV3Algorithm) { return typename ConvBwdWeightTwoStageWmmaV3Factory:: Instance{}; } - else if constexpr(BwdWmmaAlgorithm::is_valid()) + else if constexpr(BwdWmmaAlgorithm) { return typename ConvBwdWeightWmmaFactory::Instance{}; } - else if constexpr(BwdMultiDWmmaV3Algorithm::is_valid()) + else if constexpr(BwdMultiDWmmaV3Algorithm) { return typename ConvBwdWeightMultiDWmmaV3Factory:: Instance{}; } else { - diagnose_bwd_weight_algorithm_signature(); + static_assert( + false, + "No suitable backward weight convolution kernel factory found for the provided " + "ALGORITHM. The ALGORITHM must satisfy requirements for one of: Reference, Tile, " + "XDL, XDL V3, Two-Stage XDL, DL, Multi-D XDL, WMMA V3, Two-Stage " + "WMMA V3, WMMA, or Multi-D WMMA V3 variant."); } } else From 5f639559a109a3842c8e35b20aa74e1a72e649f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Mon, 5 Jan 2026 09:52:46 -0500 Subject: [PATCH 51/81] WIP: Unify warp GEMM and thread distribution descriptions. --- .../builder/conv_algorithm_concepts.hpp | 104 +++----- .../builder/factory/conv_algorithms.hpp | 48 ++-- .../builder/include/ck_tile/builder/types.hpp | 6 + ...nv_bwd_weight_multi_d_wmma_cshuffle_v3.cpp | 2 +- ..._bwd_weight_two_stage_wmma_cshuffle_v3.cpp | 4 +- ...conv_bwd_weight_two_stage_xdl_cshuffle.cpp | 4 +- ...test_ckb_conv_bwd_weight_wmma_cshuffle.cpp | 2 +- ...t_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp | 2 +- ...st_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp | 2 +- .../conv/ck/test_ckb_conv_fwd_1d_bf16.cpp | 2 +- .../test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp | 2 +- .../conv/ck/test_ckb_conv_fwd_2d_bf16.cpp | 4 +- .../conv/ck/test_ckb_conv_fwd_2d_fp16.cpp | 2 +- .../conv/ck/test_ckb_conv_fwd_2d_fp32.cpp | 2 +- .../conv/ck/test_ckb_conv_fwd_3d_bf16.cpp | 2 +- .../conv/ck/test_ckb_conv_fwd_3d_fp16.cpp | 2 +- .../conv/ck/test_ckb_conv_fwd_3d_fp32.cpp | 2 +- .../test/impl/conv_algorithm_types.hpp | 248 +++++++----------- .../test/utils/ckb_conv_test_configs.hpp | 182 +++++++------ .../test/utils/conv_algorithm_type_utils.hpp | 131 ++++----- 20 files changed, 327 insertions(+), 426 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index 2b0f63296be..cbc277a8814 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -30,16 +30,17 @@ concept ThreadBlockDescriptor = requires(T t) { // Concept for parameters that describe a gridwise XDL GEMM problem. template -concept GridwiseXdlGemmDescriptor = requires(T t) { - { t.m_per_xdl } -> SizeType; - { t.n_per_xdl } -> SizeType; - { t.m_xdl_per_wave } -> SizeType; - { t.n_xdl_per_wave } -> SizeType; +concept WarpGemmDescriptor = requires(T t) { + { t.matrix_instruction } -> std::convertible_to; + { t.gemm_m_per_instruction } -> SizeType; + { t.gemm_n_per_instruction } -> SizeType; + { t.gemm_m_iters_per_wave } -> SizeType; + { t.gemm_n_iters_per_wave } -> SizeType; }; // Concept for parameter that describe block GEMM problem. template -concept BlockGemmPipelineDescriptor = requires(T t) { +concept GemmPipelineDescriptor = requires(T t) { { t.pipeline_version } -> std::convertible_to; { t.scheduler } -> std::convertible_to; }; @@ -56,14 +57,14 @@ concept GridwiseWmmaGemmDescriptor = requires(T t) { // Concept for vectorized data transfer for convolution input tensors. template -concept BlockTransferDescriptor = requires(T t) { +concept InputTileThreadDistributionDescriptor3D = requires(T t) { { t.k0 } -> SizeType; { t.m_n } -> SizeType; { t.k1 } -> SizeType; }; template -concept BlockTransferDescriptor4D = requires(T t) { +concept InputTileThreadDistributionDescriptor4D = requires(T t) { { t.k0 } -> SizeType; { t.m_n } -> SizeType; { t.k1 } -> SizeType; @@ -72,16 +73,17 @@ concept BlockTransferDescriptor4D = requires(T t) { // Concept for thread cluster dimensions for GEMM output tensor. template -concept ThreadClusterDescriptor = requires(T t) { - { t.m_block } -> SizeType; - { t.m_wave_per_xdl } -> SizeType; - { t.n_block } -> SizeType; - { t.n_wave_per_xdl } -> SizeType; +concept OutputTileThreadDistributionDescriptor = requires(T t) { + { t.gemm_m_block_size } -> SizeType; + { t.gemm_m_per_block } -> SizeType; + { t.gemm_n_block_size } -> SizeType; + { t.gemm_n_per_block } -> SizeType; }; // Concept for the LDS transfer for the convolution input tensors. template -concept LdsTransferDescriptor = requires(T t) { +concept LdsInputTransferDescriptor = requires(T t) { + { t.global_memory_vector_load_size } -> SizeType; { t.src_vector_dim } -> SizeType; { t.src_scalar_per_vector } -> SizeType; { t.lds_dst_scalar_per_vector } -> SizeType; @@ -168,54 +170,27 @@ concept SpecifiesTileThreadBlock = requires { { T::thread_block } -> TileThreadBlockDescriptor; }; -// Concept to check if a struct specifies gridwise XDL GEMM info. +// Concept to check if a struct specifies warp GEMM info. template -concept GridwiseFwdXdlGemmDescriptor = requires(T t) { - { t.ak1 } -> SizeType; - { t.bk1 } -> SizeType; - { t.xdl_params } -> GridwiseXdlGemmDescriptor; -}; - -// Concept to check if a struct specifies gridwise XDL GEMM info. -template -concept GridwiseBwdXdlGemmDescriptor = requires(T t) { - { t.k1 } -> SizeType; - { t.xdl_params } -> GridwiseXdlGemmDescriptor; -}; - -// Concept to check if a struct specifies gridwise XDL GEMM info. -template -concept SpecifiesGridwiseFwdXdlGemm = requires(T t) { - { t.gridwise_gemm } -> GridwiseFwdXdlGemmDescriptor; -}; - -// Concept to check if a struct specifies gridwise XDL GEMM info. -template -concept SpecifiesGridwiseBwdXdlGemm = requires(T t) { - { t.gridwise_gemm } -> GridwiseBwdXdlGemmDescriptor; -}; - -// Concept to check if a struct specifies gridwise WMMA GEMM info. -template -concept SpecifiesGridwiseWmmaGemm = requires(T t) { - { t.gridwise_gemm } -> GridwiseWmmaGemmDescriptor; +concept SpecifiesWarpGemm = requires(T t) { + { t.warp_gemm } -> WarpGemmDescriptor; }; // Concept to check if a struct specifies convolution input and output block transfer info. template -concept SpecifiesBlockTransfer = requires(T t) { - { T::transfer.a.block_transfer } -> BlockTransferDescriptor; - { T::transfer.b.block_transfer } -> BlockTransferDescriptor; - { T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor; +concept SpecifiesThreadDistribution3D = requires(T t) { + { T::transfer.a.thread_distribution } -> InputTileThreadDistributionDescriptor3D; + { T::transfer.b.thread_distribution } -> InputTileThreadDistributionDescriptor3D; + { T::transfer.c.thread_distribution } -> OutputTileThreadDistributionDescriptor; }; // Concept to check if a struct specifies convolution input and output block transfer info // for 4D thread slices. template -concept SpecifiesBlockTransfer4D = requires(T t) { - { T::transfer.a.block_transfer } -> BlockTransferDescriptor4D; - { T::transfer.b.block_transfer } -> BlockTransferDescriptor4D; - { T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor; +concept SpecifiesThreadDistribution4D = requires(T t) { + { T::transfer.a.thread_distribution } -> InputTileThreadDistributionDescriptor4D; + { T::transfer.b.thread_distribution } -> InputTileThreadDistributionDescriptor4D; + { T::transfer.c.thread_distribution } -> OutputTileThreadDistributionDescriptor; }; // Concept to check if a struct specifies convolution scalar per vector infor for A, B and C. @@ -229,8 +204,8 @@ concept SpecifiesTileTransfer = requires(T t) { // Concept to check if a struct specifies LDS transfer info for tensors A, B, and C. template concept SpecifiesLdsTransfer = requires(T t) { - { T::transfer.a.lds_transfer } -> LdsTransferDescriptor; - { T::transfer.b.lds_transfer } -> LdsTransferDescriptor; + { T::transfer.a.lds_transfer_params } -> LdsInputTransferDescriptor; + { T::transfer.b.lds_transfer_params } -> LdsInputTransferDescriptor; { T::transfer.c.epilogue } -> EpilogueDescriptor; }; @@ -250,13 +225,8 @@ concept SpecifiesSourceAccessOrder = requires(T t) { // Concept to check if struct specifies block GEMM. template -concept SpecifiesBlockGemm = requires { - { T::block_gemm_pipeline } -> BlockGemmPipelineDescriptor; -}; - -template -concept SpecifiesGridwiseGemmPipeline = requires { - { T::pipeline_version } -> std::convertible_to; +concept SpecifiesGemmPipeline = requires { + { T::gemm_pipeline } -> GemmPipelineDescriptor; }; // Concept to check if struct specifies block GEMM (CK Tile). @@ -370,6 +340,18 @@ concept SpecifiesMultipleDSupport = requires { requires T::specialization == ConvAlgorithmSpecialization::MULTIPLE_D; }; +template +concept SpecifiesXdl = requires { + { T::warp_gemm.matrix_instruction } -> std::convertible_to; + requires T::warp_gemm.matrix_instruction == MatrixInstructionType::XDL; +}; + +template +concept SpecifiesWmma = requires { + { T::warp_gemm.matrix_instruction } -> std::convertible_to; + requires T::warp_gemm.matrix_instruction == MatrixInstructionType::WMMA; +}; + /******************************************** */ /* DL-specific descriptors and requirements */ /******************************************** */ diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index ffd45efe496..ff7e54546d7 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -10,38 +10,40 @@ namespace ck_tile::builder::factory { // Base algorithm concepts template concept FwdXdlAlgorithmBase = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesBlockTransfer && + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesThreadDistribution3D && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && - SpecifiesSourceAccessOrder && SpecifiesGridwiseFwdXdlGemm && + SpecifiesSourceAccessOrder && SpecifiesWarpGemm && SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && - SpecifiesNumPrefetchStages && SpecifiesNumGroupsToMerge && SpecifiesLoopScheduler; + SpecifiesNumPrefetchStages && SpecifiesNumGroupsToMerge && SpecifiesLoopScheduler && + SpecifiesXdl; template concept BwdXdlAlgorithmBase = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesBlockTransfer4D && + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesThreadDistribution4D && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && - SpecifiesSourceAccessOrder && SpecifiesGridwiseBwdXdlGemm && - SpecifiesBwdWeightConvSpecialization; + SpecifiesSourceAccessOrder && SpecifiesWarpGemm && + SpecifiesBwdWeightConvSpecialization && SpecifiesXdl; template concept BwdXdlV3AlgorithmBase = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesBlockTransfer && + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesThreadDistribution3D && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && - SpecifiesSourceAccessOrder && SpecifiesGridwiseBwdXdlGemm && - SpecifiesBwdWeightConvSpecialization && SpecifiesBlockGemm; + SpecifiesSourceAccessOrder && SpecifiesWarpGemm && + SpecifiesBwdWeightConvSpecialization && SpecifiesGemmPipeline && SpecifiesXdl; + template concept BwdWmmaAlgorithmBase = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesBlockTransfer && + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesThreadDistribution3D && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && - SpecifiesSourceAccessOrder && SpecifiesGridwiseWmmaGemm && - SpecifiesBwdWeightConvSpecialization; + SpecifiesSourceAccessOrder && SpecifiesWarpGemm && + SpecifiesBwdWeightConvSpecialization && SpecifiesWmma; template concept BwdWmmaV3AlgorithmBase = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesBlockTransfer && + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesThreadDistribution3D && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && - SpecifiesSourceAccessOrder && SpecifiesGridwiseWmmaGemm && - SpecifiesBwdWeightConvSpecialization && SpecifiesBlockGemm; + SpecifiesSourceAccessOrder && SpecifiesWarpGemm && + SpecifiesBwdWeightConvSpecialization && SpecifiesGemmPipeline && SpecifiesWmma; // Reference algorithm concept template @@ -62,19 +64,19 @@ concept LargeTensorAlgorithm = FwdXdlAlgorithmBase && SpecifiesLargeTensorSup template concept FwdXdlV3Algorithm = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesBlockTransfer && + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesThreadDistribution3D && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && - SpecifiesSourceAccessOrder && SpecifiesGridwiseFwdXdlGemm && - SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && SpecifiesBlockGemm; + SpecifiesSourceAccessOrder && SpecifiesWarpGemm && + SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && SpecifiesGemmPipeline && SpecifiesXdl; // FWD WMMA algorithm concepts template concept FwdWmmaAlgorithm = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesBlockTransfer && + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesThreadDistribution3D && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && - SpecifiesSourceAccessOrder && SpecifiesGridwiseWmmaGemm && + SpecifiesSourceAccessOrder && SpecifiesWarpGemm && SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && - SpecifiesNumPrefetchStages && SpecifiesLoopScheduler && SpecifiesGridwiseGemmPipeline; + SpecifiesNumPrefetchStages && SpecifiesLoopScheduler && SpecifiesGemmPipeline && SpecifiesWmma; // FWD DL algorithms template @@ -102,7 +104,7 @@ concept BwdTwoStageXdlAlgorithm = BwdXdlV3AlgorithmBase && SpecifiesTranspose template concept BwdWmmaAlgorithm = BwdWmmaAlgorithmBase && SpecifiesNumPrefetchStages && SpecifiesLoopScheduler && - SpecifiesGridwiseGemmPipeline && SpecifiesGenericInstance; + SpecifiesGemmPipeline && SpecifiesGenericInstance; template concept BwdMultiDWmmaV3Algorithm = BwdWmmaV3AlgorithmBase && SpecifiesMultipleDSupport; @@ -115,7 +117,7 @@ template concept BwdTwoStageWmmaV3Algorithm = BwdWmmaV3AlgorithmBase && SpecifiesTransposeTransfer && SpecifiesGemmBatchOptions && SpecifiesTwoStageSupport; -// BWD weigth DL algorithms +// BWD weight DL algorithms template concept BwdDlAlgorithm = ConvAlgorithmDescriptor && SpecifiesThreadBlock && diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index d96080f23db..b31f64a61d3 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -238,6 +238,12 @@ enum class ConvAlgorithmSpecialization MULTIPLE_D }; +enum class MatrixInstructionType +{ + XDL, + WMMA +}; + // toString methods for enum classes inline std::string_view toString(DataType dt) { diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle_v3.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle_v3.cpp index 404d1dbacdb..6f9f086fb75 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle_v3.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle_v3.cpp @@ -24,7 +24,7 @@ constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeightMultiple .with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave) .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT) - .with_block_gemm(cku::BlockGemmDesc_v1_intrawave); + .with_gemm_pipeline(cku::BlockGemmDesc_v1_intrawave); using Builder = ckb::ConvBuilder; using Instance = Builder::Instance; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_wmma_cshuffle_v3.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_wmma_cshuffle_v3.cpp index 782f33f8450..9310f5a9a69 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_wmma_cshuffle_v3.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_wmma_cshuffle_v3.cpp @@ -19,12 +19,12 @@ constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 2, .weight = {.config = {.layout = GKYXC}}, .output = {.config = {.layout = NGKHW}}}; -constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3{} +constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_CShuffle{} .with_thread_block(cku::ThreadBlock_64_32x32x32) .with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave) .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT) - .with_block_gemm(cku::BlockGemmDesc_v1_intrawave) + .with_gemm_pipeline(cku::BlockGemmDesc_v1_intrawave) .with_num_conv_groups_to_merge(2) .with_transpose_params(2, 2); diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp index a2a877dbcd4..bda064c2918 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp @@ -19,12 +19,12 @@ constexpr auto SIGNATURE = .weight = {.config = {.layout = ckb::TensorLayout::GKYXC}}, .output = {.config = {.layout = ckb::TensorLayout::GNHWK}}}; -constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle{} +constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_CShuffle{} .with_thread_block(cku::ThreadBlock_64_32x32x32) .with_gemm_config(cku::BwdGemmParams_Xdl_1x1_per_wave) .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT) - .with_block_gemm(cku::BlockGemmDesc_v2_intrawave) + .with_gemm_pipeline(cku::BlockGemmDesc_v2_intrawave) .with_num_conv_groups_to_merge(2) .with_transpose_params(2, 4); diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp index 0981ea6c11b..d5051a50c8e 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp @@ -25,7 +25,7 @@ constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CS .with_transfer(cku::BwdTransfer_4x64x1) .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT) .with_prefetch_config(1, ckb::PipelineScheduler::DEFAULT) - .with_gridwise_gemm_pipeline(ckb::PipelineVersion::V1); + .with_gemm_pipeline(ckb::PipelineVersion::V1); using Builder = ckb::ConvBuilder; using Instance = Builder::Instance; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp index 60f7d5bd643..b91d8de810a 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp @@ -25,7 +25,7 @@ constexpr auto ALGORITHM = .with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave) .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) .with_bwd_specialization(ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0) - .with_block_gemm(cku::BlockGemmDesc_v1_intrawave) + .with_gemm_pipeline(cku::BlockGemmDesc_v1_intrawave) .with_transpose_params(4, 4); using Builder = ckb::ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp index 4ad97209e5e..ccdad77e392 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp @@ -25,7 +25,7 @@ constexpr auto ALGORITHM = .with_gemm_config(cku::BwdGemmParams_Xdl_1x1_per_wave) .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) .with_bwd_specialization(ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0) - .with_block_gemm(cku::BlockGemmDesc_v2_intrawave); + .with_gemm_pipeline(cku::BlockGemmDesc_v2_intrawave); using Builder = ckb::ConvBuilder; using Instance = Builder::Instance; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp index 8d85370b268..b7ec4cdac09 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp @@ -35,7 +35,7 @@ TEST(FwdConvInstances, .with_transfer(Transfer_4x64x1) .with_fwd_specializations(ConvSpecialization::FILTER_1X1_STRIDE1_PAD0, GemmSpecialization::MNKPadding) - .with_block_gemm(BlockGemmDesc_v2_intrawave); + .with_gemm_pipeline(BlockGemmDesc_v2_intrawave); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp index e4675c89a7a..59b43191aa4 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp @@ -35,7 +35,7 @@ TEST(FwdConvInstances, .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_prefetch_config(1, PipelineScheduler::DEFAULT) .with_num_conv_groups_to_merge(2) - .with_gridwise_gemm_pipeline(PipelineVersion::V1); + .with_gemm_pipeline(PipelineVersion::V1); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp index 610e2fad5fe..e63aa41e059 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp @@ -31,7 +31,7 @@ TEST(FwdConvInstances, .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) .with_transfer(Transfer_4x64x1) .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) - .with_block_gemm(BlockGemmDesc_v1_intrawave); + .with_gemm_pipeline(BlockGemmDesc_v1_intrawave); using Builder = ConvBuilder; @@ -69,7 +69,7 @@ TEST(FwdConvInstances, .with_transfer(Transfer_4x64x1) .with_fwd_specializations(ConvSpecialization::FILTER_3x3, GemmSpecialization::MNKPadding) - .with_block_gemm(BlockGemmDesc_v5_intrawave); + .with_gemm_pipeline(BlockGemmDesc_v5_intrawave); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp index 8ae88a09174..152409396e8 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp @@ -26,7 +26,7 @@ constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xd .with_transfer(cku::Transfer_4x64x1) .with_fwd_specializations(ckb::ConvSpecialization::DEFAULT, ckb::GemmSpecialization::MNKPadding) - .with_block_gemm(cku::BlockGemmDesc_v3_intrawave); + .with_gemm_pipeline(cku::BlockGemmDesc_v3_intrawave); using Builder = ckb::ConvBuilder; using Instance = Builder::Instance; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp index bb35c53ba06..67462426f61 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp @@ -31,7 +31,7 @@ TEST(FwdConvInstances, .with_transfer(Transfer_4x64x1) .with_fwd_specializations(ConvSpecialization::FILTER_1X1_STRIDE1_PAD0, GemmSpecialization::MNKPadding) - .with_block_gemm(BlockGemmDesc_v4_intrawave); + .with_gemm_pipeline(BlockGemmDesc_v4_intrawave); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp index 9e6ca00e581..016e972f3cd 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp @@ -31,7 +31,7 @@ TEST(FwdConvInstances, .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) .with_transfer(Transfer_4x64x1) .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) - .with_block_gemm(BlockGemmDesc_v3_intrawave); + .with_gemm_pipeline(BlockGemmDesc_v3_intrawave); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp index 56d4b8be590..00da06d41aa 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp @@ -32,7 +32,7 @@ TEST(FwdConvInstances, .with_transfer(Transfer_4x64x1) .with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0, GemmSpecialization::MNKPadding) - .with_block_gemm(BlockGemmDesc_v4_intrawave); + .with_gemm_pipeline(BlockGemmDesc_v4_intrawave); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp index df8339241bc..825b9a03330 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp @@ -32,7 +32,7 @@ TEST(FwdConvInstances, .with_transfer(Transfer_4x64x1) .with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0, GemmSpecialization::MNKPadding) - .with_block_gemm(BlockGemmDesc_v1_intrawave); + .with_gemm_pipeline(BlockGemmDesc_v1_intrawave); using Builder = ConvBuilder; diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index 797f440bdce..0c665c83210 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -28,53 +28,26 @@ struct ThreadBlock }; static_assert(ckb::ThreadBlockDescriptor); -struct XdlParams +struct WarpGemmParams { - size_t m_per_xdl = 0; - size_t n_per_xdl = 0; - size_t m_xdl_per_wave = 0; - size_t n_xdl_per_wave = 0; + MatrixInstructionType matrix_instruction; + size_t gemm_m_per_instruction = 0; + size_t gemm_n_per_instruction = 0; + size_t gemm_m_iters_per_wave = 0; + size_t gemm_n_iters_per_wave = 0; }; -static_assert(ckb::GridwiseXdlGemmDescriptor); +static_assert(ckb::WarpGemmDescriptor); -// Describe gridwise XDL GEMM parameters. -struct GridwiseFwdXdlGemm -{ - // NOTE: ak1 and bk1 are difficult to verify in the kernel instantiation!!! - size_t ak1 = 0; - size_t bk1 = 0; - XdlParams xdl_params; -}; -static_assert(ckb::GridwiseFwdXdlGemmDescriptor); - -struct GridwiseBwdXdlGemm -{ - size_t k1 = 0; - XdlParams xdl_params; -}; -static_assert(ckb::GridwiseBwdXdlGemmDescriptor); - -// Describe gridwise WMMA GEMM parameters. -struct GridwiseWmmaGemm -{ - size_t k1 = 0; - size_t m_per_wmma = 0; - size_t n_per_wmma = 0; - size_t m_wmma_per_wave = 0; - size_t n_wmma_per_wave = 0; -}; -static_assert(ckb::GridwiseWmmaGemmDescriptor); - -struct BlockGemmPipeline +struct GemmPipeline { PipelineVersion pipeline_version; - PipelineScheduler scheduler; + PipelineScheduler scheduler{PipelineScheduler::DEFAULT}; }; -static_assert(ckb::BlockGemmPipelineDescriptor); +static_assert(ckb::GemmPipelineDescriptor); -// Describe Aand B block transfer thread cluster lengths. +// Describe input tensor thread cluster lengths. template -struct BlockTransfer +struct InputDataThreadDistribution { size_t k0; size_t m_n; @@ -84,34 +57,35 @@ struct BlockTransfer // Specialization for ThreadSliceLength == 3 template <> -struct BlockTransfer<3> +struct InputDataThreadDistribution<3> { size_t k0; size_t m_n; size_t k1; }; -static_assert(ckb::BlockTransferDescriptor>); -static_assert(ckb::BlockTransferDescriptor>); +static_assert(ckb::InputTileThreadDistributionDescriptor3D>); +static_assert(ckb::InputTileThreadDistributionDescriptor4D>); // Describe C block transfer thread cluster lengths. -struct ThreadCluster +struct OutputDataThreadDistribution { - size_t m_block; - size_t m_wave_per_xdl; - size_t n_block; - size_t n_wave_per_xdl; + size_t gemm_m_block_size; + size_t gemm_m_per_block; + size_t gemm_n_block_size; + size_t gemm_n_per_block; }; -static_assert(ThreadClusterDescriptor); +static_assert(OutputTileThreadDistributionDescriptor); -struct LdsTransfer +struct LdsInputTransferParams { + size_t global_memory_vector_load_size; size_t src_vector_dim; size_t src_scalar_per_vector; size_t lds_dst_scalar_per_vector; bool is_direct_load; bool lds_padding; }; -static_assert(LdsTransferDescriptor); +static_assert(LdsInputTransferDescriptor); struct Epilogue { @@ -130,26 +104,26 @@ static_assert(AccessOrderDescriptor>); static_assert(AccessOrderDescriptor>); template -struct InputTransfer +struct InputTileTransfer { - BlockTransfer block_transfer; - LdsTransfer lds_transfer; - AccessOrder block_transfer_access_order; + InputDataThreadDistribution thread_distribution; + LdsInputTransferParams lds_transfer_params; + AccessOrder thread_distribution_access_order; AccessOrder src_access_order; }; -struct OutputTransfer +struct OutputTileTransfer { - ThreadCluster thread_cluster_dims; + OutputDataThreadDistribution thread_distribution; Epilogue epilogue; }; template -struct Transfer +struct InputOutputTileTransfer { - InputTransfer a; - InputTransfer b; - OutputTransfer c; + InputTileTransfer a; + InputTileTransfer b; + OutputTileTransfer c; }; // DL-specific descriptors @@ -199,25 +173,15 @@ struct ThreadBlock_ ThreadBlock thread_block; }; -struct FwdXdlGemm_ -{ - GridwiseFwdXdlGemm gridwise_gemm; -}; - -struct BwdXdlGemm_ -{ - GridwiseBwdXdlGemm gridwise_gemm; -}; - -struct WmmaGemm_ +struct WarpGemm_ { - GridwiseWmmaGemm gridwise_gemm; + WarpGemmParams warp_gemm; }; template -struct Transfer_ +struct InputOutputTileTransfer_ { - Transfer transfer; + InputOutputTileTransfer transfer; }; struct ConvSpecializationFwd_ @@ -248,14 +212,9 @@ struct GemmBatchOptions_ size_t num_conv_groups_to_merge{1}; }; -struct BlockGemm_ -{ - BlockGemmPipeline block_gemm_pipeline; -}; - -struct GridGemm_ +struct GemmPipeline_ { - PipelineVersion pipeline_version; + GemmPipeline gemm_pipeline; }; struct DlThreadConfig_ @@ -386,30 +345,16 @@ struct ConvAlgorithmTemplate : Components... constexpr auto with_gemm_config(const GemmConfig& gemm) const { auto result = *this; - if constexpr(std::is_base_of_v) - { - result.gridwise_gemm = gemm; - } - else if constexpr(std::is_base_of_v) - { - result.gridwise_gemm = gemm; - } - else if constexpr(std::is_base_of_v) - { - result.gridwise_gemm = gemm; - } - else - { - static_assert(false, "Unrecognized GemmConfig type"); - } + static_assert(std::is_base_of_v); + result.warp_gemm = gemm; return result; } template constexpr auto with_transfer(const T& t) const { - static_assert(std::is_base_of_v, ConvAlgorithmTemplate> || - std::is_base_of_v, ConvAlgorithmTemplate>); + static_assert(std::is_base_of_v, ConvAlgorithmTemplate> || + std::is_base_of_v, ConvAlgorithmTemplate>); auto result = *this; result.transfer = t; return result; @@ -460,20 +405,20 @@ struct ConvAlgorithmTemplate : Components... return result; } - template - constexpr auto with_block_gemm(const BG& bg) const + template + constexpr auto with_gemm_pipeline(const PL& pl) const { - static_assert(std::is_base_of_v); + static_assert(std::is_base_of_v); auto result = *this; - result.block_gemm_pipeline = bg; + result.gemm_pipeline = pl; return result; } - constexpr auto with_gridwise_gemm_pipeline(const PipelineVersion plv) const + constexpr auto with_gemm_pipeline(const PipelineVersion plv) const { - static_assert(std::is_base_of_v); + static_assert(std::is_base_of_v); auto result = *this; - result.pipeline_version = plv; + result.gemm_pipeline.pipeline_version = plv; return result; } @@ -555,25 +500,25 @@ struct ConvAlgorithmTemplate : Components... using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = ConvAlgorithmTemplate, + WarpGemm_, + InputOutputTileTransfer_<>, ConvSpecializationFwd_, Prefetch_, GemmBatchOptions_>; using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 = ConvAlgorithmTemplate, + WarpGemm_, + InputOutputTileTransfer_<>, ConvSpecializationFwd_, - BlockGemm_>; + GemmPipeline_>; using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle = ConvAlgorithmTemplate, + WarpGemm_, + InputOutputTileTransfer_<>, ConvSpecializationFwd_, - GridGemm_, + GemmPipeline_, Prefetch_, GemmBatchOptions_>; @@ -586,8 +531,8 @@ using ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK = using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor = ConvAlgorithmTemplate, + WarpGemm_, + InputOutputTileTransfer_<>, ConvSpecializationFwd_, Prefetch_, GemmBatchOptions_, @@ -612,27 +557,44 @@ struct ConvAlgorithm_Reference // Bwd weight algorithm types using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle = ConvAlgorithmTemplate, + WarpGemm_, + InputOutputTileTransfer_<4>, ConvSpecializationBwdWeight_, TransposeParams_>; -using ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle = +using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle = ConvAlgorithmTemplate, + WarpGemm_, + InputOutputTileTransfer_<4>, ConvSpecializationBwdWeight_, - BlockGemm_, + GemmPipeline_, + Prefetch_>; + +// Covers both XDL and WMMA variants +using ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_CShuffle = + ConvAlgorithmTemplate, + ConvSpecializationBwdWeight_, + GemmPipeline_, TransposeParams_, GemmBatchOptions_, TwoStageSpecialization_>; using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 = ConvAlgorithmTemplate, + WarpGemm_, + InputOutputTileTransfer_<>, + ConvSpecializationBwdWeight_, + GemmPipeline_>; + +using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3 = + ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, - BlockGemm_>; + GemmPipeline_, + TransposeParams_>; using ConvAlgorithm_DeviceGroupedConvBwdWeight_Dl = ConvAlgorithmTemplate, + WarpGemm_, + InputOutputTileTransfer_<4>, ConvSpecializationBwdWeight_, MultipleDSpecialization_>; -using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3 = - ConvAlgorithmTemplate, - ConvSpecializationBwdWeight_, - BlockGemm_, - TransposeParams_>; - -using ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3 = - ConvAlgorithmTemplate, - ConvSpecializationBwdWeight_, - BlockGemm_, - TransposeParams_, - GemmBatchOptions_, - TwoStageSpecialization_>; - -using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle = - ConvAlgorithmTemplate, - ConvSpecializationBwdWeight_, - GridGemm_, - Prefetch_>; - using ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_V3 = ConvAlgorithmTemplate, + WarpGemm_, + InputOutputTileTransfer_<>, ConvSpecializationBwdWeight_, - BlockGemm_, + GemmPipeline_, MultipleDSpecialization_>; } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/utils/ckb_conv_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_test_configs.hpp index 3b83ead2d0d..fd09c810d04 100644 --- a/experimental/builder/test/utils/ckb_conv_test_configs.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_configs.hpp @@ -12,6 +12,7 @@ namespace ck_tile::builder::test_utils { using namespace ck_tile::builder; using namespace test; +// Test configs for DL algorithms constexpr DlThreadConfig DlThreadConfig_16x2x4x4x1{ .k0_per_block = 16, .k1 = 2, .m1_per_thread = 4, .n1_per_thread = 4, .k_per_thread = 1}; @@ -50,165 +51,176 @@ constexpr DlTransfer<5> DlTransfer5D{.a = DlBlockTransfer_1x8x1x1x1, .src_dst_vector_dim = 5, .dst_scalar_per_vector = 1}}; -constexpr Transfer<> Transfer_4x64x1{ +// XLD/WMMA test configs +constexpr InputOutputTileTransfer<> Transfer_4x64x1{ .a = { - .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, - .lds_transfer = {.src_vector_dim = 2, + .thread_distribution = {.k0 = 4, .m_n = 64, .k1 = 1}, + .lds_transfer_params = {.global_memory_vector_load_size = 8, + .src_vector_dim = 2, .src_scalar_per_vector = 2, .lds_dst_scalar_per_vector = 8, .is_direct_load = false, .lds_padding = false}, - .block_transfer_access_order = {1, 0, 2}, + .thread_distribution_access_order = {1, 0, 2}, .src_access_order = {1, 0, 2}, }, .b = { - .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, - .lds_transfer = {.src_vector_dim = 2, + .thread_distribution = {.k0 = 4, .m_n = 64, .k1 = 1}, + .lds_transfer_params = {.global_memory_vector_load_size = 8, + .src_vector_dim = 2, .src_scalar_per_vector = 8, .lds_dst_scalar_per_vector = 8, .is_direct_load = false, .lds_padding = false}, - .block_transfer_access_order = {1, 0, 2}, + .thread_distribution_access_order = {1, 0, 2}, .src_access_order = {1, 0, 2}, }, .c = { - .thread_cluster_dims = - {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, + .thread_distribution = + {.gemm_m_block_size = 1, .gemm_m_per_block = 32, .gemm_n_block_size = 1, .gemm_n_per_block = 8}, .epilogue = {.m_xdl_per_wave_per_shuffle = 1, .n_per_wave_per_shuffle = 1, .scalar_per_vector = 4}, }, }; -constexpr Transfer<4> BwdTransfer_4x64x1{ +constexpr InputOutputTileTransfer<4> BwdTransfer_4x64x1{ .a = { - .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1}, - .lds_transfer = {.src_vector_dim = 2, + .thread_distribution = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1}, + .lds_transfer_params = {.global_memory_vector_load_size = 8, + .src_vector_dim = 2, .src_scalar_per_vector = 2, .lds_dst_scalar_per_vector = 4, .is_direct_load = false, .lds_padding = true}, - .block_transfer_access_order = {0, 3, 1, 2}, + .thread_distribution_access_order = {0, 3, 1, 2}, .src_access_order = {0, 2, 1, 3}, }, .b = { - .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1}, - .lds_transfer = {.src_vector_dim = 2, + .thread_distribution = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1}, + .lds_transfer_params = {.global_memory_vector_load_size = 8, + .src_vector_dim = 2, .src_scalar_per_vector = 2, .lds_dst_scalar_per_vector = 4, .is_direct_load = false, .lds_padding = true}, - .block_transfer_access_order = {0, 3, 1, 2}, + .thread_distribution_access_order = {0, 3, 1, 2}, .src_access_order = {0, 2, 1, 3}, }, .c = { - .thread_cluster_dims = - {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, + .thread_distribution = + {.gemm_m_block_size = 1, .gemm_m_per_block = 32, .gemm_n_block_size = 1, .gemm_n_per_block = 8}, .epilogue = {.m_xdl_per_wave_per_shuffle = 1, .n_per_wave_per_shuffle = 1, .scalar_per_vector = 8}, }, }; -constexpr Transfer<> BwdTransfer_4x8x1_4x16x1_v3{ +constexpr InputOutputTileTransfer<> BwdTransfer_4x8x1_4x16x1_v3{ .a = { - .block_transfer = {.k0 = 4, .m_n = 8, .k1 = 1}, - .lds_transfer = {.src_vector_dim = 1, + .thread_distribution = {.k0 = 4, .m_n = 8, .k1 = 1}, + .lds_transfer_params = {.global_memory_vector_load_size = 8, + .src_vector_dim = 1, .src_scalar_per_vector = 2, .lds_dst_scalar_per_vector = 2, .is_direct_load = false, .lds_padding = false}, - .block_transfer_access_order = {2, 0, 1}, + .thread_distribution_access_order = {2, 0, 1}, .src_access_order = {1, 0, 2}, }, .b = { - .block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1}, - .lds_transfer = {.src_vector_dim = 1, + .thread_distribution = {.k0 = 4, .m_n = 16, .k1 = 1}, + .lds_transfer_params = {.global_memory_vector_load_size = 8, + .src_vector_dim = 1, .src_scalar_per_vector = 2, .lds_dst_scalar_per_vector = 2, .is_direct_load = false, .lds_padding = false}, - .block_transfer_access_order = {2, 0, 1}, + .thread_distribution_access_order = {2, 0, 1}, .src_access_order = {1, 0, 2}, }, .c = { - .thread_cluster_dims = - {.m_block = 1, .m_wave_per_xdl = 8, .n_block = 1, .n_wave_per_xdl = 8}, + .thread_distribution = + {.gemm_m_block_size = 1, .gemm_m_per_block = 8, .gemm_n_block_size = 1, .gemm_n_per_block = 8}, .epilogue = {.m_xdl_per_wave_per_shuffle = 1, .n_per_wave_per_shuffle = 1, .scalar_per_vector = 2}, }, }; -constexpr Transfer<> Transfer_4x64x1_fp8{ +constexpr InputOutputTileTransfer<> Transfer_4x64x1_fp8{ .a = { - .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, - .lds_transfer = {.src_vector_dim = 2, + .thread_distribution = {.k0 = 4, .m_n = 64, .k1 = 1}, + .lds_transfer_params = {.global_memory_vector_load_size = 8, + .src_vector_dim = 2, .src_scalar_per_vector = 8, .lds_dst_scalar_per_vector = 8, .is_direct_load = false, .lds_padding = true}, - .block_transfer_access_order = {1, 0, 2}, + .thread_distribution_access_order = {1, 0, 2}, .src_access_order = {1, 0, 2}, }, .b = { - .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, - .lds_transfer = {.src_vector_dim = 2, + .thread_distribution = {.k0 = 4, .m_n = 64, .k1 = 1}, + .lds_transfer_params = {.global_memory_vector_load_size = 8, + .src_vector_dim = 2, .src_scalar_per_vector = 8, .lds_dst_scalar_per_vector = 8, .is_direct_load = false, .lds_padding = true}, - .block_transfer_access_order = {1, 0, 2}, + .thread_distribution_access_order = {1, 0, 2}, .src_access_order = {1, 0, 2}, }, .c = { - .thread_cluster_dims = - {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, + .thread_distribution = + {.gemm_m_block_size = 1, .gemm_m_per_block = 32, .gemm_n_block_size = 1, .gemm_n_per_block = 8}, .epilogue = {.m_xdl_per_wave_per_shuffle = 1, .n_per_wave_per_shuffle = 1, .scalar_per_vector = 8}, }, }; -constexpr Transfer<> Transfer_4x16x1{ +constexpr InputOutputTileTransfer<> Transfer_4x16x1{ .a = { - .block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1}, - .lds_transfer = {.src_vector_dim = 2, + .thread_distribution = {.k0 = 4, .m_n = 16, .k1 = 1}, + .lds_transfer_params = {.global_memory_vector_load_size = 8, + .src_vector_dim = 2, .src_scalar_per_vector = 8, .lds_dst_scalar_per_vector = 8, .is_direct_load = false, .lds_padding = true}, - .block_transfer_access_order = {1, 0, 2}, + .thread_distribution_access_order = {1, 0, 2}, .src_access_order = {1, 0, 2}, }, .b = { - .block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1}, - .lds_transfer = {.src_vector_dim = 2, + .thread_distribution = {.k0 = 4, .m_n = 16, .k1 = 1}, + .lds_transfer_params = {.global_memory_vector_load_size = 8, + .src_vector_dim = 2, .src_scalar_per_vector = 8, .lds_dst_scalar_per_vector = 8, .is_direct_load = false, .lds_padding = true}, - .block_transfer_access_order = {1, 0, 2}, + .thread_distribution_access_order = {1, 0, 2}, .src_access_order = {1, 0, 2}, }, .c = { - .thread_cluster_dims = - {.m_block = 1, .m_wave_per_xdl = 16, .n_block = 1, .n_wave_per_xdl = 4}, + .thread_distribution = + {.gemm_m_block_size = 1, .gemm_m_per_block = 16, .gemm_n_block_size = 1, .gemm_n_per_block = 4}, .epilogue = {.m_xdl_per_wave_per_shuffle = 1, .n_per_wave_per_shuffle = 1, .scalar_per_vector = 8}, @@ -216,72 +228,72 @@ constexpr Transfer<> Transfer_4x16x1{ }, }; -constexpr Transfer<> Transfer_4x32x1{ +constexpr InputOutputTileTransfer<> Transfer_4x32x1{ .a = { - .block_transfer = {.k0 = 4, .m_n = 32, .k1 = 1}, - .lds_transfer = {.src_vector_dim = 2, + .thread_distribution = {.k0 = 4, .m_n = 32, .k1 = 1}, + .lds_transfer_params = {.global_memory_vector_load_size = 8, + .src_vector_dim = 2, .src_scalar_per_vector = 16, .lds_dst_scalar_per_vector = 16, .is_direct_load = false, .lds_padding = true}, - .block_transfer_access_order = {1, 0, 2}, + .thread_distribution_access_order = {1, 0, 2}, .src_access_order = {1, 0, 2}, }, .b = { - .block_transfer = {.k0 = 4, .m_n = 32, .k1 = 1}, - .lds_transfer = {.src_vector_dim = 2, + .thread_distribution = {.k0 = 4, .m_n = 32, .k1 = 1}, + .lds_transfer_params = {.global_memory_vector_load_size = 8, + .src_vector_dim = 2, .src_scalar_per_vector = 16, .lds_dst_scalar_per_vector = 16, .is_direct_load = false, .lds_padding = true}, - .block_transfer_access_order = {1, 0, 2}, + .thread_distribution_access_order = {1, 0, 2}, .src_access_order = {1, 0, 2}, }, .c = { - .thread_cluster_dims = - {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 4}, + .thread_distribution = + {.gemm_m_block_size = 1, .gemm_m_per_block = 32, .gemm_n_block_size = 1, .gemm_n_per_block = 4}, .epilogue = {.m_xdl_per_wave_per_shuffle = 1, .n_per_wave_per_shuffle = 1, .scalar_per_vector = 8}, }, }; -constexpr GridwiseBwdXdlGemm BwdGemmParams_Xdl_4x4_per_wave{ - .k1 = 8, - .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}}; +constexpr WarpGemmParams BwdGemmParams_Xdl_4x4_per_wave{ + .matrix_instruction = MatrixInstructionType::XDL, + .gemm_m_per_instruction = 32, .gemm_n_per_instruction = 32, .gemm_m_iters_per_wave = 4, .gemm_n_iters_per_wave = 4}; -constexpr GridwiseBwdXdlGemm BwdGemmParams_Xdl_1x1_per_wave{ - .k1 = 8, - .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 1, .n_xdl_per_wave = 1}}; +constexpr WarpGemmParams BwdGemmParams_Xdl_1x1_per_wave{ + .matrix_instruction = MatrixInstructionType::XDL, + .gemm_m_per_instruction = 32, .gemm_n_per_instruction = 32, .gemm_m_iters_per_wave = 1, .gemm_n_iters_per_wave = 1}; -constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_4x4_per_wave{ - .ak1 = 8, - .bk1 = 8, - .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}}; +constexpr WarpGemmParams FwdGemmParams_Xdl_4x4_per_wave{ + .matrix_instruction = MatrixInstructionType::XDL, + .gemm_m_per_instruction = 32, .gemm_n_per_instruction = 32, .gemm_m_iters_per_wave = 4, .gemm_n_iters_per_wave = 4}; -constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_4x2_per_wave{ - .ak1 = 8, - .bk1 = 8, - .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 2}}; +constexpr WarpGemmParams FwdGemmParams_Xdl_4x2_per_wave{ + .matrix_instruction = MatrixInstructionType::XDL, + .gemm_m_per_instruction = 32, .gemm_n_per_instruction = 32, .gemm_m_iters_per_wave = 4, .gemm_n_iters_per_wave = 2}; -constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_2x2_per_wave{ - .ak1 = 8, - .bk1 = 8, - .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2}}; +constexpr WarpGemmParams FwdGemmParams_Xdl_2x2_per_wave{ + .matrix_instruction = MatrixInstructionType::XDL, + .gemm_m_per_instruction = 32, .gemm_n_per_instruction = 32, .gemm_m_iters_per_wave = 2, .gemm_n_iters_per_wave = 2}; -constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_2x1_per_wave{ - .ak1 = 8, - .bk1 = 8, - .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 1}}; +constexpr WarpGemmParams FwdGemmParams_Xdl_2x1_per_wave{ + .matrix_instruction = MatrixInstructionType::XDL, + .gemm_m_per_instruction = 32, .gemm_n_per_instruction = 32, .gemm_m_iters_per_wave = 2, .gemm_n_iters_per_wave = 1}; -constexpr GridwiseWmmaGemm GemmParams_Wmma_2x1_per_wave{ - .k1 = 8, .m_per_wmma = 32, .n_per_wmma = 32, .m_wmma_per_wave = 2, .n_wmma_per_wave = 1}; +constexpr WarpGemmParams GemmParams_Wmma_2x1_per_wave{ + .matrix_instruction = MatrixInstructionType::WMMA, + .gemm_m_per_instruction = 32, .gemm_n_per_instruction = 32, .gemm_m_iters_per_wave = 2, .gemm_n_iters_per_wave = 1}; -constexpr GridwiseWmmaGemm GemmParams_Wmma_16x16_2x1_per_wave{ - .k1 = 8, .m_per_wmma = 16, .n_per_wmma = 16, .m_wmma_per_wave = 2, .n_wmma_per_wave = 1}; +constexpr WarpGemmParams GemmParams_Wmma_16x16_2x1_per_wave{ + .matrix_instruction = MatrixInstructionType::WMMA, + .gemm_m_per_instruction = 16, .gemm_n_per_instruction = 16, .gemm_m_iters_per_wave = 2, .gemm_n_iters_per_wave = 1}; constexpr ThreadBlock ThreadBlock_256_256x256x32{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; @@ -310,19 +322,19 @@ constexpr ThreadBlock ThreadBlock_128_128x128x32{.block_size = 128, constexpr ThreadBlock ThreadBlock_128_64x64x64{.block_size = 128, .tile_size = {.m = 64, .n = 64, .k = 64}}; -constexpr BlockGemmPipeline BlockGemmDesc_v1_intrawave = { +constexpr GemmPipeline BlockGemmDesc_v1_intrawave = { .pipeline_version = PipelineVersion::V1, .scheduler = PipelineScheduler::INTRAWAVE}; -constexpr BlockGemmPipeline BlockGemmDesc_v2_intrawave = { +constexpr GemmPipeline BlockGemmDesc_v2_intrawave = { .pipeline_version = PipelineVersion::V2, .scheduler = PipelineScheduler::INTRAWAVE}; -constexpr BlockGemmPipeline BlockGemmDesc_v3_intrawave = { +constexpr GemmPipeline BlockGemmDesc_v3_intrawave = { .pipeline_version = PipelineVersion::V3, .scheduler = PipelineScheduler::INTRAWAVE}; -constexpr BlockGemmPipeline BlockGemmDesc_v4_intrawave = { +constexpr GemmPipeline BlockGemmDesc_v4_intrawave = { .pipeline_version = PipelineVersion::V4, .scheduler = PipelineScheduler::INTRAWAVE}; -constexpr BlockGemmPipeline BlockGemmDesc_v5_intrawave = { +constexpr GemmPipeline BlockGemmDesc_v5_intrawave = { .pipeline_version = PipelineVersion::V5, .scheduler = PipelineScheduler::INTRAWAVE}; } // namespace ck_tile::builder::test_utils diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index b705ab1e471..ff478172275 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -86,34 +86,16 @@ inline std::string to_string(ThreadBlock t) } template <> -inline std::string to_string(GridwiseBwdXdlGemm t) +inline std::string to_string(WarpGemmParams t) { std::ostringstream oss; - oss << t.k1 << "," << t.xdl_params.m_per_xdl << "," << t.xdl_params.n_per_xdl << "," - << t.xdl_params.m_xdl_per_wave << "," << t.xdl_params.n_xdl_per_wave; + oss << t.gemm_m_per_instruction << "," << t.gemm_n_per_instruction << "," + << t.gemm_m_iters_per_wave << "," << t.gemm_n_iters_per_wave; return oss.str(); } template <> -inline std::string to_string(GridwiseFwdXdlGemm t) -{ - std::ostringstream oss; - oss << t.ak1 << "," << t.bk1 << "," << t.xdl_params.m_per_xdl << "," << t.xdl_params.n_per_xdl - << "," << t.xdl_params.m_xdl_per_wave << "," << t.xdl_params.n_xdl_per_wave; - return oss.str(); -} - -template <> -inline std::string to_string(GridwiseWmmaGemm t) -{ - std::ostringstream oss; - oss << t.k1 << "," << t.m_per_wmma << "," << t.n_per_wmma << "," << t.m_wmma_per_wave << "," - << t.n_wmma_per_wave; - return oss.str(); -} - -template <> -inline std::string to_string(BlockGemmPipeline t) +inline std::string to_string(GemmPipeline t) { std::ostringstream oss; oss << to_string(t.scheduler) << "," << to_string(t.pipeline_version); @@ -121,7 +103,7 @@ inline std::string to_string(BlockGemmPipeline t) } template -inline std::string to_string(BlockTransfer t) +inline std::string to_string(InputDataThreadDistribution t) { if constexpr(ThreadSliceDim == 4) { @@ -138,17 +120,17 @@ inline std::string to_string(BlockTransfer t) } template <> -inline std::string to_string(ThreadCluster t) +inline std::string to_string(OutputDataThreadDistribution t) { return array_to_seq( - std::array{t.m_block, t.m_wave_per_xdl, t.n_block, t.n_wave_per_xdl}); + std::array{t.gemm_m_block_size, t.gemm_m_per_block, t.gemm_n_block_size, t.gemm_n_per_block}); } template <> -inline std::string to_string(LdsTransfer t) +inline std::string to_string(LdsInputTransferParams t) { std::ostringstream oss; - oss << t.src_vector_dim << "," << t.src_scalar_per_vector << "," << t.lds_dst_scalar_per_vector + oss << t.global_memory_vector_load_size << "," << t.src_vector_dim << "," << t.src_scalar_per_vector << "," << t.lds_dst_scalar_per_vector << "," << (t.lds_padding ? "true" : "false") << "," << (t.is_direct_load ? "true" : "false"); return oss.str(); @@ -161,27 +143,27 @@ inline std::string to_string(AccessOrder t) } template -inline std::string to_string(InputTransfer t) +inline std::string to_string(InputTileTransfer t) { std::ostringstream oss; - oss << to_string(t.block_transfer) << "," << to_string(t.block_transfer_access_order) << "," - << to_string(t.src_access_order) << "," << t.lds_transfer.src_vector_dim << "," - << t.lds_transfer.src_scalar_per_vector << "," << t.lds_transfer.lds_dst_scalar_per_vector - << "," << (t.lds_transfer.lds_padding ? "true" : "false"); + oss << to_string(t.thread_distribution) << "," << to_string(t.thread_distribution_access_order) << "," + << to_string(t.src_access_order) << "," << t.lds_transfer_params.src_vector_dim << "," + << t.lds_transfer_params.src_scalar_per_vector << "," << t.lds_transfer_params.lds_dst_scalar_per_vector + << "," << (t.lds_transfer_params.lds_padding ? "true" : "false"); return oss.str(); } template <> -inline std::string to_string(OutputTransfer t) +inline std::string to_string(OutputTileTransfer t) { std::ostringstream oss; oss << t.epilogue.m_xdl_per_wave_per_shuffle << "," << t.epilogue.n_per_wave_per_shuffle << "," - << to_string(t.thread_cluster_dims) << "," << t.epilogue.scalar_per_vector; + << to_string(t.thread_distribution) << "," << t.epilogue.scalar_per_vector; return oss.str(); } template -inline std::string to_string(Transfer t) +inline std::string to_string(InputOutputTileTransfer t) { std::ostringstream oss; oss << to_string(t.a) << "," << to_string(t.b) << "," << to_string(t.c); @@ -271,25 +253,14 @@ inline std::string to_string(ThreadBlock_ t) } template <> -inline std::string to_string(FwdXdlGemm_ t) +inline std::string to_string(WarpGemm_ t) { - return to_string(t.gridwise_gemm); + return to_string(t.warp_gemm); } -template <> -inline std::string to_string(BwdXdlGemm_ t) -{ - return to_string(t.gridwise_gemm); -} - -template <> -inline std::string to_string(WmmaGemm_ t) -{ - return to_string(t.gridwise_gemm); -} template -inline std::string to_string(Transfer_ t) +inline std::string to_string(InputOutputTileTransfer_ t) { return to_string(t.transfer); } @@ -319,9 +290,9 @@ inline std::string to_string(Prefetch_ t) } template <> -inline std::string to_string(BlockGemm_ t) +inline std::string to_string(GemmPipeline_ t) { - return to_string(t.block_gemm_pipeline); + return to_string(t.gemm_pipeline); } template <> @@ -355,8 +326,8 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); return oss.str(); } @@ -365,8 +336,8 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); return oss.str(); } @@ -375,8 +346,8 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); return oss.str(); } @@ -397,8 +368,8 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); return oss.str(); } @@ -407,8 +378,8 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); return oss.str(); } @@ -417,8 +388,8 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); return oss.str(); } @@ -427,8 +398,8 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); return oss.str(); } @@ -437,8 +408,8 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); return oss.str(); } @@ -447,30 +418,22 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); return oss.str(); } +// Covers both XDL and WMMA versions template <> -inline std::string to_string( - ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3 t) +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_CShuffle t) { std::ostringstream oss; - oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); return oss.str(); } -template <> -inline std::string to_string( - ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle t) -{ - std::ostringstream oss; - oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); - return oss.str(); -} template <> inline std::string to_string( @@ -489,8 +452,8 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); return oss.str(); } From 37e9547a2932aaf327e577350e647af37ca92619 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Wed, 7 Jan 2026 03:56:50 -0500 Subject: [PATCH 52/81] Fix ref algorithm dispatching. --- .../builder/include/ck_tile/builder/factory/conv_dispatcher.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp index 319293cff14..e235db4bb09 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -112,7 +112,7 @@ constexpr auto make_conv_instance() return typename ReferenceFactory::Instance{}; } // CK Tile supports common factory for each direction - if constexpr(TileAlgorithm) + else if constexpr(TileAlgorithm) { return typename ConvTileFactory::Instance{}; } From 82803221c3335a2efa2aba0915573e8bb9d14656 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Wed, 7 Jan 2026 03:57:02 -0500 Subject: [PATCH 53/81] Fix smoke tests. --- experimental/builder/test/test_conv_description.cpp | 2 +- experimental/builder/test/unit_conv_tuning_params.cpp | 7 ++----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/experimental/builder/test/test_conv_description.cpp b/experimental/builder/test/test_conv_description.cpp index 498de9a42fd..7204b09157c 100644 --- a/experimental/builder/test/test_conv_description.cpp +++ b/experimental/builder/test/test_conv_description.cpp @@ -161,7 +161,7 @@ struct DefaultAlgorithm ckb::ConvSpecialization fwd_specialization = ckb::ConvSpecialization::DEFAULT; ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::Default; - ckb::test::BlockGemm block_gemm{.pipeline_version = ckb::PipelineVersion::V4, + ckb::test::BlockGemmPipeline block_gemm_pipeline{.pipeline_version = ckb::PipelineVersion::V4, .scheduler = ckb::PipelineScheduler::INTRAWAVE}; }; static_assert(ckb::ConvAlgorithmDescriptor); diff --git a/experimental/builder/test/unit_conv_tuning_params.cpp b/experimental/builder/test/unit_conv_tuning_params.cpp index ee1388a77f7..90057429309 100644 --- a/experimental/builder/test/unit_conv_tuning_params.cpp +++ b/experimental/builder/test/unit_conv_tuning_params.cpp @@ -19,7 +19,7 @@ TEST(ConvTuningParams, AssignsBlockGemmParams) { ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V3; ckb::PipelineScheduler scheduler = ckb::PipelineScheduler::INTRAWAVE; - } block_gemm; + } block_gemm_pipeline; } kAlgorithm; constexpr auto block_gemm = SetBlockGemm(); @@ -42,10 +42,7 @@ TEST(ConvTuningParams, AssignsGridwiseGemmPipelineVersion) { constexpr struct Algorithm { - struct GridwiseGemm - { - ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V4; - } gridwise_gemm; + ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V4; } kAlgorithm; constexpr auto pipeline_version = SetGridwiseGemmPipelineVersion(); From 00f45cca2e089b67160708ab73d1dfeddadcd537 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Wed, 7 Jan 2026 03:58:09 -0500 Subject: [PATCH 54/81] clang-format --- experimental/builder/test/test_conv_description.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/experimental/builder/test/test_conv_description.cpp b/experimental/builder/test/test_conv_description.cpp index 7204b09157c..9e8008ccf02 100644 --- a/experimental/builder/test/test_conv_description.cpp +++ b/experimental/builder/test/test_conv_description.cpp @@ -162,7 +162,8 @@ struct DefaultAlgorithm ckb::ConvSpecialization fwd_specialization = ckb::ConvSpecialization::DEFAULT; ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::Default; ckb::test::BlockGemmPipeline block_gemm_pipeline{.pipeline_version = ckb::PipelineVersion::V4, - .scheduler = ckb::PipelineScheduler::INTRAWAVE}; + .scheduler = + ckb::PipelineScheduler::INTRAWAVE}; }; static_assert(ckb::ConvAlgorithmDescriptor); From c5cdd51ce4cf5749ca158d466cf3647283d7c4be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Wed, 7 Jan 2026 05:41:22 -0500 Subject: [PATCH 55/81] Fix factory for regular WMMA conv bwd weight. --- .../builder/factory/conv_bwd_weight_wmma_factory.hpp | 8 ++++---- .../conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp | 2 +- experimental/builder/test/impl/conv_algorithm_types.hpp | 2 +- .../builder/test/utils/conv_algorithm_type_utils.hpp | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp index 817432081b0..32161a234ae 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp @@ -51,13 +51,13 @@ struct ConvBwdWeightWmmaFactory static_assert(InputVectorTransferLimits, "Invalid A block transfer config"); static_assert(InputVectorTransferLimits, "Invalid B block transfer config"); static_assert(OutputVectorTransferLimits, "Invalid C block transfer config"); - static_assert(AccessOrderLimits4D, + static_assert(AccessOrderLimits3D, "Invalid A thread cluster access order"); - static_assert(AccessOrderLimits4D, + static_assert(AccessOrderLimits3D, "Invalid B thread cluster access order"); - static_assert(AccessOrderLimits4D, + static_assert(AccessOrderLimits3D, "Invalid A source access order"); - static_assert(AccessOrderLimits4D, + static_assert(AccessOrderLimits3D, "Invalid B source access order"); // The forward convolution kernel class instance. diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp index 0981ea6c11b..ff350ac8049 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp @@ -22,7 +22,7 @@ constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 3, constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle{} .with_thread_block(cku::ThreadBlock_64_32x32x32) .with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave) - .with_transfer(cku::BwdTransfer_4x64x1) + .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT) .with_prefetch_config(1, ckb::PipelineScheduler::DEFAULT) .with_gridwise_gemm_pipeline(ckb::PipelineVersion::V1); diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index 797f440bdce..27ba1ec3b60 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -669,7 +669,7 @@ using ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3 = using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle = ConvAlgorithmTemplate, + Transfer_<>, ConvSpecializationBwdWeight_, GridGemm_, Prefetch_>; diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index b705ab1e471..d80d6a1b8cf 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -428,7 +428,7 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); + << "," << to_string(static_cast>(t)); return oss.str(); } From 16b86803eac6bec1c1963da36e4ebf2f47eae0d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Thu, 8 Jan 2026 05:22:17 -0500 Subject: [PATCH 56/81] Clarify builder Readme. --- experimental/builder/README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/experimental/builder/README.md b/experimental/builder/README.md index 7a93c395c07..1156de0e9ce 100644 --- a/experimental/builder/README.md +++ b/experimental/builder/README.md @@ -45,10 +45,10 @@ cmake .. ``` -Note: When compiling e.g. for the `gfx942` architecture, the WMMA builders are not automatically included in the tests -since `gfx9` architectures do not support WMMA. Hence, to compile also the WMMA builders, add e.g. -`gfx1121` to the list of supported architectures or add flag `-D CK_USE_WMMA=ON`. One still needs -a Navi card to execute the Builder tests that use the GPU. +Note: The tests for WMMA builders are only built when `CK_USE_WMMA` is enabled. Add e.g. +`gfx1121` or any of the other `gfx11`/`gfx12` architectures to the GPU targets. Alternatively, +one can add flag `-D CK_USE_WMMA=ON` to build the tests. For the end-to-end tests that use +the instances from builder, one needs an actual Navi card. ## Building and Testing From fd8edf9d3fc2f0139892d61d8f5d02bb9f79a6f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Thu, 8 Jan 2026 05:25:38 -0500 Subject: [PATCH 57/81] Remove obsolete test file. --- .../test/test_concept_diagnostics_sync.cpp | 418 ------------------ 1 file changed, 418 deletions(-) delete mode 100644 experimental/builder/test/test_concept_diagnostics_sync.cpp diff --git a/experimental/builder/test/test_concept_diagnostics_sync.cpp b/experimental/builder/test/test_concept_diagnostics_sync.cpp deleted file mode 100644 index ca08ae92ed7..00000000000 --- a/experimental/builder/test/test_concept_diagnostics_sync.cpp +++ /dev/null @@ -1,418 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -/** - * @file test_concept_diagnostics_sync.cpp - * @brief Unit tests to ensure concepts and their diagnostics remain in sync - * - * This test suite verifies that: - * 1. Valid types satisfy their corresponding concepts - * 2. Invalid types (missing members) do not satisfy concepts - * 3. Diagnostic messages correctly identify missing requirements - * 4. Existing test types from conv_algorithm_types.hpp satisfy their concepts - */ - -#include -#include - -#include "ck_tile/builder/conv_algorithm_concepts.hpp" -#include "ck_tile/builder/conv_algorithm_diagnostics.hpp" -#include "ck_tile/builder/types.hpp" -#include "experimental/builder/test/impl/conv_algorithm_types.hpp" - -namespace ck_tile::builder::test { - -using ck_tile::builder::AccessOrderDescriptor; -using ck_tile::builder::BlockGemmDescriptor; -using ck_tile::builder::BlockTransferDescriptor; -using ck_tile::builder::ConvAlgorithmDescriptor; -using ck_tile::builder::DlBlockTransferDescriptor; -using ck_tile::builder::DlEpilogueDescriptor; -using ck_tile::builder::DlThreadClusterDescriptor; -using ck_tile::builder::DlThreadConfigDescriptor; -using ck_tile::builder::EpilogueDescriptor; -using ck_tile::builder::GridwiseWmmaGemmDescriptor; -using ck_tile::builder::GridwiseXdlGemmDescriptor; -using ck_tile::builder::LdsTransferDescriptor; -using ck_tile::builder::SpecifiesBlockGemm; -using ck_tile::builder::SpecifiesBwdWeightConvSpecialization; -using ck_tile::builder::SpecifiesDlThreadCluster; -using ck_tile::builder::SpecifiesDlThreadConfig; -using ck_tile::builder::SpecifiesFwdConvSpecialization; -using ck_tile::builder::SpecifiesGemmSpecialization; -using ck_tile::builder::SpecifiesGridwiseBwdXdlGemm; -using ck_tile::builder::SpecifiesGridwiseFwdXdlGemm; -using ck_tile::builder::SpecifiesLoopScheduler; -using ck_tile::builder::SpecifiesNumPrefetchStages; -using ck_tile::builder::SpecifiesThreadBlock; -using ck_tile::builder::SpecifiesTileBlockGemm; -using ck_tile::builder::SpecifiesTileConvSpecialization; -using ck_tile::builder::SpecifiesTileOptimizations; -using ck_tile::builder::SpecifiesTileThreadBlock; -using ck_tile::builder::SpecifiesTileTransfer; -using ck_tile::builder::ThreadBlockDescriptor; -using ck_tile::builder::ThreadClusterDescriptor; -using ck_tile::builder::TileBlockGemmDescriptor; -using ck_tile::builder::TileOptimizationsDescriptor; -using ck_tile::builder::TileThreadBlockDescriptor; -using ck_tile::builder::TileTransferDescriptor; - -// Helper to check if a string contains a substring -bool contains(const std::string& str, const std::string& substr) -{ - return str.find(substr) != std::string::npos; -} - -// ============================================================================= -// BASIC DESCRIPTOR CONCEPTS TESTS -// ============================================================================= - -TEST(ConceptDiagnosticsSync, ThreadBlockDescriptor_Valid) -{ - // The ThreadBlock type from conv_algorithm_types.hpp should satisfy the concept - static_assert(ThreadBlockDescriptor); -} - -TEST(ConceptDiagnosticsSync, GridwiseXdlGemmDescriptor_Valid) -{ - // The XdlParams type should satisfy the concept - static_assert(GridwiseXdlGemmDescriptor); -} - -TEST(ConceptDiagnosticsSync, BlockTransferDescriptor_Valid) -{ - // The BlockTransfer type should satisfy the concept - static_assert(BlockTransferDescriptor); -} - -TEST(ConceptDiagnosticsSync, ThreadClusterDescriptor_Valid) -{ - // The ThreadCluster type should satisfy the concept - static_assert(ThreadClusterDescriptor); -} - -TEST(ConceptDiagnosticsSync, LdsTransferDescriptor_Valid) -{ - // The LdsTransfer type should satisfy the concept - static_assert(LdsTransferDescriptor); -} - -TEST(ConceptDiagnosticsSync, EpilogueDescriptor_Valid) -{ - // The Epilogue type should satisfy the concept - static_assert(EpilogueDescriptor); -} - -TEST(ConceptDiagnosticsSync, AccessOrderDescriptor_Valid) -{ - // The AccessOrder type should satisfy the concept - static_assert(AccessOrderDescriptor); -} - -TEST(ConceptDiagnosticsSync, BlockGemmDescriptor_Valid) -{ - // The BlockGemm type should satisfy the concept - static_assert(BlockGemmDescriptor); -} - -TEST(ConceptDiagnosticsSync, GridwiseWmmaGemmDescriptor_Valid) -{ - // The GridwiseWmmaGemm type should satisfy the concept - static_assert(GridwiseWmmaGemmDescriptor); -} - -// ============================================================================= -// HIGH-LEVEL "SPECIFIES" CONCEPTS TESTS -// ============================================================================= - -TEST(ConceptDiagnosticsSync, SpecifiesThreadBlock_Valid) -{ - static_assert(SpecifiesThreadBlock); -} - -TEST(ConceptDiagnosticsSync, SpecifiesGridwiseFwdXdlGemm_Valid) -{ - static_assert(SpecifiesGridwiseFwdXdlGemm); -} - -TEST(ConceptDiagnosticsSync, SpecifiesGridwiseBwdXdlGemm_Valid) -{ - static_assert(SpecifiesGridwiseBwdXdlGemm); -} - -TEST(ConceptDiagnosticsSync, SpecifiesBlockGemm_Valid) -{ - static_assert(SpecifiesBlockGemm); -} - -TEST(ConceptDiagnosticsSync, SpecifiesFwdConvSpecialization_Valid) -{ - static_assert(SpecifiesFwdConvSpecialization); -} - -TEST(ConceptDiagnosticsSync, SpecifiesBwdWeightConvSpecialization_Valid) -{ - static_assert(SpecifiesBwdWeightConvSpecialization); -} - -TEST(ConceptDiagnosticsSync, SpecifiesGemmSpecialization_Valid) -{ - static_assert(SpecifiesGemmSpecialization); -} - -TEST(ConceptDiagnosticsSync, SpecifiesNumPrefetchStages_Valid) -{ - static_assert(SpecifiesNumPrefetchStages); -} - -TEST(ConceptDiagnosticsSync, SpecifiesLoopScheduler_Valid) -{ - static_assert(SpecifiesLoopScheduler); -} - -// ============================================================================= -// TILE-SPECIFIC CONCEPTS TESTS -// ============================================================================= - -TEST(ConceptDiagnosticsSync, TileThreadBlockDescriptor_Valid) -{ - static_assert(TileThreadBlockDescriptor); -} - -TEST(ConceptDiagnosticsSync, TileTransferDescriptor_Valid) -{ - static_assert(TileTransferDescriptor); -} - -TEST(ConceptDiagnosticsSync, TileBlockGemmDescriptor_Valid) -{ - static_assert(TileBlockGemmDescriptor); -} - -TEST(ConceptDiagnosticsSync, TileOptimizationsDescriptor_Valid) -{ - static_assert(TileOptimizationsDescriptor); -} - -TEST(ConceptDiagnosticsSync, SpecifiesTileThreadBlock_Valid) -{ - static_assert(SpecifiesTileThreadBlock); -} - -TEST(ConceptDiagnosticsSync, SpecifiesTileTransfer_Valid) -{ - static_assert(SpecifiesTileTransfer); -} - -TEST(ConceptDiagnosticsSync, SpecifiesTileBlockGemm_Valid) -{ - static_assert(SpecifiesTileBlockGemm); -} - -TEST(ConceptDiagnosticsSync, SpecifiesTileOptimizations_Valid) -{ - static_assert(SpecifiesTileOptimizations); -} - -TEST(ConceptDiagnosticsSync, SpecifiesTileConvSpecialization_Valid) -{ - static_assert(SpecifiesTileConvSpecialization); -} - -// ============================================================================= -// DL-SPECIFIC CONCEPTS TESTS -// ============================================================================= - -TEST(ConceptDiagnosticsSync, DlThreadConfigDescriptor_Valid) -{ - static_assert(DlThreadConfigDescriptor); -} - -TEST(ConceptDiagnosticsSync, DlThreadClusterDescriptor_Valid) -{ - static_assert(DlThreadClusterDescriptor); -} - -TEST(ConceptDiagnosticsSync, DlBlockTransferDescriptor_Valid) -{ - static_assert(DlBlockTransferDescriptor); -} - -TEST(ConceptDiagnosticsSync, DlEpilogueDescriptor_Valid) -{ - static_assert(DlEpilogueDescriptor); -} - -TEST(ConceptDiagnosticsSync, SpecifiesDlThreadConfig_Valid) -{ - static_assert(SpecifiesDlThreadConfig); -} - -TEST(ConceptDiagnosticsSync, SpecifiesDlThreadCluster_Valid) -{ - static_assert(SpecifiesDlThreadCluster); -} - -// ============================================================================= -// INVALID TYPE TESTS - Test that concepts correctly reject invalid types -// ============================================================================= - -namespace invalid_types { - -// Test ThreadBlockDescriptor with missing members -struct MissingBlockSize -{ - struct - { - size_t m, n, k; - } tile_size; -}; - -struct MissingTileSizeM -{ - size_t block_size; - struct - { - size_t n, k; - } tile_size; -}; - -// Test GridwiseXdlGemmDescriptor with missing members -struct MissingMPerXdl -{ - size_t n_per_xdl; - size_t m_xdl_per_wave; - size_t n_xdl_per_wave; -}; - -// Test BlockTransferDescriptor with missing members -struct MissingK0 -{ - size_t m_n; - size_t k1; -}; - -// Test LdsTransferDescriptor with missing members -struct MissingSrcVectorDim -{ - size_t src_scalar_per_vector; - size_t lds_dst_scalar_per_vector; - bool is_direct_load; - bool lds_padding; -}; - -} // namespace invalid_types - -TEST(ConceptDiagnosticsSync, ThreadBlockDescriptor_Invalid) -{ - static_assert(!ThreadBlockDescriptor); - static_assert(!ThreadBlockDescriptor); -} - -TEST(ConceptDiagnosticsSync, GridwiseXdlGemmDescriptor_Invalid) -{ - static_assert(!GridwiseXdlGemmDescriptor); -} - -TEST(ConceptDiagnosticsSync, BlockTransferDescriptor_Invalid) -{ - static_assert(!BlockTransferDescriptor); -} - -TEST(ConceptDiagnosticsSync, LdsTransferDescriptor_Invalid) -{ - static_assert(!LdsTransferDescriptor); -} - -// ============================================================================= -// COMPREHENSIVE ALGORITHM TYPE TESTS -// ============================================================================= - -TEST(ConceptDiagnosticsSync, CompleteAlgorithmTypes) -{ - // Test that complete algorithm types satisfy their concepts - static_assert( - ConvAlgorithmDescriptor); - static_assert( - ConvAlgorithmDescriptor); - static_assert( - ConvAlgorithmDescriptor); - static_assert(ConvAlgorithmDescriptor); - static_assert(ConvAlgorithmDescriptor); - - // Test specific requirements for each algorithm type - static_assert(SpecifiesThreadBlock); - static_assert( - SpecifiesGridwiseFwdXdlGemm); - static_assert( - SpecifiesFwdConvSpecialization); - static_assert( - SpecifiesNumPrefetchStages); - - static_assert(SpecifiesTileThreadBlock); - static_assert(SpecifiesTileBlockGemm); - static_assert(SpecifiesTileOptimizations); -} - -// ============================================================================= -// DIAGNOSTIC MESSAGE TESTS -// ============================================================================= - -TEST(ConceptDiagnosticsSync, DiagnosticMessages) -{ - // Test that diagnostics can be called (even if messages may be empty at compile-time) - // The key is that the diagnostic functions exist and compile - std::string diag1 = ck_tile::builder::diagnostics::detailed_diagnostic_SpecifiesThreadBlock< - invalid_types::MissingBlockSize>(); - std::string diag2 = - ck_tile::builder::diagnostics::detailed_diagnostic_SpecifiesGridwiseFwdXdlGemm< - invalid_types::MissingMPerXdl>(); - - // These may be empty depending on the implementation, but they should compile - EXPECT_TRUE(diag1.empty() || contains(diag1, "thread_block") || contains(diag1, "missing")); - EXPECT_TRUE(diag2.empty() || contains(diag2, "gridwise_gemm") || contains(diag2, "missing")); -} - -// ============================================================================= -// CONCEPT COMPLETENESS TESTS -// ============================================================================= - -/** - * @brief Verify that all concepts defined in conv_algorithm_concepts.hpp have tests - * - * This test serves as documentation of which concepts are tested. If new concepts - * are added, this test should be updated to include them. - */ -TEST(ConceptDiagnosticsSync, ConceptCoverage) -{ - // Basic Descriptor Concepts - verify they all exist and can be instantiated - EXPECT_TRUE((ThreadBlockDescriptor)); - EXPECT_TRUE((GridwiseXdlGemmDescriptor)); - EXPECT_TRUE((BlockGemmDescriptor)); - EXPECT_TRUE((GridwiseWmmaGemmDescriptor)); - EXPECT_TRUE((BlockTransferDescriptor)); - EXPECT_TRUE((ThreadClusterDescriptor)); - EXPECT_TRUE((LdsTransferDescriptor)); - EXPECT_TRUE((EpilogueDescriptor)); - EXPECT_TRUE((AccessOrderDescriptor)); - - // Tile Descriptor Concepts - EXPECT_TRUE((TileThreadBlockDescriptor)); - EXPECT_TRUE((TileTransferDescriptor)); - EXPECT_TRUE((TileBlockGemmDescriptor)); - EXPECT_TRUE((TileOptimizationsDescriptor)); - - // DL Descriptor Concepts - EXPECT_TRUE((DlThreadConfigDescriptor)); - EXPECT_TRUE((DlThreadClusterDescriptor)); - EXPECT_TRUE((DlBlockTransferDescriptor)); - EXPECT_TRUE((DlEpilogueDescriptor)); -} - -} // namespace ck_tile::builder::test - -int main(int argc, char** argv) -{ - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} From 18c2631c79c3ec540d538125a58196a2dce7a96c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Thu, 8 Jan 2026 06:51:57 -0500 Subject: [PATCH 58/81] Fix test after merge. --- .../builder/test/conv/ck/unit_instance_to_conv_traits.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/experimental/builder/test/conv/ck/unit_instance_to_conv_traits.cpp b/experimental/builder/test/conv/ck/unit_instance_to_conv_traits.cpp index de2a4fdd142..81a9175dd0b 100644 --- a/experimental/builder/test/conv/ck/unit_instance_to_conv_traits.cpp +++ b/experimental/builder/test/conv/ck/unit_instance_to_conv_traits.cpp @@ -230,7 +230,7 @@ TEST(InstanceToConvTraits, ExtractsDefaultSpecialization) using Traits = ck_tile::reflect::conv::ConvTraits; - EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT); + EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT); } TEST(InstanceToConvTraits, ExtractsFilter1x1Pad0Specialization) @@ -290,7 +290,7 @@ TEST(InstanceToConvTraits, ExtractsFilter1x1Pad0Specialization) using Traits = ck_tile::reflect::conv::ConvTraits; EXPECT_EQ(Traits::conv_specialization, - ck_tile::builder::ConvFwdSpecialization::FILTER_1X1_PAD0); + ck_tile::builder::ConvSpecialization::FILTER_1X1_PAD0); } // ============================================================================ From 0336ac573ea969caefd8bddf5d131791a5fa7409 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Thu, 8 Jan 2026 06:57:34 -0500 Subject: [PATCH 59/81] clang-format --- .../builder/test/conv/ck/unit_instance_to_conv_traits.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/experimental/builder/test/conv/ck/unit_instance_to_conv_traits.cpp b/experimental/builder/test/conv/ck/unit_instance_to_conv_traits.cpp index 81a9175dd0b..9d6fab19d13 100644 --- a/experimental/builder/test/conv/ck/unit_instance_to_conv_traits.cpp +++ b/experimental/builder/test/conv/ck/unit_instance_to_conv_traits.cpp @@ -289,8 +289,7 @@ TEST(InstanceToConvTraits, ExtractsFilter1x1Pad0Specialization) using Traits = ck_tile::reflect::conv::ConvTraits; - EXPECT_EQ(Traits::conv_specialization, - ck_tile::builder::ConvSpecialization::FILTER_1X1_PAD0); + EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::FILTER_1X1_PAD0); } // ============================================================================ From 3f0bac4e7b5e16170104be483663e1f26e1f6bf7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Thu, 8 Jan 2026 09:18:37 -0500 Subject: [PATCH 60/81] Fix conv algorithm types after refactoring. --- .../test/impl/conv_algorithm_types.hpp | 24 ++++--------------- 1 file changed, 4 insertions(+), 20 deletions(-) diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index c8cf809d3aa..2e0c79ec5c3 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -565,7 +565,7 @@ using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle = using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle = ConvAlgorithmTemplate, + InputOutputTileTransfer_<>, ConvSpecializationBwdWeight_, GemmPipeline_, Prefetch_>; @@ -610,32 +610,16 @@ using ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle = ConvSpecializationBwdWeight_, MultipleDSpecialization_>; -using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3 = - ConvAlgorithmTemplate, - ConvSpecializationBwdWeight_, - BlockGemm_, - TransposeParams_>; - using ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3 = ConvAlgorithmTemplate, + WarpGemm_, + InputOutputTileTransfer_<>, ConvSpecializationBwdWeight_, - BlockGemm_, + GemmPipeline_, TransposeParams_, GemmBatchOptions_, TwoStageSpecialization_>; -using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle = - ConvAlgorithmTemplate, - ConvSpecializationBwdWeight_, - GridGemm_, - Prefetch_>; - using ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_V3 = ConvAlgorithmTemplate Date: Fri, 9 Jan 2026 09:17:45 -0500 Subject: [PATCH 61/81] Adapt factories to warp GEMM and transfer parameters refactoring. --- .../builder/conv_algorithm_concepts.hpp | 20 +++++++------- ...onv_bwd_weight_multi_d_wmma_v3_factory.hpp | 17 +++++++----- .../conv_bwd_weight_multi_d_xdl_factory.hpp | 18 ++++++++----- ...v_bwd_weight_two_stage_wmma_v3_factory.hpp | 17 +++++++----- .../conv_bwd_weight_two_stage_xdl_factory.hpp | 18 ++++++++----- .../factory/conv_bwd_weight_wmma_factory.hpp | 17 +++++++----- .../conv_bwd_weight_wmma_v3_factory.hpp | 17 +++++++----- .../factory/conv_bwd_weight_xdl_factory.hpp | 19 ++++++++----- .../conv_bwd_weight_xdl_v3_factory.hpp | 18 ++++++++----- .../factory/conv_fwd_large_tensor_factory.hpp | 15 +++++------ .../builder/factory/conv_fwd_v3_factory.hpp | 21 +++++++-------- .../builder/factory/conv_fwd_wmma_factory.hpp | 17 +++++++----- .../builder/factory/conv_fwd_xdl_factory.hpp | 15 +++++------ .../helpers/ck/conv_block_transfer.hpp | 27 +++++++++++-------- .../factory/helpers/ck/conv_tuning_params.hpp | 2 +- .../test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp | 2 +- 16 files changed, 152 insertions(+), 108 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index cbc277a8814..d036ab7ec68 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -172,8 +172,8 @@ concept SpecifiesTileThreadBlock = requires { // Concept to check if a struct specifies warp GEMM info. template -concept SpecifiesWarpGemm = requires(T t) { - { t.warp_gemm } -> WarpGemmDescriptor; +concept SpecifiesWarpGemm = requires { + { T::warp_gemm } -> WarpGemmDescriptor; }; // Concept to check if a struct specifies convolution input and output block transfer info. @@ -212,8 +212,8 @@ concept SpecifiesLdsTransfer = requires(T t) { // Concept to check if a struct specifies thread cluster access order info. template concept SpecifiesThreadClusterAccessOrder = requires(T t) { - { T::transfer.a.block_transfer_access_order } -> AccessOrderDescriptor; - { T::transfer.b.block_transfer_access_order } -> AccessOrderDescriptor; + { T::transfer.a.thread_distribution_access_order } -> AccessOrderDescriptor; + { T::transfer.b.thread_distribution_access_order } -> AccessOrderDescriptor; }; // Concept to check if a struct specifies source access order info. @@ -341,15 +341,15 @@ concept SpecifiesMultipleDSupport = requires { }; template -concept SpecifiesXdl = requires { - { T::warp_gemm.matrix_instruction } -> std::convertible_to; - requires T::warp_gemm.matrix_instruction == MatrixInstructionType::XDL; +concept SpecifiesXdl = requires (T t){ + { t.warp_gemm.matrix_instruction } -> std::convertible_to; + { t.warp_gemm.matrix_instruction == MatrixInstructionType::XDL}; }; template -concept SpecifiesWmma = requires { - { T::warp_gemm.matrix_instruction } -> std::convertible_to; - requires T::warp_gemm.matrix_instruction == MatrixInstructionType::WMMA; +concept SpecifiesWmma = requires (T t){ + { t.warp_gemm.matrix_instruction } -> std::convertible_to; + { t.warp_gemm.matrix_instruction == MatrixInstructionType::WMMA}; }; /******************************************** */ diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp index 6f5c679b595..234ba398296 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp @@ -35,7 +35,7 @@ struct ConvBwdWeightMultiDWmmaV3Factory internal::SetBwdWeightConvSpecialization(); static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto WARP_GEMM = ALGORITHM.warp_gemm; static constexpr auto A_BLOCK_TRANSFER = internal::SetBwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -57,6 +57,11 @@ struct ConvBwdWeightMultiDWmmaV3Factory static_assert(AccessOrderLimits3D, "Invalid B source access order"); + static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size == + B_BLOCK_TRANSFER.global_memory_vector_load_size, + "A nd B block transfer vector load size need to be the same"); + static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size; + // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< @@ -78,11 +83,11 @@ struct ConvBwdWeightMultiDWmmaV3Factory BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k, - GRIDWISE_GEMM.k1, - GRIDWISE_GEMM.m_per_wmma, - GRIDWISE_GEMM.n_per_wmma, - GRIDWISE_GEMM.m_wmma_per_wave, - GRIDWISE_GEMM.n_wmma_per_wave, + GMEM_VECTOR_LOAD_SIZE, + WARP_GEMM.gemm_m_per_instruction, + WARP_GEMM.gemm_n_per_instruction, + WARP_GEMM.gemm_m_iters_per_wave, + WARP_GEMM.gemm_n_iters_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp index 9f76568ca89..843c4e0f90c 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp @@ -35,8 +35,7 @@ struct ConvBwdWeightMultiDXdlFactory internal::SetBwdWeightConvSpecialization(); static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; - static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; + static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm; static constexpr auto A_BLOCK_TRANSFER = internal::SetBwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -53,6 +52,11 @@ struct ConvBwdWeightMultiDXdlFactory static_assert(AccessOrderLimits4D); static_assert(AccessOrderLimits4D); + static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size == + B_BLOCK_TRANSFER.global_memory_vector_load_size, + "A nd B block transfer vector load size need to be the same"); + static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size; + // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< SPATIAL_DIM, @@ -73,11 +77,11 @@ struct ConvBwdWeightMultiDXdlFactory BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k, - GRIDWISE_GEMM.k1, - XDL_PARAMS.m_per_xdl, - XDL_PARAMS.n_per_xdl, - XDL_PARAMS.m_xdl_per_wave, - XDL_PARAMS.n_xdl_per_wave, + GMEM_VECTOR_LOAD_SIZE, + XDL_PARAMS.gemm_m_per_instruction, + XDL_PARAMS.gemm_n_per_instruction, + XDL_PARAMS.gemm_m_iters_per_wave, + XDL_PARAMS.gemm_n_iters_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp index 86c48fe3220..48a15a16385 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp @@ -35,7 +35,7 @@ struct ConvBwdWeightTwoStageWmmaV3Factory internal::SetBwdWeightConvSpecialization(); static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto WARP_GEMM = ALGORITHM.warp_gemm; static constexpr auto A_BLOCK_TRANSFER = internal::SetBwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -57,6 +57,11 @@ struct ConvBwdWeightTwoStageWmmaV3Factory static_assert(AccessOrderLimits3D, "Invalid B source access order"); + static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size == + B_BLOCK_TRANSFER.global_memory_vector_load_size, + "A nd B block transfer vector load size need to be the same"); + static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size; + // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< @@ -76,11 +81,11 @@ struct ConvBwdWeightTwoStageWmmaV3Factory BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k, - GRIDWISE_GEMM.k1, - GRIDWISE_GEMM.m_per_wmma, - GRIDWISE_GEMM.n_per_wmma, - GRIDWISE_GEMM.m_wmma_per_wave, - GRIDWISE_GEMM.n_wmma_per_wave, + GMEM_VECTOR_LOAD_SIZE, + WARP_GEMM.gemm_m_per_instruction, + WARP_GEMM.gemm_n_per_instruction, + WARP_GEMM.gemm_m_iters_per_wave, + WARP_GEMM.gemm_n_iters_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp index 9c37beae46c..5eea36313f7 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp @@ -35,8 +35,7 @@ struct ConvBwdWeightTwoStageXdlFactory internal::SetBwdWeightConvSpecialization(); static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; - static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; + static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm; static constexpr auto A_BLOCK_TRANSFER = internal::SetBwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -58,6 +57,11 @@ struct ConvBwdWeightTwoStageXdlFactory static_assert(AccessOrderLimits3D, "Invalid B source access order"); + static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size == + B_BLOCK_TRANSFER.global_memory_vector_load_size, + "A nd B block transfer vector load size need to be the same"); + static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size; + // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< SPATIAL_DIM, @@ -76,11 +80,11 @@ struct ConvBwdWeightTwoStageXdlFactory BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k, - GRIDWISE_GEMM.k1, - XDL_PARAMS.m_per_xdl, - XDL_PARAMS.n_per_xdl, - XDL_PARAMS.m_xdl_per_wave, - XDL_PARAMS.n_xdl_per_wave, + GMEM_VECTOR_LOAD_SIZE, + XDL_PARAMS.gemm_m_per_instruction, + XDL_PARAMS.gemm_n_per_instruction, + XDL_PARAMS.gemm_m_iters_per_wave, + XDL_PARAMS.gemm_n_iters_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp index 32161a234ae..8e22958ac13 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp @@ -35,7 +35,7 @@ struct ConvBwdWeightWmmaFactory internal::SetBwdWeightConvSpecialization(); static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto WARP_GEMM = ALGORITHM.warp_gemm; static constexpr auto GRIDWISE_GEMM_PIPELINE_VERSION = internal::SetGridwiseGemmPipelineVersion(); static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler(); @@ -60,6 +60,11 @@ struct ConvBwdWeightWmmaFactory static_assert(AccessOrderLimits3D, "Invalid B source access order"); + static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size == + B_BLOCK_TRANSFER.global_memory_vector_load_size, + "A nd B block transfer vector load size need to be the same"); + static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size; + // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle< SPATIAL_DIM, @@ -78,11 +83,11 @@ struct ConvBwdWeightWmmaFactory BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k, - GRIDWISE_GEMM.k1, - GRIDWISE_GEMM.m_per_wmma, - GRIDWISE_GEMM.n_per_wmma, - GRIDWISE_GEMM.m_wmma_per_wave, - GRIDWISE_GEMM.n_wmma_per_wave, + GMEM_VECTOR_LOAD_SIZE, + WARP_GEMM.gemm_m_per_instruction, + WARP_GEMM.gemm_n_per_instruction, + WARP_GEMM.gemm_m_iters_per_wave, + WARP_GEMM.gemm_n_iters_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp index baf84402c34..463749958a6 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp @@ -35,7 +35,7 @@ struct ConvBwdWeightWmmaV3Factory internal::SetBwdWeightConvSpecialization(); static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto WARP_GEMM = ALGORITHM.warp_gemm; static constexpr auto A_BLOCK_TRANSFER = internal::SetBwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -57,6 +57,11 @@ struct ConvBwdWeightWmmaV3Factory static_assert(AccessOrderLimits3D, "Invalid B source access order"); + static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size == + B_BLOCK_TRANSFER.global_memory_vector_load_size, + "A nd B block transfer vector load size need to be the same"); + static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size; + // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< SPATIAL_DIM, @@ -75,11 +80,11 @@ struct ConvBwdWeightWmmaV3Factory BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k, - GRIDWISE_GEMM.k1, - GRIDWISE_GEMM.m_per_wmma, - GRIDWISE_GEMM.n_per_wmma, - GRIDWISE_GEMM.m_wmma_per_wave, - GRIDWISE_GEMM.n_wmma_per_wave, + GMEM_VECTOR_LOAD_SIZE, + WARP_GEMM.gemm_m_per_instruction, + WARP_GEMM.gemm_n_per_instruction, + WARP_GEMM.gemm_m_iters_per_wave, + WARP_GEMM.gemm_n_iters_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp index 91c19d2bd0d..ba5fdb2c536 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp @@ -35,8 +35,7 @@ struct ConvBwdWeightXdlFactory internal::SetBwdWeightConvSpecialization(); static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; - static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; + static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm; static constexpr auto A_BLOCK_TRANSFER = internal::SetBwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -53,6 +52,12 @@ struct ConvBwdWeightXdlFactory static_assert(AccessOrderLimits4D); static_assert(AccessOrderLimits4D); + static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size == + B_BLOCK_TRANSFER.global_memory_vector_load_size, + "A nd B block transfer vector load size need to be the same"); + static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size; + + // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle< SPATIAL_DIM, @@ -71,11 +76,11 @@ struct ConvBwdWeightXdlFactory BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k, - GRIDWISE_GEMM.k1, - XDL_PARAMS.m_per_xdl, - XDL_PARAMS.n_per_xdl, - XDL_PARAMS.m_xdl_per_wave, - XDL_PARAMS.n_xdl_per_wave, + GMEM_VECTOR_LOAD_SIZE, + XDL_PARAMS.gemm_m_per_instruction, + XDL_PARAMS.gemm_n_per_instruction, + XDL_PARAMS.gemm_m_iters_per_wave, + XDL_PARAMS.gemm_n_iters_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp index f3edd0e6d93..ab4dbea2f4c 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp @@ -35,8 +35,7 @@ struct ConvBwdWeightXdlV3Factory internal::SetBwdWeightConvSpecialization(); static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; - static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; + static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm; static constexpr auto A_BLOCK_TRANSFER = internal::SetBwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -58,6 +57,11 @@ struct ConvBwdWeightXdlV3Factory static_assert(AccessOrderLimits3D, "Invalid B source access order"); + static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size == + B_BLOCK_TRANSFER.global_memory_vector_load_size, + "A nd B block transfer vector load size need to be the same"); + static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size; + // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< SPATIAL_DIM, @@ -76,11 +80,11 @@ struct ConvBwdWeightXdlV3Factory BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k, - GRIDWISE_GEMM.k1, - XDL_PARAMS.m_per_xdl, - XDL_PARAMS.n_per_xdl, - XDL_PARAMS.m_xdl_per_wave, - XDL_PARAMS.n_xdl_per_wave, + GMEM_VECTOR_LOAD_SIZE, + XDL_PARAMS.gemm_m_per_instruction, + XDL_PARAMS.gemm_n_per_instruction, + XDL_PARAMS.gemm_m_iters_per_wave, + XDL_PARAMS.gemm_n_iters_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp index 9cd56ad7ad2..fdb95d602a8 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp @@ -38,8 +38,7 @@ struct ConvFwdLargeTensorFactory static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler(); static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; - static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; + static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm; static constexpr auto A_BLOCK_TRANSFER = internal::SetFwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -79,12 +78,12 @@ struct ConvFwdLargeTensorFactory BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k, - GRIDWISE_GEMM.ak1, - GRIDWISE_GEMM.bk1, - XDL_PARAMS.m_per_xdl, - XDL_PARAMS.n_per_xdl, - XDL_PARAMS.m_xdl_per_wave, - XDL_PARAMS.n_xdl_per_wave, + A_BLOCK_TRANSFER.global_memory_vector_load_size, + B_BLOCK_TRANSFER.global_memory_vector_load_size, + XDL_PARAMS.gemm_m_per_instruction, + XDL_PARAMS.gemm_n_per_instruction, + XDL_PARAMS.gemm_m_iters_per_wave, + XDL_PARAMS.gemm_n_iters_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp index 7d889a0c01b..a64929d1581 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp @@ -31,19 +31,18 @@ struct ConvFwdXdlV3Factory using Ops = internal::ElementwiseOps; using AlgorithmType = decltype(ALGORITHM); - static_assert(ALGORITHM.transfer.a.lds_transfer.is_direct_load == - ALGORITHM.transfer.b.lds_transfer.is_direct_load, + static_assert(ALGORITHM.transfer.a.lds_transfer_params.is_direct_load == + ALGORITHM.transfer.b.lds_transfer_params.is_direct_load, "A and B block transfers must both be direct load or not."); - static constexpr bool IS_DIRECT_LOAD = ALGORITHM.transfer.a.lds_transfer.is_direct_load; + static constexpr bool IS_DIRECT_LOAD = ALGORITHM.transfer.a.lds_transfer_params.is_direct_load; static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization(); static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization(); static constexpr internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION, .gemm_spec = GEMM_SPECIALIZATION}; static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; - static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; + static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm; static constexpr auto A_BLOCK_TRANSFER = internal::SetFwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -83,12 +82,12 @@ struct ConvFwdXdlV3Factory BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k, - GRIDWISE_GEMM.ak1, - GRIDWISE_GEMM.bk1, - XDL_PARAMS.m_per_xdl, - XDL_PARAMS.n_per_xdl, - XDL_PARAMS.m_xdl_per_wave, - XDL_PARAMS.n_xdl_per_wave, + A_BLOCK_TRANSFER.global_memory_vector_load_size, + B_BLOCK_TRANSFER.global_memory_vector_load_size, + XDL_PARAMS.gemm_m_per_instruction, + XDL_PARAMS.gemm_n_per_instruction, + XDL_PARAMS.gemm_m_iters_per_wave, + XDL_PARAMS.gemm_n_iters_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp index 3506f5d1a92..d52f684d8c0 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp @@ -38,7 +38,7 @@ struct ConvFwdWmmaFactory static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler(); static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto WARP_GEMM = ALGORITHM.warp_gemm; static constexpr auto GRIDWISE_GEMM_PIPELINE_VERSION = internal::SetGridwiseGemmPipelineVersion(); static constexpr auto A_BLOCK_TRANSFER = @@ -57,6 +57,11 @@ struct ConvFwdWmmaFactory static_assert(AccessOrderLimits3D); static_assert(AccessOrderLimits3D); + static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size == + B_BLOCK_TRANSFER.global_memory_vector_load_size, + "A nd B block transfer vector load size need to be the same"); + static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size; + // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< SPATIAL_DIM, @@ -80,11 +85,11 @@ struct ConvFwdWmmaFactory BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k, - GRIDWISE_GEMM.k1, - GRIDWISE_GEMM.m_per_wmma, - GRIDWISE_GEMM.n_per_wmma, - GRIDWISE_GEMM.m_wmma_per_wave, - GRIDWISE_GEMM.n_wmma_per_wave, + GMEM_VECTOR_LOAD_SIZE, + WARP_GEMM.gemm_m_per_instruction, + WARP_GEMM.gemm_n_per_instruction, + WARP_GEMM.gemm_m_iters_per_wave, + WARP_GEMM.gemm_n_iters_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp index 446ceceda25..eb2fdfad4d3 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp @@ -38,8 +38,7 @@ struct ConvFwdXdlFactory static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler(); static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; - static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; + static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm; static constexpr auto A_BLOCK_TRANSFER = internal::SetFwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -79,12 +78,12 @@ struct ConvFwdXdlFactory BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k, - GRIDWISE_GEMM.ak1, - GRIDWISE_GEMM.bk1, - XDL_PARAMS.m_per_xdl, - XDL_PARAMS.n_per_xdl, - XDL_PARAMS.m_xdl_per_wave, - XDL_PARAMS.n_xdl_per_wave, + A_BLOCK_TRANSFER.global_memory_vector_load_size, + B_BLOCK_TRANSFER.global_memory_vector_load_size, + XDL_PARAMS.gemm_m_per_instruction, + XDL_PARAMS.gemm_n_per_instruction, + XDL_PARAMS.gemm_m_iters_per_wave, + XDL_PARAMS.gemm_n_iters_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp index 4ef2f533c98..dfe35355adf 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp @@ -15,6 +15,7 @@ struct BlockTransfer ck::Array thread_cluster_dims{}; // k0, m, k1 ck::Array thread_cluster_order{}; ck::Array src_access_order{}; + size_t global_memory_vector_load_size = 0; size_t src_vector_dim = 0; size_t src_scalar_per_vector = 0; size_t lds_dst_scalar_per_vector = 0; @@ -28,6 +29,7 @@ struct BwdBlockTransfer ck::Array thread_cluster_dims{}; ck::Array thread_cluster_order{}; ck::Array src_access_order{}; + size_t global_memory_vector_load_size = 0; size_t src_vector_dim = 0; size_t src_scalar_per_vector = 0; size_t lds_dst_scalar_per_vector = 0; @@ -37,15 +39,16 @@ struct BwdBlockTransfer template constexpr BlockTransfer SetFwdConvBlockTransfer() { - auto& block_xfer = TRANSFER.block_transfer; - auto& block_order = TRANSFER.block_transfer_access_order; + auto& block_xfer = TRANSFER.thread_distribution; + auto& block_order = TRANSFER.thread_distribution_access_order; auto& src_order = TRANSFER.src_access_order; - auto& lds_cfg = TRANSFER.lds_transfer; + auto& lds_cfg = TRANSFER.lds_transfer_params; return BlockTransfer{ .thread_cluster_dims = {block_xfer.k0, block_xfer.m_n, block_xfer.k1}, .thread_cluster_order = {block_order.order[0], block_order.order[1], block_order.order[2]}, .src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2]}, + .global_memory_vector_load_size = lds_cfg.global_memory_vector_load_size, .src_vector_dim = lds_cfg.src_vector_dim, .src_scalar_per_vector = lds_cfg.src_scalar_per_vector, .lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector, @@ -57,10 +60,10 @@ constexpr BlockTransfer SetFwdConvBlockTransfer() template constexpr auto SetBwdConvBlockTransfer() { - auto& block_xfer = TRANSFER.block_transfer; - auto& block_order = TRANSFER.block_transfer_access_order; + auto& block_xfer = TRANSFER.thread_distribution; + auto& block_order = TRANSFER.thread_distribution_access_order; auto& src_order = TRANSFER.src_access_order; - auto& lds_cfg = TRANSFER.lds_transfer; + auto& lds_cfg = TRANSFER.lds_transfer_params; constexpr auto array_length = block_order.order.size(); static_assert(block_order.order.size() == src_order.order.size(), @@ -74,6 +77,7 @@ constexpr auto SetBwdConvBlockTransfer() block_order.order[1], block_order.order[2]}, .src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2]}, + .global_memory_vector_load_size = lds_cfg.global_memory_vector_load_size, .src_vector_dim = lds_cfg.src_vector_dim, .src_scalar_per_vector = lds_cfg.src_scalar_per_vector, .lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector, @@ -95,6 +99,7 @@ constexpr auto SetBwdConvBlockTransfer() src_order.order[1], src_order.order[2], src_order.order[3]}, + .global_memory_vector_load_size = lds_cfg.global_memory_vector_load_size, .src_vector_dim = lds_cfg.src_vector_dim, .src_scalar_per_vector = lds_cfg.src_scalar_per_vector, .lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector, @@ -119,17 +124,17 @@ struct CBlockTransfer template constexpr CBlockTransfer SetCBlockTransfer() { - auto& thread_cluster_dims = ALGORITHM.transfer.c.thread_cluster_dims; + auto& thread_cluster_dims = ALGORITHM.transfer.c.thread_distribution; auto& epilogue_config = ALGORITHM.transfer.c.epilogue; return CBlockTransfer{ .m_xdl_per_wave_per_shuffle = epilogue_config.m_xdl_per_wave_per_shuffle, .n_xdl_per_wave_per_shuffle = epilogue_config.n_per_wave_per_shuffle, .thread_cluster_dims = { - thread_cluster_dims.m_block, - thread_cluster_dims.m_wave_per_xdl, - thread_cluster_dims.n_block, - thread_cluster_dims.n_wave_per_xdl, + thread_cluster_dims.gemm_m_block_size, + thread_cluster_dims.gemm_m_per_block, + thread_cluster_dims.gemm_n_block_size, + thread_cluster_dims.gemm_n_per_block, }, .scalar_per_vector = epilogue_config.scalar_per_vector, }; diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp index 9ed1eebc3c0..29cf3f8513a 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp @@ -38,7 +38,7 @@ struct BlockGemmSpec template consteval BlockGemmSpec SetBlockGemm() { - constexpr auto& BG = ALGORITHM.block_gemm_pipeline; + constexpr auto& BG = ALGORITHM.gemm_pipeline; ck::BlockGemmPipelineScheduler scheduler; ck::BlockGemmPipelineVersion version; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp index d087024913a..8ee12a46ba9 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp @@ -33,7 +33,7 @@ TEST(FwdConvInstances, .with_gemm_config(GemmParams_Wmma_2x1_per_wave) .with_transfer(Transfer_4x32x1) .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) - .with_prefetch_config(1, PipelineScheduler::DEFAULT) + .with_prefetch_config(1, PipelineScheduler::INTRAWAVE) .with_num_conv_groups_to_merge(2) .with_gemm_pipeline(PipelineVersion::V1); From 63fc27b0b1159a62bc2e8a59b12b51fbf6d010da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Fri, 9 Jan 2026 10:05:54 -0500 Subject: [PATCH 62/81] Refactor algorithm specialization and GEMM pipeline definitions. --- .../builder/conv_algorithm_concepts.hpp | 36 +++--- .../builder/factory/conv_algorithms.hpp | 25 ++-- .../conv_bwd_weight_two_stage_xdl_factory.hpp | 2 +- .../factory/helpers/ck/conv_tuning_params.hpp | 8 +- .../builder/include/ck_tile/builder/types.hpp | 31 ++++- ...conv_bwd_weight_two_stage_xdl_cshuffle.cpp | 2 +- ...test_ckb_conv_bwd_weight_wmma_cshuffle.cpp | 3 +- .../conv/ck/test_ckb_conv_fwd_1d_fp16.cpp | 1 - .../test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp | 4 +- ...est_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp | 1 - .../test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp | 1 - ...test_ckb_conv_fwd_2d_large_tensor_fp16.cpp | 2 - .../test/impl/conv_algorithm_types.hpp | 116 +++++++----------- .../test/utils/conv_algorithm_type_utils.hpp | 38 +----- 14 files changed, 120 insertions(+), 150 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index d036ab7ec68..946b485bdba 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -282,16 +282,6 @@ concept SpecifiesNumGroupsToMerge = requires { { T::num_conv_groups_to_merge } -> SizeType; }; -template -concept SpecifiesLoopScheduler = requires { - { T::loop_scheduler } -> std::convertible_to; -}; - -template -concept SpecifiesGenericInstance = !requires { - { T::specialization }; -}; - template concept SpecifiesTransposeTransfer = requires { { T::max_transpose_transfer_src_scalar_per_vector } -> SizeType; @@ -308,10 +298,6 @@ template concept TransposeTransferWellDefinedIfProvided = !HasTransposeTransfer || SpecifiesTransposeTransfer; -template -concept SpecifiesGemmBatchOptions = requires { - { T::num_conv_groups_to_merge } -> SizeType; -}; /******************************************** */ /* Algorithm specialization concepts */ @@ -319,25 +305,39 @@ concept SpecifiesGemmBatchOptions = requires { template concept SpecifiesLargeTensorSupport = requires { { T::specialization } -> std::convertible_to; - requires T::specialization == ConvAlgorithmSpecialization::LARGE_TENSOR; + requires !!(T::specialization & ConvAlgorithmSpecialization::LARGE_TENSOR); }; template concept SpecifiesReferenceAlgorithm = requires { { T::specialization } -> std::convertible_to; - requires T::specialization == ConvAlgorithmSpecialization::REFERENCE; + requires !!(T::specialization & ConvAlgorithmSpecialization::REFERENCE); }; template concept SpecifiesTwoStageSupport = requires { { T::specialization } -> std::convertible_to; - requires T::specialization == ConvAlgorithmSpecialization::TWO_STAGE; + requires !!(T::specialization & ConvAlgorithmSpecialization::TWO_STAGE); }; template concept SpecifiesMultipleDSupport = requires { { T::specialization } -> std::convertible_to; - requires T::specialization == ConvAlgorithmSpecialization::MULTIPLE_D; + requires !!(T::specialization & ConvAlgorithmSpecialization::MULTIPLE_D); +}; + +template +concept SpecifiesPipelineV3 = requires { + { T::specialization } -> std::convertible_to; + requires !!(T::specialization & ConvAlgorithmSpecialization::PIPELINE_V3); +}; + +template +concept SpecifiesGenericInstance = !requires { + { T::specialization }; +} || requires { + { T::specialization } -> std::convertible_to; + requires !!(T::specialization == ConvAlgorithmSpecialization::NONE); }; template diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index ff7e54546d7..81bafa158dc 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -13,8 +13,7 @@ concept FwdXdlAlgorithmBase = ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesThreadDistribution3D && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && SpecifiesWarpGemm && - SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && - SpecifiesNumPrefetchStages && SpecifiesNumGroupsToMerge && SpecifiesLoopScheduler && + SpecifiesFwdConvSpecialization && SpecifiesGemmPipeline && SpecifiesXdl; template @@ -29,7 +28,8 @@ concept BwdXdlV3AlgorithmBase = ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesThreadDistribution3D && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && SpecifiesWarpGemm && - SpecifiesBwdWeightConvSpecialization && SpecifiesGemmPipeline && SpecifiesXdl; + SpecifiesBwdWeightConvSpecialization && SpecifiesGemmPipeline && SpecifiesXdl && + SpecifiesPipelineV3; template concept BwdWmmaAlgorithmBase = @@ -43,7 +43,8 @@ concept BwdWmmaV3AlgorithmBase = ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesThreadDistribution3D && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && SpecifiesWarpGemm && - SpecifiesBwdWeightConvSpecialization && SpecifiesGemmPipeline && SpecifiesWmma; + SpecifiesBwdWeightConvSpecialization && SpecifiesGemmPipeline && SpecifiesWmma && + SpecifiesPipelineV3; // Reference algorithm concept template @@ -67,7 +68,8 @@ concept FwdXdlV3Algorithm = ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesThreadDistribution3D && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && SpecifiesWarpGemm && - SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && SpecifiesGemmPipeline && SpecifiesXdl; + SpecifiesFwdConvSpecialization && SpecifiesGemmPipeline && SpecifiesXdl && + SpecifiesPipelineV3; // FWD WMMA algorithm concepts template @@ -75,8 +77,7 @@ concept FwdWmmaAlgorithm = ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesThreadDistribution3D && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && SpecifiesWarpGemm && - SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && - SpecifiesNumPrefetchStages && SpecifiesLoopScheduler && SpecifiesGemmPipeline && SpecifiesWmma; + SpecifiesFwdConvSpecialization && SpecifiesGemmPipeline && SpecifiesWmma; // FWD DL algorithms template @@ -94,17 +95,15 @@ template concept BwdMultiDXdlAlgorithm = BwdXdlAlgorithmBase && SpecifiesMultipleDSupport; template -concept BwdXdlV3Algorithm = BwdXdlV3AlgorithmBase && SpecifiesGenericInstance; +concept BwdXdlV3Algorithm = BwdXdlV3AlgorithmBase; template -concept BwdTwoStageXdlAlgorithm = BwdXdlV3AlgorithmBase && SpecifiesTransposeTransfer && - SpecifiesGemmBatchOptions && SpecifiesTwoStageSupport; +concept BwdTwoStageXdlAlgorithm = BwdXdlV3AlgorithmBase && SpecifiesTransposeTransfer && SpecifiesTwoStageSupport; // BWD weight WMMA algorithm concepts template concept BwdWmmaAlgorithm = - BwdWmmaAlgorithmBase && SpecifiesNumPrefetchStages && SpecifiesLoopScheduler && - SpecifiesGemmPipeline && SpecifiesGenericInstance; + BwdWmmaAlgorithmBase && SpecifiesNumPrefetchStages && SpecifiesGemmPipeline && SpecifiesGenericInstance; template concept BwdMultiDWmmaV3Algorithm = BwdWmmaV3AlgorithmBase && SpecifiesMultipleDSupport; @@ -115,7 +114,7 @@ concept BwdWmmaV3Algorithm = template concept BwdTwoStageWmmaV3Algorithm = BwdWmmaV3AlgorithmBase && SpecifiesTransposeTransfer && - SpecifiesGemmBatchOptions && SpecifiesTwoStageSupport; + SpecifiesTwoStageSupport; // BWD weight DL algorithms template diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp index 5eea36313f7..764b0617965 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp @@ -105,7 +105,7 @@ struct ConvBwdWeightTwoStageXdlFactory C_BLOCK_TRANSFER.scalar_per_vector, BLOCK_GEMM.scheduler, BLOCK_GEMM.pipeline_version, - ALGORITHM.num_conv_groups_to_merge, + BLOCK_GEMM.num_conv_groups_to_merge, typename Types::OutComputeType, typename Types::InComputeType, ALGORITHM.max_transpose_transfer_src_scalar_per_vector, diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp index 29cf3f8513a..b64b7d336e8 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp @@ -31,6 +31,8 @@ ConvSpec(CONV_ENUM, GEMM_ENUM) -> ConvSpec; struct BlockGemmSpec { + size_t num_conv_groups_to_merge{1}; + size_t num_gemm_k_prefetch_stages{1}; ck::BlockGemmPipelineVersion pipeline_version; ck::BlockGemmPipelineScheduler scheduler; }; @@ -63,7 +65,11 @@ consteval BlockGemmSpec SetBlockGemm() default: throw "Unknown PipelineVersion"; } - return BlockGemmSpec{.pipeline_version = version, .scheduler = scheduler}; + return BlockGemmSpec{ + .num_conv_groups_to_merge = BG.num_conv_groups_to_merge, + .num_gemm_k_prefetch_stages = BG.num_gemm_k_prefetch_stages, + .pipeline_version = version, + .scheduler = scheduler}; } template diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index 91952c21a94..af9cc44115b 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -232,12 +232,35 @@ enum class PipelineScheduler enum class ConvAlgorithmSpecialization { - LARGE_TENSOR, - REFERENCE, // GPU reference implementation for validation, - TWO_STAGE, - MULTIPLE_D + NONE = 0, + LARGE_TENSOR = 1 << 0, + REFERENCE = 1 << 1, // GPU reference implementation for validation, + TWO_STAGE = 1 << 2, + MULTIPLE_D = 1 << 3, + PIPELINE_V3 = 1 << 4 }; +constexpr ConvAlgorithmSpecialization operator|(ConvAlgorithmSpecialization lhs, + ConvAlgorithmSpecialization rhs) +{ + using T = std::underlying_type_t; + return static_cast(static_cast(lhs) | static_cast(rhs)); +} + +constexpr ConvAlgorithmSpecialization operator&(ConvAlgorithmSpecialization lhs, + ConvAlgorithmSpecialization rhs) +{ + using T = std::underlying_type_t; + return static_cast(static_cast(lhs) & static_cast(rhs)); +} + +// Enable direct boolean conversion for flag checks +constexpr bool operator!(ConvAlgorithmSpecialization spec) +{ + using T = std::underlying_type_t; + return static_cast(spec) == 0; +} + enum class MatrixInstructionType { XDL, diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp index bda064c2918..57b85fa8c78 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp @@ -19,7 +19,7 @@ constexpr auto SIGNATURE = .weight = {.config = {.layout = ckb::TensorLayout::GKYXC}}, .output = {.config = {.layout = ckb::TensorLayout::GNHWK}}}; -constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_CShuffle{} +constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_CShuffle_V3{} .with_thread_block(cku::ThreadBlock_64_32x32x32) .with_gemm_config(cku::BwdGemmParams_Xdl_1x1_per_wave) .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp index a559d3ee47c..562599da936 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp @@ -24,8 +24,7 @@ constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CS .with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave) .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT) - .with_prefetch_config(1, ckb::PipelineScheduler::DEFAULT) - .with_gemm_pipeline(ckb::PipelineVersion::V1); + .with_gemm_pipeline(ckb::PipelineVersion::V1, ckb::PipelineScheduler::DEFAULT); using Builder = ckb::ConvBuilder; using Instance = Builder::Instance; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp index d3ace110c4b..abeca068e71 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp @@ -31,7 +31,6 @@ TEST(FwdConvInstances, .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) .with_transfer(Transfer_4x16x1) .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) - .with_prefetch_config(1, PipelineScheduler::DEFAULT) .with_num_conv_groups_to_merge(2); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp index 8ee12a46ba9..7ef33469a0d 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp @@ -33,9 +33,9 @@ TEST(FwdConvInstances, .with_gemm_config(GemmParams_Wmma_2x1_per_wave) .with_transfer(Transfer_4x32x1) .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) - .with_prefetch_config(1, PipelineScheduler::INTRAWAVE) .with_num_conv_groups_to_merge(2) - .with_gemm_pipeline(PipelineVersion::V1); + .with_with_num_gemm_k_prefetch_stages(3) + .with_gemm_pipeline(PipelineVersion::V1, PipelineScheduler::INTRAWAVE); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp index 23edef54369..395640dcf1f 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp @@ -36,7 +36,6 @@ TEST(FwdConvInstances, .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) .with_transfer(Transfer_4x16x1) .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) - .with_prefetch_config(1, PipelineScheduler::DEFAULT) .with_num_conv_groups_to_merge(1); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp index b117e693fe3..85a87fc3a28 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp @@ -31,7 +31,6 @@ TEST(FwdConvInstances, .with_gemm_config(FwdGemmParams_Xdl_4x2_per_wave) .with_transfer(Transfer_4x64x1_fp8) .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) - .with_prefetch_config(1, PipelineScheduler::DEFAULT) .with_num_conv_groups_to_merge(1); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp index 97bc0a00e5d..a4dd6171d88 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp @@ -30,7 +30,6 @@ TEST(FwdConvInstances, .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) .with_transfer(Transfer_4x16x1) .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) - .with_prefetch_config(1, PipelineScheduler::DEFAULT) .with_num_conv_groups_to_merge(1); using Builder = ConvBuilder; @@ -67,7 +66,6 @@ TEST( .with_transfer(Transfer_4x16x1) .with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0, GemmSpecialization::MNKPadding) - .with_prefetch_config(1, PipelineScheduler::DEFAULT) .with_num_conv_groups_to_merge(1); using Builder = ConvBuilder; diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index 2e0c79ec5c3..8c1c3f898ce 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -40,6 +40,8 @@ static_assert(ckb::WarpGemmDescriptor); struct GemmPipeline { + size_t num_gemm_k_prefetch_stages{1}; + size_t num_conv_groups_to_merge{1}; PipelineVersion pipeline_version; PipelineScheduler scheduler{PipelineScheduler::DEFAULT}; }; @@ -195,23 +197,12 @@ struct ConvSpecializationBwdWeight_ ConvSpecialization bwd_weight_specialization; }; -struct Prefetch_ -{ - size_t num_gemm_k_prefetch_stages; - PipelineScheduler loop_scheduler; -}; - struct TransposeParams_ { size_t max_transpose_transfer_src_scalar_per_vector{1}; size_t max_transpose_transfer_dst_scalar_per_vector{1}; }; -struct GemmBatchOptions_ -{ - size_t num_conv_groups_to_merge{1}; -}; - struct GemmPipeline_ { GemmPipeline gemm_pipeline; @@ -241,22 +232,10 @@ struct DlTransfer_ DlTransfer transfer; }; -struct TwoStageSpecialization_ -{ - static constexpr ConvAlgorithmSpecialization specialization = - ConvAlgorithmSpecialization::TWO_STAGE; -}; - -struct MultipleDSpecialization_ +template +struct AlgorithmSpecialization_ { - static constexpr ConvAlgorithmSpecialization specialization = - ConvAlgorithmSpecialization::MULTIPLE_D; -}; - -struct LargeTensorSpecialization_ -{ - static constexpr ConvAlgorithmSpecialization specialization = - ConvAlgorithmSpecialization::LARGE_TENSOR; + static constexpr ConvAlgorithmSpecialization specialization = Specialization; }; // Specify thread block dimensions for a GEMM (CK Tile). @@ -378,15 +357,6 @@ struct ConvAlgorithmTemplate : Components... return result; } - constexpr auto with_prefetch_config(size_t k_prefetch_stages, PipelineScheduler scheduler) const - { - static_assert(std::is_base_of_v); - auto result = *this; - result.num_gemm_k_prefetch_stages = k_prefetch_stages; - result.loop_scheduler = scheduler; - return result; - } - constexpr auto with_transpose_params(size_t max_src_scalar_per_vector, size_t max_dst_scalar_per_vector) const { @@ -399,9 +369,17 @@ struct ConvAlgorithmTemplate : Components... constexpr auto with_num_conv_groups_to_merge(size_t num_groups_to_merge) const { - static_assert(std::is_base_of_v); + static_assert(std::is_base_of_v); auto result = *this; - result.num_conv_groups_to_merge = num_groups_to_merge; + result.gemm_pipeline.num_conv_groups_to_merge = num_groups_to_merge; + return result; + } + + constexpr auto with_num_gemm_k_prefetch_stages(size_t num_prefetch_stages) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.gemm_pipeline.num_gemm_k_prefetch_stages = num_prefetch_stages; return result; } @@ -422,6 +400,15 @@ struct ConvAlgorithmTemplate : Components... return result; } + constexpr auto with_gemm_pipeline(const PipelineVersion plv, const PipelineScheduler sch) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.gemm_pipeline.pipeline_version = plv; + result.gemm_pipeline.scheduler = sch; + return result; + } + template constexpr auto with_dl_thread_config(const TC& tc) const { @@ -498,29 +485,24 @@ struct ConvAlgorithmTemplate : Components... // Fwd algorithm types -using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = - ConvAlgorithmTemplate, - ConvSpecializationFwd_, - Prefetch_, - GemmBatchOptions_>; +using enum ckb::ConvAlgorithmSpecialization; -using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 = +// Covers both XDL and WMMA variants for generic fwd convolution +using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_CShuffle = ConvAlgorithmTemplate, ConvSpecializationFwd_, - GemmPipeline_>; + GemmPipeline_, + AlgorithmSpecialization_<>>; -using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle = +using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 = ConvAlgorithmTemplate, ConvSpecializationFwd_, GemmPipeline_, - Prefetch_, - GemmBatchOptions_>; + AlgorithmSpecialization_>; using ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK = ConvAlgorithmTemplate, ConvSpecializationFwd_, - Prefetch_, - GemmBatchOptions_, - LargeTensorSpecialization_>; + GemmPipeline_, + AlgorithmSpecialization_>; // CK Tile algorithm using ConvAlgorithm_Tile_GroupedConvolutionKernel = ConvAlgorithmTemplate; // Reference algorithm descriptor - for GPU reference validation -// This is a simple algorithm that requires no complex configuration, -// just a specialization marker to identify it as a reference implementation. -struct ConvAlgorithm_Reference -{ - static constexpr auto specialization = ckb::ConvAlgorithmSpecialization::REFERENCE; - // GPU reference uses simple algorithm, no tile configuration needed -}; +using ConvAlgorithm_Reference = ConvAlgorithmTemplate>; // Bwd weight algorithm types using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle = @@ -560,7 +535,8 @@ using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle = WarpGemm_, InputOutputTileTransfer_<4>, ConvSpecializationBwdWeight_, - TransposeParams_>; + TransposeParams_, + AlgorithmSpecialization_<>>; using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle = ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, GemmPipeline_, - Prefetch_>; + AlgorithmSpecialization_<>>; // Covers both XDL and WMMA variants -using ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_CShuffle = +using ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_CShuffle_V3 = ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, GemmPipeline_, TransposeParams_, - GemmBatchOptions_, - TwoStageSpecialization_>; + AlgorithmSpecialization_>; using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 = ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, - GemmPipeline_>; + GemmPipeline_, + AlgorithmSpecialization_>; using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3 = ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, GemmPipeline_, - TransposeParams_>; + TransposeParams_, + AlgorithmSpecialization_>; using ConvAlgorithm_DeviceGroupedConvBwdWeight_Dl = ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, - MultipleDSpecialization_>; + AlgorithmSpecialization_>; using ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3 = ConvAlgorithmTemplate; + AlgorithmSpecialization_>; using ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_V3 = ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, GemmPipeline_, - MultipleDSpecialization_>; + AlgorithmSpecialization_>; } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index bad29a65c0c..7b1bfc9e558 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -98,7 +98,7 @@ template <> inline std::string to_string(GemmPipeline t) { std::ostringstream oss; - oss << to_string(t.scheduler) << "," << to_string(t.pipeline_version); + oss << t.num_gemm_k_prefetch_stages << "," << t.num_conv_groups_to_merge << "," << to_string(t.scheduler) << "," << to_string(t.pipeline_version); return oss.str(); } @@ -281,14 +281,6 @@ inline std::string to_string(ConvSpecializationBwd return oss.str(); } -template <> -inline std::string to_string(Prefetch_ t) -{ - std::ostringstream oss; - oss << t.num_gemm_k_prefetch_stages << "," << to_string(t.loop_scheduler); - return oss.str(); -} - template <> inline std::string to_string(GemmPipeline_ t) { @@ -322,28 +314,8 @@ inline std::string to_string>(DlTransfer_<5> t) // Template specializations for algorithm types template <> -inline std::string to_string( - ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle t) -{ - std::ostringstream oss; - oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); - return oss.str(); -} - -template <> -inline std::string to_string( - ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 t) -{ - std::ostringstream oss; - oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); - return oss.str(); -} - -template <> -inline std::string to_string( - ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle t) +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_CShuffle t) { std::ostringstream oss; oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) @@ -425,8 +397,8 @@ inline std::string to_string -inline std::string to_string( - ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_CShuffle t) +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_CShuffle_V3 t) { std::ostringstream oss; oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) From 6bcdc10593cb3b33c6b7a8b981ab47a63e29ed38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Fri, 9 Jan 2026 10:48:18 -0500 Subject: [PATCH 63/81] Fix fwd factories after refactoring. --- ...v_bwd_weight_two_stage_wmma_v3_factory.hpp | 2 +- .../factory/conv_bwd_weight_wmma_factory.hpp | 2 +- .../factory/conv_fwd_large_tensor_factory.hpp | 2 +- .../builder/factory/conv_fwd_wmma_factory.hpp | 2 +- .../builder/factory/conv_fwd_xdl_factory.hpp | 4 +- .../factory/helpers/ck/conv_tuning_params.hpp | 4 +- .../conv/ck/test_ckb_conv_fwd_1d_bf16.cpp | 2 +- .../conv/ck/test_ckb_conv_fwd_1d_fp16.cpp | 2 +- .../test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp | 72 +++++++++---------- .../conv/ck/test_ckb_conv_fwd_2d_bf16.cpp | 4 +- ...est_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp | 2 +- .../conv/ck/test_ckb_conv_fwd_2d_fp16.cpp | 2 +- .../conv/ck/test_ckb_conv_fwd_2d_fp32.cpp | 2 +- .../test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp | 2 +- .../conv/ck/test_ckb_conv_fwd_3d_bf16.cpp | 2 +- .../conv/ck/test_ckb_conv_fwd_3d_fp16.cpp | 2 +- .../conv/ck/test_ckb_conv_fwd_3d_fp32.cpp | 2 +- .../test/impl/conv_algorithm_types.hpp | 10 ++- .../test/utils/conv_algorithm_type_utils.hpp | 10 +++ 19 files changed, 74 insertions(+), 56 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp index 48a15a16385..5fa04a3779b 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp @@ -106,7 +106,7 @@ struct ConvBwdWeightTwoStageWmmaV3Factory C_BLOCK_TRANSFER.scalar_per_vector, BLOCK_GEMM.scheduler, BLOCK_GEMM.pipeline_version, - ALGORITHM.num_conv_groups_to_merge, + ALGORITHM.gemm_pipeline.num_gemm_k_prefetch_stages, typename Types::OutComputeType, typename Types::InComputeType, ALGORITHM.max_transpose_transfer_src_scalar_per_vector, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp index 8e22958ac13..7c77b0174fe 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp @@ -106,7 +106,7 @@ struct ConvBwdWeightWmmaFactory C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, to_sequence_v, C_BLOCK_TRANSFER.scalar_per_vector, - ALGORITHM.num_gemm_k_prefetch_stages, + ALGORITHM.gemm_pipeline.num_gemm_k_prefetch_stages, LOOP_SCHEDULER, GRIDWISE_GEMM_PIPELINE_VERSION>; }; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp index fdb95d602a8..7c6842a08bc 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp @@ -73,7 +73,7 @@ struct ConvFwdLargeTensorFactory typename Ops::CDEElementwiseOp, SPECIALIZATION.conv_spec, SPECIALIZATION.gemm_spec, - ALGORITHM.num_gemm_k_prefetch_stages, + ALGORITHM.gemm_pipeline.num_gemm_k_prefetch_stages, BLOCK.block_size, BLOCK.per_block.m, BLOCK.per_block.n, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp index d52f684d8c0..97335d9b5e0 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp @@ -80,7 +80,7 @@ struct ConvFwdWmmaFactory typename Ops::CDEElementwiseOp, SPECIALIZATION.conv_spec, SPECIALIZATION.gemm_spec, - ALGORITHM.num_gemm_k_prefetch_stages, + ALGORITHM.gemm_pipeline.num_gemm_k_prefetch_stages, BLOCK.block_size, BLOCK.per_block.m, BLOCK.per_block.n, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp index eb2fdfad4d3..97805e6c8cd 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp @@ -73,7 +73,7 @@ struct ConvFwdXdlFactory typename Ops::CDEElementwiseOp, SPECIALIZATION.conv_spec, SPECIALIZATION.gemm_spec, - ALGORITHM.num_gemm_k_prefetch_stages, + ALGORITHM.gemm_pipeline.num_gemm_k_prefetch_stages, BLOCK.block_size, BLOCK.per_block.m, BLOCK.per_block.n, @@ -105,7 +105,7 @@ struct ConvFwdXdlFactory typename Types::AComputeType, typename Types::BComputeType, LOOP_SCHEDULER, - ALGORITHM.num_conv_groups_to_merge>; + ALGORITHM.gemm_pipeline.num_conv_groups_to_merge>; }; } // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp index b64b7d336e8..8db9607f343 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp @@ -75,7 +75,7 @@ consteval BlockGemmSpec SetBlockGemm() template consteval ck::LoopScheduler SetLoopScheduler() { - constexpr auto loop_scheduler = ALGORITHM.loop_scheduler; + constexpr auto loop_scheduler = ALGORITHM.gemm_pipeline.scheduler; using ck_loop_sched = ck::LoopScheduler; switch(loop_scheduler) { @@ -89,7 +89,7 @@ consteval ck::LoopScheduler SetLoopScheduler() template consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion() { - constexpr auto pipeline_version = ALGORITHM.pipeline_version; + constexpr auto pipeline_version = ALGORITHM.gemm_pipeline.pipeline_version; using ck_pipeline = ck::PipelineVersion; switch(pipeline_version) { diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp index b7ec4cdac09..ef271cd9879 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp @@ -29,7 +29,7 @@ TEST(FwdConvInstances, .output = {.config = {.layout = NGKW}, .operation = {.elementwise_operation = SCALE}}}; constexpr auto FwdConvAlgorithm = - ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_CShuffle_V3{} .with_thread_block(ThreadBlock_256_256x256x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) .with_transfer(Transfer_4x64x1) diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp index abeca068e71..4a30766bdd5 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp @@ -26,7 +26,7 @@ TEST(FwdConvInstances, .output = {.config = {.layout = NWGK}}}; constexpr auto FwdConvAlgorithm = - ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_CShuffle{} .with_thread_block(ThreadBlock_64_64x32x32) .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) .with_transfer(Transfer_4x16x1) diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp index 7ef33469a0d..be519a69ab9 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp @@ -11,41 +11,41 @@ using namespace ck_tile::builder::test_utils; // 1D I8 (channels-last) with and DEFAULT specialization // (not supported on gfx11 and gfx12) -#if !defined(__gfx11__) && !defined(__gfx12__) -TEST(FwdConvInstances, - Create_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_Instance_1D_FP32_ChannelsFirst_scale) -{ - using enum ck_tile::builder::ConvDirection; - using enum ck_tile::builder::DataType; - using enum ck_tile::builder::TensorLayout; - - constexpr ConvSignature FwdConvSignature{.spatial_dim = 1, - .direction = FORWARD, - .data_type = I8, - .accumulation_data_type = I32, - .input = {.config = {.layout = GNWC}}, - .weight = {.config = {.layout = GKXC}}, - .output = {.config = {.layout = GNWK}}}; - - constexpr auto FwdConvAlgorithm = - ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle{} - .with_thread_block(ThreadBlock_128_64x64x64) - .with_gemm_config(GemmParams_Wmma_2x1_per_wave) - .with_transfer(Transfer_4x32x1) - .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) - .with_num_conv_groups_to_merge(2) - .with_with_num_gemm_k_prefetch_stages(3) - .with_gemm_pipeline(PipelineVersion::V1, PipelineScheduler::INTRAWAVE); - - using Builder = ConvBuilder; - - const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); - run_test({"DeviceGroupedConvFwdMultipleD_Wmma_CShuffle", - expected_transfer_parameters, - "GNWC,GKXC,EmptyTuple,GNWK", - "PassThrough,PassThrough,PassThrough", - "Default"}); -} -#endif +//#if !defined(__gfx11__) && !defined(__gfx12__) +// TEST(FwdConvInstances, +// Create_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_Instance_1D_FP32_ChannelsFirst_scale) +// { +// using enum ck_tile::builder::ConvDirection; +// using enum ck_tile::builder::DataType; +// using enum ck_tile::builder::TensorLayout; + +// constexpr ConvSignature FwdConvSignature{.spatial_dim = 1, +// .direction = FORWARD, +// .data_type = I8, +// .accumulation_data_type = I32, +// .input = {.config = {.layout = GNWC}}, +// .weight = {.config = {.layout = GKXC}}, +// .output = {.config = {.layout = GNWK}}}; + +// constexpr auto FwdConvAlgorithm = +// ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_CShuffle{} +// .with_thread_block(ThreadBlock_128_64x64x64) +// .with_gemm_config(GemmParams_Wmma_2x1_per_wave) +// .with_transfer(Transfer_4x32x1) +// .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) +// .with_num_conv_groups_to_merge(1) +// .with_num_gemm_k_prefetch_stages(1) +// .with_gemm_pipeline(PipelineScheduler::DEFAULT); + +// using Builder = ConvBuilder; + +// const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); +// run_test({"DeviceGroupedConvFwdMultipleD_Wmma_CShuffle", +// expected_transfer_parameters, +// "GNWC,GKXC,EmptyTuple,GNWK", +// "PassThrough,PassThrough,PassThrough", +// "Default"}); +// } +// #endif } // namespace diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp index e63aa41e059..ebc30272c9d 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp @@ -26,7 +26,7 @@ TEST(FwdConvInstances, .output = {.config = {.layout = NHWGK}}}; constexpr auto FwdConvAlgorithm = - ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_CShuffle_V3{} .with_thread_block(ThreadBlock_256_256x256x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) .with_transfer(Transfer_4x64x1) @@ -63,7 +63,7 @@ TEST(FwdConvInstances, .output = {.config = {.layout = NHWGK}}}; constexpr auto FwdConvAlgorithm = - ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_CShuffle_V3{} .with_thread_block(ThreadBlock_256_256x256x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) .with_transfer(Transfer_4x64x1) diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp index 395640dcf1f..be2fdd689af 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp @@ -31,7 +31,7 @@ TEST(FwdConvInstances, .with_auxiliary_operand_configs()}}; constexpr auto FwdConvAlgorithm = - ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_CShuffle{} .with_thread_block(ThreadBlock_64_64x32x32) .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) .with_transfer(Transfer_4x16x1) diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp index 628394e3ca2..1ea09918aac 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp @@ -24,7 +24,7 @@ constexpr auto SIGNATURE = .weight = {.config = {.layout = ckb::TensorLayout::GKYXC}}, .output = {.config = {.layout = ckb::TensorLayout::GNHWK}}}; -constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} +constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_CShuffle_V3{} .with_thread_block(cku::ThreadBlock_256_256x256x32) .with_gemm_config(cku::FwdGemmParams_Xdl_4x4_per_wave) .with_transfer(cku::Transfer_4x64x1) diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp index 67462426f61..c26ff9a9cc3 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp @@ -25,7 +25,7 @@ TEST(FwdConvInstances, .output = {.config = {.layout = NGKHW}}}; constexpr auto FwdConvAlgorithm = - ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_CShuffle_V3{} .with_thread_block(ThreadBlock_256_128x128x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) .with_transfer(Transfer_4x64x1) diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp index 85a87fc3a28..a83c9c76551 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp @@ -26,7 +26,7 @@ TEST(FwdConvInstances, .output = {.config = {.layout = NHWGK}}}; constexpr auto FwdConvAlgorithm = - ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_CShuffle{} .with_thread_block(ThreadBlock_256_256x128x32) .with_gemm_config(FwdGemmParams_Xdl_4x2_per_wave) .with_transfer(Transfer_4x64x1_fp8) diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp index 016e972f3cd..34f3e286468 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp @@ -26,7 +26,7 @@ TEST(FwdConvInstances, .output = {.config = {.layout = GNDHWK}}}; constexpr auto FwdConvAlgorithm = - ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_CShuffle_V3{} .with_thread_block(ThreadBlock_256_256x256x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) .with_transfer(Transfer_4x64x1) diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp index 00da06d41aa..07c399a9795 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp @@ -26,7 +26,7 @@ TEST(FwdConvInstances, .output = {.config = {.layout = NDHWGK}}}; constexpr auto FwdConvAlgorithm = - ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_CShuffle_V3{} .with_thread_block(ThreadBlock_256_128x128x32) .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) .with_transfer(Transfer_4x64x1) diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp index 825b9a03330..d33ac55db3c 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp @@ -26,7 +26,7 @@ TEST(FwdConvInstances, .output = {.config = {.layout = NGKDHW}}}; constexpr auto FwdConvAlgorithm = - ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_CShuffle_V3{} .with_thread_block(ThreadBlock_256_256x256x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) .with_transfer(Transfer_4x64x1) diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index 8c1c3f898ce..6cac5aa49b0 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -400,6 +400,14 @@ struct ConvAlgorithmTemplate : Components... return result; } + constexpr auto with_gemm_pipeline(const PipelineScheduler sch) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.gemm_pipeline.scheduler = sch; + return result; + } + constexpr auto with_gemm_pipeline(const PipelineVersion plv, const PipelineScheduler sch) const { static_assert(std::is_base_of_v); @@ -496,7 +504,7 @@ using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_CShuffle = GemmPipeline_, AlgorithmSpecialization_<>>; -using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 = +using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_CShuffle_V3 = ConvAlgorithmTemplate, diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index 7b1bfc9e558..09cb1551391 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -323,6 +323,16 @@ inline std::string to_string +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_CShuffle_V3 t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); + return oss.str(); +} + template <> inline std::string to_string( ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK t) From 7e027902932197ec193b7efe4149f94466dbfaa4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Mon, 12 Jan 2026 08:47:03 -0500 Subject: [PATCH 64/81] Remove the C++26 extensions. --- experimental/builder/test/CMakeLists.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 106608c4196..0a576468c86 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -54,7 +54,6 @@ function(add_ck_builder_test test_name) target_compile_options(${test_name} PRIVATE -Wno-global-constructors -Wno-c++20-compat - -Wno-c++26-extensions # Allow C++26 extensions for better compile-time diagnostics ) target_link_libraries(${test_name} PRIVATE GTest::gtest_main GTest::gmock) endfunction() From 46afc665436c3b8cf008bcb5e61387b988217069 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Mon, 12 Jan 2026 09:26:04 -0500 Subject: [PATCH 65/81] Fix fwd/bwd conv factory tests after tile transfer XDL/WMMA concepts refactoring. --- .../builder/conv_algorithm_concepts.hpp | 14 +- .../builder/factory/conv_algorithms.hpp | 157 +++++++++--------- .../builder/factory/conv_dispatcher.hpp | 34 ++-- .../test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp | 71 ++++---- .../test/utils/ckb_conv_test_configs.hpp | 51 +++++- .../test/utils/conv_algorithm_type_utils.hpp | 43 +++-- 6 files changed, 216 insertions(+), 154 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index 946b485bdba..d7f618f0792 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -340,16 +340,14 @@ concept SpecifiesGenericInstance = !requires { requires !!(T::specialization == ConvAlgorithmSpecialization::NONE); }; -template -concept SpecifiesXdl = requires (T t){ - { t.warp_gemm.matrix_instruction } -> std::convertible_to; - { t.warp_gemm.matrix_instruction == MatrixInstructionType::XDL}; +template +concept SpecifiesXdl = requires { + requires Algorithm.warp_gemm.matrix_instruction == MatrixInstructionType::XDL; }; -template -concept SpecifiesWmma = requires (T t){ - { t.warp_gemm.matrix_instruction } -> std::convertible_to; - { t.warp_gemm.matrix_instruction == MatrixInstructionType::WMMA}; +template +concept SpecifiesWmma = requires { + requires Algorithm.warp_gemm.matrix_instruction == MatrixInstructionType::WMMA; }; /******************************************** */ diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index 81bafa158dc..b228126106e 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -8,119 +8,128 @@ namespace ck_tile::builder::factory { // Base algorithm concepts -template +template concept FwdXdlAlgorithmBase = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesThreadDistribution3D && - SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && - SpecifiesSourceAccessOrder && SpecifiesWarpGemm && - SpecifiesFwdConvSpecialization && SpecifiesGemmPipeline && - SpecifiesXdl; + ConvAlgorithmDescriptor && SpecifiesThreadBlock && + SpecifiesThreadDistribution3D && SpecifiesLdsTransfer && + SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && + SpecifiesWarpGemm && SpecifiesFwdConvSpecialization && + SpecifiesGemmPipeline && SpecifiesXdl; -template +template concept BwdXdlAlgorithmBase = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesThreadDistribution4D && - SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && - SpecifiesSourceAccessOrder && SpecifiesWarpGemm && - SpecifiesBwdWeightConvSpecialization && SpecifiesXdl; + ConvAlgorithmDescriptor && SpecifiesThreadBlock && + SpecifiesThreadDistribution4D && SpecifiesLdsTransfer && + SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && + SpecifiesWarpGemm && SpecifiesBwdWeightConvSpecialization && + SpecifiesXdl; -template +template concept BwdXdlV3AlgorithmBase = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesThreadDistribution3D && - SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && - SpecifiesSourceAccessOrder && SpecifiesWarpGemm && - SpecifiesBwdWeightConvSpecialization && SpecifiesGemmPipeline && SpecifiesXdl && - SpecifiesPipelineV3; + ConvAlgorithmDescriptor && SpecifiesThreadBlock && + SpecifiesThreadDistribution3D && SpecifiesLdsTransfer && + SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && + SpecifiesWarpGemm && SpecifiesBwdWeightConvSpecialization && + SpecifiesGemmPipeline && SpecifiesXdl && SpecifiesPipelineV3; -template +template concept BwdWmmaAlgorithmBase = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesThreadDistribution3D && - SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && - SpecifiesSourceAccessOrder && SpecifiesWarpGemm && - SpecifiesBwdWeightConvSpecialization && SpecifiesWmma; + ConvAlgorithmDescriptor && SpecifiesThreadBlock && + SpecifiesThreadDistribution3D && SpecifiesLdsTransfer && + SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && + SpecifiesWarpGemm && SpecifiesBwdWeightConvSpecialization && + SpecifiesWmma; -template +template concept BwdWmmaV3AlgorithmBase = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesThreadDistribution3D && - SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && - SpecifiesSourceAccessOrder && SpecifiesWarpGemm && - SpecifiesBwdWeightConvSpecialization && SpecifiesGemmPipeline && SpecifiesWmma && - SpecifiesPipelineV3; + ConvAlgorithmDescriptor && SpecifiesThreadBlock && + SpecifiesThreadDistribution3D && SpecifiesLdsTransfer && + SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && + SpecifiesWarpGemm && SpecifiesBwdWeightConvSpecialization && + SpecifiesGemmPipeline && SpecifiesWmma && SpecifiesPipelineV3; // Reference algorithm concept -template -concept ReferenceAlgorithm = ConvAlgorithmDescriptor && SpecifiesReferenceAlgorithm; +template +concept ReferenceAlgorithm = ConvAlgorithmDescriptor && SpecifiesReferenceAlgorithm; // Tile-based algorithm concept -template -concept TileAlgorithm = ConvAlgorithmDescriptor && SpecifiesTileThreadBlock && - SpecifiesTileTransfer && SpecifiesTileConvSpecialization && - SpecifiesTileBlockGemm && SpecifiesTileOptimizations; +template +concept TileAlgorithm = ConvAlgorithmDescriptor && SpecifiesTileThreadBlock && + SpecifiesTileTransfer && SpecifiesTileConvSpecialization && + SpecifiesTileBlockGemm && SpecifiesTileOptimizations; // FWD XDL algorithm concepts -template -concept FwdXdlAlgorithm = FwdXdlAlgorithmBase && SpecifiesGenericInstance; +template +concept FwdXdlAlgorithm = FwdXdlAlgorithmBase && SpecifiesGenericInstance; -template -concept LargeTensorAlgorithm = FwdXdlAlgorithmBase && SpecifiesLargeTensorSupport; +template +concept LargeTensorAlgorithm = FwdXdlAlgorithmBase && SpecifiesLargeTensorSupport; -template +template concept FwdXdlV3Algorithm = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesThreadDistribution3D && - SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && - SpecifiesSourceAccessOrder && SpecifiesWarpGemm && - SpecifiesFwdConvSpecialization && SpecifiesGemmPipeline && SpecifiesXdl && - SpecifiesPipelineV3; + ConvAlgorithmDescriptor && SpecifiesThreadBlock && + SpecifiesThreadDistribution3D && SpecifiesLdsTransfer && + SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && + SpecifiesWarpGemm && SpecifiesFwdConvSpecialization && + SpecifiesGemmPipeline && SpecifiesXdl && SpecifiesPipelineV3; // FWD WMMA algorithm concepts -template +template concept FwdWmmaAlgorithm = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesThreadDistribution3D && - SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && - SpecifiesSourceAccessOrder && SpecifiesWarpGemm && - SpecifiesFwdConvSpecialization && SpecifiesGemmPipeline && SpecifiesWmma; + ConvAlgorithmDescriptor && SpecifiesThreadBlock && + SpecifiesThreadDistribution3D && SpecifiesLdsTransfer && + SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && + SpecifiesWarpGemm && SpecifiesFwdConvSpecialization && + SpecifiesGemmPipeline && SpecifiesWmma; // FWD DL algorithms -template +template concept FwdDlAlgorithm = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesFwdConvSpecialization && - SpecifiesGemmSpecialization && SpecifiesDlThreadConfig && SpecifiesDlThreadCluster && - SpecifiesDlFwdBlockTransfer && SpecifiesDlEpilogue; + ConvAlgorithmDescriptor && SpecifiesThreadBlock && + SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && + SpecifiesDlThreadConfig && SpecifiesDlThreadCluster && + SpecifiesDlFwdBlockTransfer && SpecifiesDlEpilogue; // BWD weight XDL algorithm concepts -template +template concept BwdXdlAlgorithm = - BwdXdlAlgorithmBase && SpecifiesTransposeTransfer && SpecifiesGenericInstance; + BwdXdlAlgorithmBase && SpecifiesTransposeTransfer && + SpecifiesGenericInstance; -template -concept BwdMultiDXdlAlgorithm = BwdXdlAlgorithmBase && SpecifiesMultipleDSupport; +template +concept BwdMultiDXdlAlgorithm = BwdXdlAlgorithmBase && SpecifiesMultipleDSupport; -template -concept BwdXdlV3Algorithm = BwdXdlV3AlgorithmBase; +template +concept BwdXdlV3Algorithm = BwdXdlV3AlgorithmBase; -template -concept BwdTwoStageXdlAlgorithm = BwdXdlV3AlgorithmBase && SpecifiesTransposeTransfer && SpecifiesTwoStageSupport; +template +concept BwdTwoStageXdlAlgorithm = BwdXdlV3AlgorithmBase && SpecifiesTransposeTransfer && + SpecifiesTwoStageSupport; // BWD weight WMMA algorithm concepts -template +template concept BwdWmmaAlgorithm = - BwdWmmaAlgorithmBase && SpecifiesNumPrefetchStages && SpecifiesGemmPipeline && SpecifiesGenericInstance; + BwdWmmaAlgorithmBase && SpecifiesNumPrefetchStages && + SpecifiesGemmPipeline && SpecifiesGenericInstance; -template -concept BwdMultiDWmmaV3Algorithm = BwdWmmaV3AlgorithmBase && SpecifiesMultipleDSupport; +template +concept BwdMultiDWmmaV3Algorithm = BwdWmmaV3AlgorithmBase && SpecifiesMultipleDSupport; -template +template concept BwdWmmaV3Algorithm = - BwdWmmaV3AlgorithmBase && SpecifiesTransposeTransfer && SpecifiesGenericInstance; + BwdWmmaV3AlgorithmBase && SpecifiesTransposeTransfer && + SpecifiesGenericInstance; -template -concept BwdTwoStageWmmaV3Algorithm = BwdWmmaV3AlgorithmBase && SpecifiesTransposeTransfer && - SpecifiesTwoStageSupport; +template +concept BwdTwoStageWmmaV3Algorithm = BwdWmmaV3AlgorithmBase && SpecifiesTransposeTransfer && + SpecifiesTwoStageSupport; // BWD weight DL algorithms -template +template concept BwdDlAlgorithm = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && - SpecifiesBwdWeightConvSpecialization && SpecifiesDlThreadConfig && - SpecifiesDlThreadCluster && SpecifiesDlBwdBlockTransfer && SpecifiesDlEpilogue; + ConvAlgorithmDescriptor && SpecifiesThreadBlock && + SpecifiesBwdWeightConvSpecialization && SpecifiesDlThreadConfig && + SpecifiesDlThreadCluster && SpecifiesDlBwdBlockTransfer && + SpecifiesDlEpilogue; } // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp index e235db4bb09..61f20b0fd34 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -104,38 +104,36 @@ template constexpr auto make_conv_instance() { - using AlgoType = std::remove_const_t; - // Reference algorithm supports all directions - if constexpr(ReferenceAlgorithm) + if constexpr(ReferenceAlgorithm) { return typename ReferenceFactory::Instance{}; } // CK Tile supports common factory for each direction - else if constexpr(TileAlgorithm) + else if constexpr(TileAlgorithm) { return typename ConvTileFactory::Instance{}; } // Forward direction (supports most algorithm variants) else if constexpr(ConvDirectionIsForward) { - if constexpr(FwdXdlV3Algorithm) + if constexpr(FwdXdlV3Algorithm) { return typename ConvFwdXdlV3Factory::Instance{}; } - else if constexpr(FwdXdlAlgorithm) + else if constexpr(FwdXdlAlgorithm) { return typename ConvFwdXdlFactory::Instance{}; } - else if constexpr(FwdWmmaAlgorithm) + else if constexpr(FwdWmmaAlgorithm) { return typename ConvFwdWmmaFactory::Instance{}; } - else if constexpr(FwdDlAlgorithm) + else if constexpr(FwdDlAlgorithm) { return typename ConvFwdDlFactory::Instance{}; } - else if constexpr(LargeTensorAlgorithm) + else if constexpr(LargeTensorAlgorithm) { return typename ConvFwdLargeTensorFactory::Instance{}; } @@ -159,42 +157,42 @@ constexpr auto make_conv_instance() // Backward weight direction (will expand with more algorithms in the future) else if constexpr(ConvDirectionIsBackwardWeight) { - if constexpr(BwdXdlAlgorithm) + if constexpr(BwdXdlAlgorithm) { return typename ConvBwdWeightXdlFactory::Instance{}; } - else if constexpr(BwdXdlV3Algorithm) + else if constexpr(BwdXdlV3Algorithm) { return typename ConvBwdWeightXdlV3Factory::Instance{}; } - else if constexpr(BwdTwoStageXdlAlgorithm) + else if constexpr(BwdTwoStageXdlAlgorithm) { return typename ConvBwdWeightTwoStageXdlFactory::Instance{}; } - else if constexpr(BwdDlAlgorithm) + else if constexpr(BwdDlAlgorithm) { return typename ConvBwdWeightDlFactory::Instance{}; } - else if constexpr(BwdMultiDXdlAlgorithm) + else if constexpr(BwdMultiDXdlAlgorithm) { return typename ConvBwdWeightMultiDXdlFactory::Instance{}; } - else if constexpr(BwdWmmaV3Algorithm) + else if constexpr(BwdWmmaV3Algorithm) { return typename ConvBwdWeightWmmaV3Factory::Instance{}; } - else if constexpr(BwdTwoStageWmmaV3Algorithm) + else if constexpr(BwdTwoStageWmmaV3Algorithm) { return typename ConvBwdWeightTwoStageWmmaV3Factory:: Instance{}; } - else if constexpr(BwdWmmaAlgorithm) + else if constexpr(BwdWmmaAlgorithm) { return typename ConvBwdWeightWmmaFactory::Instance{}; } - else if constexpr(BwdMultiDWmmaV3Algorithm) + else if constexpr(BwdMultiDWmmaV3Algorithm) { return typename ConvBwdWeightMultiDWmmaV3Factory:: Instance{}; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp index be519a69ab9..ad36a78c1fd 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp @@ -11,41 +11,40 @@ using namespace ck_tile::builder::test_utils; // 1D I8 (channels-last) with and DEFAULT specialization // (not supported on gfx11 and gfx12) -//#if !defined(__gfx11__) && !defined(__gfx12__) -// TEST(FwdConvInstances, -// Create_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_Instance_1D_FP32_ChannelsFirst_scale) -// { -// using enum ck_tile::builder::ConvDirection; -// using enum ck_tile::builder::DataType; -// using enum ck_tile::builder::TensorLayout; - -// constexpr ConvSignature FwdConvSignature{.spatial_dim = 1, -// .direction = FORWARD, -// .data_type = I8, -// .accumulation_data_type = I32, -// .input = {.config = {.layout = GNWC}}, -// .weight = {.config = {.layout = GKXC}}, -// .output = {.config = {.layout = GNWK}}}; - -// constexpr auto FwdConvAlgorithm = -// ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_CShuffle{} -// .with_thread_block(ThreadBlock_128_64x64x64) -// .with_gemm_config(GemmParams_Wmma_2x1_per_wave) -// .with_transfer(Transfer_4x32x1) -// .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) -// .with_num_conv_groups_to_merge(1) -// .with_num_gemm_k_prefetch_stages(1) -// .with_gemm_pipeline(PipelineScheduler::DEFAULT); - -// using Builder = ConvBuilder; - -// const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); -// run_test({"DeviceGroupedConvFwdMultipleD_Wmma_CShuffle", -// expected_transfer_parameters, -// "GNWC,GKXC,EmptyTuple,GNWK", -// "PassThrough,PassThrough,PassThrough", -// "Default"}); -// } -// #endif +#if !defined(__gfx11__) && !defined(__gfx12__) +TEST(FwdConvInstances, + Create_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_Instance_1D_FP32_ChannelsFirst_scale) +{ + using enum ck_tile::builder::ConvDirection; + using enum ck_tile::builder::DataType; + using enum ck_tile::builder::TensorLayout; + + constexpr ConvSignature FwdConvSignature{.spatial_dim = 1, + .direction = FORWARD, + .data_type = I8, + .accumulation_data_type = I32, + .input = {.config = {.layout = GNWC}}, + .weight = {.config = {.layout = GKXC}}, + .output = {.config = {.layout = GNWK}}}; + + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_CShuffle{} + .with_thread_block(ThreadBlock_128_64x64x64) + .with_gemm_config(GemmParams_Wmma_16x16_2x2_per_wave) + .with_transfer(Transfer_4x32x1_vector_load_16_generic) + .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_num_gemm_k_prefetch_stages(1) + .with_gemm_pipeline(PipelineVersion::V1, PipelineScheduler::DEFAULT); + + using Builder = ConvBuilder; + + const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); + run_test({"DeviceGroupedConvFwdMultipleD_Wmma_CShuffle", + expected_transfer_parameters, + "GNWC,GKXC,EmptyTuple,GNWK", + "PassThrough,PassThrough,PassThrough", + "Default"}); +} +#endif } // namespace diff --git a/experimental/builder/test/utils/ckb_conv_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_test_configs.hpp index fd09c810d04..34682746cee 100644 --- a/experimental/builder/test/utils/ckb_conv_test_configs.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_configs.hpp @@ -196,7 +196,7 @@ constexpr InputOutputTileTransfer<> Transfer_4x16x1{ .a = { .thread_distribution = {.k0 = 4, .m_n = 16, .k1 = 1}, - .lds_transfer_params = {.global_memory_vector_load_size = 8, + .lds_transfer_params = {.global_memory_vector_load_size = 8, .src_vector_dim = 2, .src_scalar_per_vector = 8, .lds_dst_scalar_per_vector = 8, @@ -208,7 +208,7 @@ constexpr InputOutputTileTransfer<> Transfer_4x16x1{ .b = { .thread_distribution = {.k0 = 4, .m_n = 16, .k1 = 1}, - .lds_transfer_params = {.global_memory_vector_load_size = 8, + .lds_transfer_params = {.global_memory_vector_load_size = 8, .src_vector_dim = 2, .src_scalar_per_vector = 8, .lds_dst_scalar_per_vector = 8, @@ -232,7 +232,7 @@ constexpr InputOutputTileTransfer<> Transfer_4x32x1{ .a = { .thread_distribution = {.k0 = 4, .m_n = 32, .k1 = 1}, - .lds_transfer_params = {.global_memory_vector_load_size = 8, + .lds_transfer_params = {.global_memory_vector_load_size = 8, .src_vector_dim = 2, .src_scalar_per_vector = 16, .lds_dst_scalar_per_vector = 16, @@ -244,7 +244,7 @@ constexpr InputOutputTileTransfer<> Transfer_4x32x1{ .b = { .thread_distribution = {.k0 = 4, .m_n = 32, .k1 = 1}, - .lds_transfer_params = {.global_memory_vector_load_size = 8, + .lds_transfer_params = {.global_memory_vector_load_size = 8, .src_vector_dim = 2, .src_scalar_per_vector = 16, .lds_dst_scalar_per_vector = 16, @@ -263,6 +263,41 @@ constexpr InputOutputTileTransfer<> Transfer_4x32x1{ }, }; +constexpr InputOutputTileTransfer<> Transfer_4x32x1_vector_load_16_generic{ + .a = + { + .thread_distribution = {.k0 = 4, .m_n = 32, .k1 = 1}, + .lds_transfer_params = {.global_memory_vector_load_size = 16, + .src_vector_dim = 2, + .src_scalar_per_vector = 1, + .lds_dst_scalar_per_vector = 16, + .is_direct_load = false, + .lds_padding = true}, + .thread_distribution_access_order = {1, 0, 2}, + .src_access_order = {1, 0, 2}, + }, + .b = + { + .thread_distribution = {.k0 = 4, .m_n = 32, .k1 = 1}, + .lds_transfer_params = {.global_memory_vector_load_size = 16, + .src_vector_dim = 2, + .src_scalar_per_vector = 1, + .lds_dst_scalar_per_vector = 16, + .is_direct_load = false, + .lds_padding = true}, + .thread_distribution_access_order = {1, 0, 2}, + .src_access_order = {1, 0, 2}, + }, + .c = + { + .thread_distribution = + {.gemm_m_block_size = 1, .gemm_m_per_block = 32, .gemm_n_block_size = 1, .gemm_n_per_block = 4}, + .epilogue = {.m_xdl_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 1}, + }, +}; + constexpr WarpGemmParams BwdGemmParams_Xdl_4x4_per_wave{ .matrix_instruction = MatrixInstructionType::XDL, .gemm_m_per_instruction = 32, .gemm_n_per_instruction = 32, .gemm_m_iters_per_wave = 4, .gemm_n_iters_per_wave = 4}; @@ -287,14 +322,14 @@ constexpr WarpGemmParams FwdGemmParams_Xdl_2x1_per_wave{ .matrix_instruction = MatrixInstructionType::XDL, .gemm_m_per_instruction = 32, .gemm_n_per_instruction = 32, .gemm_m_iters_per_wave = 2, .gemm_n_iters_per_wave = 1}; -constexpr WarpGemmParams GemmParams_Wmma_2x1_per_wave{ - .matrix_instruction = MatrixInstructionType::WMMA, - .gemm_m_per_instruction = 32, .gemm_n_per_instruction = 32, .gemm_m_iters_per_wave = 2, .gemm_n_iters_per_wave = 1}; - constexpr WarpGemmParams GemmParams_Wmma_16x16_2x1_per_wave{ .matrix_instruction = MatrixInstructionType::WMMA, .gemm_m_per_instruction = 16, .gemm_n_per_instruction = 16, .gemm_m_iters_per_wave = 2, .gemm_n_iters_per_wave = 1}; +constexpr WarpGemmParams GemmParams_Wmma_16x16_2x2_per_wave{ + .matrix_instruction = MatrixInstructionType::WMMA, + .gemm_m_per_instruction = 16, .gemm_n_per_instruction = 16, .gemm_m_iters_per_wave = 2, .gemm_n_iters_per_wave = 2}; + constexpr ThreadBlock ThreadBlock_256_256x256x32{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index 09cb1551391..9498997c814 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -318,7 +318,10 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) + oss << to_string(static_cast(t)) << "," + << t.transfer.a.lds_transfer_params.global_memory_vector_load_size << "," + << t.transfer.b.lds_transfer_params.global_memory_vector_load_size << "," + << to_string(static_cast(t)) << "," << to_string(static_cast>(t)); return oss.str(); } @@ -328,7 +331,10 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) + oss << to_string(static_cast(t)) << "," + << t.transfer.a.lds_transfer_params.global_memory_vector_load_size << "," + << t.transfer.b.lds_transfer_params.global_memory_vector_load_size << "," + << to_string(static_cast(t)) << "," << to_string(static_cast>(t)); return oss.str(); } @@ -350,7 +356,10 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) + oss << to_string(static_cast(t)) << "," + << t.transfer.a.lds_transfer_params.global_memory_vector_load_size << "," + << t.transfer.b.lds_transfer_params.global_memory_vector_load_size << "," + << to_string(static_cast(t)) << "," << to_string(static_cast>(t)); return oss.str(); } @@ -360,7 +369,9 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) + oss << to_string(static_cast(t)) << "," + << t.transfer.a.lds_transfer_params.global_memory_vector_load_size << "," + << to_string(static_cast(t)) << "," << to_string(static_cast>(t)); return oss.str(); } @@ -370,7 +381,9 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) + oss << to_string(static_cast(t)) << "," + << t.transfer.a.lds_transfer_params.global_memory_vector_load_size << "," + << to_string(static_cast(t)) << "," << to_string(static_cast>(t)); return oss.str(); } @@ -380,7 +393,9 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) + oss << to_string(static_cast(t)) << "," + << t.transfer.a.lds_transfer_params.global_memory_vector_load_size << "," + << to_string(static_cast(t)) << "," << to_string(static_cast>(t)); return oss.str(); } @@ -390,7 +405,9 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) + oss << to_string(static_cast(t)) << "," + << t.transfer.a.lds_transfer_params.global_memory_vector_load_size << "," + << to_string(static_cast(t)) << "," << to_string(static_cast>(t)); return oss.str(); } @@ -400,7 +417,9 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) + oss << to_string(static_cast(t)) << "," + << t.transfer.a.lds_transfer_params.global_memory_vector_load_size << "," + << to_string(static_cast(t)) << "," << to_string(static_cast>(t)); return oss.str(); } @@ -411,7 +430,9 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) + oss << to_string(static_cast(t)) + << t.transfer.a.lds_transfer_params.global_memory_vector_load_size << "," + << "," << to_string(static_cast(t)) << "," << to_string(static_cast>(t)); return oss.str(); } @@ -434,7 +455,9 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) + oss << to_string(static_cast(t)) << "," + << t.transfer.a.lds_transfer_params.global_memory_vector_load_size << "," + << to_string(static_cast(t)) << "," << to_string(static_cast>(t)); return oss.str(); } From 4f721aca8ab81eb8e4fde333b99b6c49a077f9c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Mon, 12 Jan 2026 11:33:46 -0500 Subject: [PATCH 66/81] Fix remaining fwd/bwd instances tests. --- .../builder/factory/conv_dispatcher.hpp | 35 ++++++++++--------- .../test/utils/conv_algorithm_type_utils.hpp | 22 ++++++++---- 2 files changed, 34 insertions(+), 23 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp index 61f20b0fd34..97bfa05e1b0 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -157,46 +157,47 @@ constexpr auto make_conv_instance() // Backward weight direction (will expand with more algorithms in the future) else if constexpr(ConvDirectionIsBackwardWeight) { - if constexpr(BwdXdlAlgorithm) + // Start from more specialized and end with least specialized. + if constexpr(BwdTwoStageXdlAlgorithm) { - return typename ConvBwdWeightXdlFactory::Instance{}; + return + typename ConvBwdWeightTwoStageXdlFactory::Instance{}; } - else if constexpr(BwdXdlV3Algorithm) + else if constexpr(BwdTwoStageWmmaV3Algorithm) { - return typename ConvBwdWeightXdlV3Factory::Instance{}; + return typename ConvBwdWeightTwoStageWmmaV3Factory:: + Instance{}; } - else if constexpr(BwdTwoStageXdlAlgorithm) + else if constexpr(BwdMultiDXdlAlgorithm) { return - typename ConvBwdWeightTwoStageXdlFactory::Instance{}; + typename ConvBwdWeightMultiDXdlFactory::Instance{}; + } + else if constexpr(BwdMultiDWmmaV3Algorithm) + { + return typename ConvBwdWeightMultiDWmmaV3Factory:: + Instance{}; } else if constexpr(BwdDlAlgorithm) { return typename ConvBwdWeightDlFactory::Instance{}; } - else if constexpr(BwdMultiDXdlAlgorithm) + else if constexpr(BwdXdlV3Algorithm) { - return - typename ConvBwdWeightMultiDXdlFactory::Instance{}; + return typename ConvBwdWeightXdlV3Factory::Instance{}; } else if constexpr(BwdWmmaV3Algorithm) { return typename ConvBwdWeightWmmaV3Factory::Instance{}; } - else if constexpr(BwdTwoStageWmmaV3Algorithm) + else if constexpr(BwdXdlAlgorithm) { - return typename ConvBwdWeightTwoStageWmmaV3Factory:: - Instance{}; + return typename ConvBwdWeightXdlFactory::Instance{}; } else if constexpr(BwdWmmaAlgorithm) { return typename ConvBwdWeightWmmaFactory::Instance{}; } - else if constexpr(BwdMultiDWmmaV3Algorithm) - { - return typename ConvBwdWeightMultiDWmmaV3Factory:: - Instance{}; - } else { static_assert( diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index 9498997c814..fd2ea47c155 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -318,11 +318,21 @@ inline std::string to_string(t)) << "," - << t.transfer.a.lds_transfer_params.global_memory_vector_load_size << "," - << t.transfer.b.lds_transfer_params.global_memory_vector_load_size << "," - << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); + if (t.warp_gemm.matrix_instruction == MatrixInstructionType::WMMA) + { + oss << to_string(static_cast(t)) << "," + << t.transfer.a.lds_transfer_params.global_memory_vector_load_size << "," + << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); + } + else + { + oss << to_string(static_cast(t)) << "," + << t.transfer.a.lds_transfer_params.global_memory_vector_load_size << "," + << t.transfer.b.lds_transfer_params.global_memory_vector_load_size << "," + << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); + } return oss.str(); } @@ -431,7 +441,7 @@ inline std::string to_string(t)) - << t.transfer.a.lds_transfer_params.global_memory_vector_load_size << "," + << "," << t.transfer.a.lds_transfer_params.global_memory_vector_load_size << "," << to_string(static_cast(t)) << "," << to_string(static_cast>(t)); return oss.str(); From 3e8f3907f295dea0bacb2cecbc8d23d2648022a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Tue, 13 Jan 2026 02:43:37 -0500 Subject: [PATCH 67/81] Unify conv elementwise ops and layout definitions for fwd and bwd directions. --- .../factory/conv_bwd_weight_dl_factory.hpp | 2 +- ...onv_bwd_weight_multi_d_wmma_v3_factory.hpp | 2 +- .../conv_bwd_weight_multi_d_xdl_factory.hpp | 2 +- ...v_bwd_weight_two_stage_wmma_v3_factory.hpp | 2 +- .../conv_bwd_weight_two_stage_xdl_factory.hpp | 2 +- .../factory/conv_bwd_weight_wmma_factory.hpp | 2 +- .../conv_bwd_weight_wmma_v3_factory.hpp | 2 +- .../factory/conv_bwd_weight_xdl_factory.hpp | 2 +- .../conv_bwd_weight_xdl_v3_factory.hpp | 2 +- .../builder/factory/conv_fwd_dl_factory.hpp | 14 +-- .../factory/conv_fwd_large_tensor_factory.hpp | 14 +-- .../builder/factory/conv_fwd_v3_factory.hpp | 14 +-- .../builder/factory/conv_fwd_wmma_factory.hpp | 14 +-- .../builder/factory/conv_fwd_xdl_factory.hpp | 14 +-- .../helpers/ck/conv_elementwise_op.hpp | 28 ++---- .../factory/helpers/ck/conv_tensor_layout.hpp | 31 ++---- .../ck_tile/builder/testing/conv_fwd.hpp | 14 +-- .../ck_tile/builder/testing/conv_fwd_ck.hpp | 8 +- .../builder/test/unit_conv_tensor_layout.cpp | 96 +++++++++---------- 19 files changed, 121 insertions(+), 144 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_dl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_dl_factory.hpp index 80143427c78..60551b86331 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_dl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_dl_factory.hpp @@ -25,7 +25,7 @@ struct ConvBwdWeightDlFactory static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; using Layouts = internal::ConvTensorLayouts; using Types = internal::BwdWeightConvTensorDataTypes; - using Ops = internal::ElementwiseOps; + using Ops = internal::ConvElementwiseOps; using AlgorithmType = decltype(ALGORITHM); static constexpr auto BWD_CONV_SPECIALIZATION = diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp index 6f5c679b595..9485e050ef3 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp @@ -28,7 +28,7 @@ struct ConvBwdWeightMultiDWmmaV3Factory static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; using Layouts = internal::ConvTensorLayouts; using Types = internal::BwdWeightConvTensorDataTypes; - using Ops = internal::ElementwiseOps; + using Ops = internal::ConvElementwiseOps; using AlgorithmType = decltype(ALGORITHM); static constexpr auto BWD_CONV_SPECIALIZATION = diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp index 9f76568ca89..a8c92a39fbc 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp @@ -28,7 +28,7 @@ struct ConvBwdWeightMultiDXdlFactory static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; using Layouts = internal::ConvTensorLayouts; using Types = internal::BwdWeightConvTensorDataTypes; - using Ops = internal::ElementwiseOps; + using Ops = internal::ConvElementwiseOps; using AlgorithmType = decltype(ALGORITHM); static constexpr auto BWD_CONV_SPECIALIZATION = diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp index 86c48fe3220..4f51ac52150 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp @@ -28,7 +28,7 @@ struct ConvBwdWeightTwoStageWmmaV3Factory static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; using Layouts = internal::ConvTensorLayouts; using Types = internal::BwdWeightConvTensorDataTypes; - using Ops = internal::ElementwiseOps; + using Ops = internal::ConvElementwiseOps; using AlgorithmType = decltype(ALGORITHM); static constexpr auto BWD_CONV_SPECIALIZATION = diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp index 9c37beae46c..3fb04ffa2e4 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp @@ -28,7 +28,7 @@ struct ConvBwdWeightTwoStageXdlFactory static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; using Layouts = internal::ConvTensorLayouts; using Types = internal::BwdWeightConvTensorDataTypes; - using Ops = internal::ElementwiseOps; + using Ops = internal::ConvElementwiseOps; using AlgorithmType = decltype(ALGORITHM); static constexpr auto BWD_CONV_SPECIALIZATION = diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp index 32161a234ae..eecc2ad1023 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp @@ -28,7 +28,7 @@ struct ConvBwdWeightWmmaFactory static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; using Layouts = internal::ConvTensorLayouts; using Types = internal::BwdWeightConvTensorDataTypes; - using Ops = internal::ElementwiseOps; + using Ops = internal::ConvElementwiseOps; using AlgorithmType = decltype(ALGORITHM); static constexpr auto BWD_CONV_SPECIALIZATION = diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp index baf84402c34..b7845d7a006 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp @@ -28,7 +28,7 @@ struct ConvBwdWeightWmmaV3Factory static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; using Layouts = internal::ConvTensorLayouts; using Types = internal::BwdWeightConvTensorDataTypes; - using Ops = internal::ElementwiseOps; + using Ops = internal::ConvElementwiseOps; using AlgorithmType = decltype(ALGORITHM); static constexpr auto BWD_CONV_SPECIALIZATION = diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp index 91c19d2bd0d..db666ffb925 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp @@ -28,7 +28,7 @@ struct ConvBwdWeightXdlFactory static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; using Layouts = internal::ConvTensorLayouts; using Types = internal::BwdWeightConvTensorDataTypes; - using Ops = internal::ElementwiseOps; + using Ops = internal::ConvElementwiseOps; using AlgorithmType = decltype(ALGORITHM); static constexpr auto BWD_CONV_SPECIALIZATION = diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp index f3edd0e6d93..9054aa5b88c 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp @@ -28,7 +28,7 @@ struct ConvBwdWeightXdlV3Factory static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; using Layouts = internal::ConvTensorLayouts; using Types = internal::BwdWeightConvTensorDataTypes; - using Ops = internal::ElementwiseOps; + using Ops = internal::ConvElementwiseOps; using AlgorithmType = decltype(ALGORITHM); static constexpr auto BWD_CONV_SPECIALIZATION = diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp index 31246eb5a8b..679ce4e59e0 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp @@ -26,7 +26,7 @@ struct ConvFwdDlFactory static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; using Layouts = internal::ConvTensorLayouts; using Types = internal::FwdConvTensorDataTypes; - using Ops = internal::ElementwiseOps; + using Ops = internal::ConvElementwiseOps; using AlgorithmType = decltype(ALGORITHM); static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization(); @@ -94,13 +94,13 @@ struct ConvFwdDlFactory typename Types::DsDataTypes, typename Types::EDataType, typename Types::AccDataType, - typename Layouts::ALayout, - typename Layouts::BLayout, + typename Layouts::InLayout, + typename Layouts::WeiLayout, typename Layouts::DsLayout, - typename Layouts::ELayout, - typename Ops::AElementwiseOp, - typename Ops::BElementwiseOp, - typename Ops::CDEElementwiseOp, + typename Layouts::OutLayout, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, FWD_CONV_SPECIALIZATION, GEMM_SPECIALIZATION, BLOCK.block_size, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp index 9cd56ad7ad2..1ad00d5c7a2 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp @@ -28,7 +28,7 @@ struct ConvFwdLargeTensorFactory static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; using Layouts = internal::ConvTensorLayouts; using Types = internal::FwdConvTensorDataTypes; - using Ops = internal::ElementwiseOps; + using Ops = internal::ConvElementwiseOps; using AlgorithmType = decltype(ALGORITHM); static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization(); @@ -59,19 +59,19 @@ struct ConvFwdLargeTensorFactory using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor< SPATIAL_DIM, - typename Layouts::ALayout, - typename Layouts::BLayout, + typename Layouts::InLayout, + typename Layouts::WeiLayout, typename Layouts::DsLayout, - typename Layouts::ELayout, + typename Layouts::OutLayout, typename Types::ADataType, typename Types::BDataType, typename Types::AccDataType, typename Types::CShuffleDataType, typename Types::DsDataTypes, typename Types::EDataType, - typename Ops::AElementwiseOp, - typename Ops::BElementwiseOp, - typename Ops::CDEElementwiseOp, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, SPECIALIZATION.conv_spec, SPECIALIZATION.gemm_spec, ALGORITHM.num_gemm_k_prefetch_stages, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp index 7d889a0c01b..210f4393166 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp @@ -28,7 +28,7 @@ struct ConvFwdXdlV3Factory static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; using Layouts = internal::ConvTensorLayouts; using Types = internal::FwdConvTensorDataTypes; - using Ops = internal::ElementwiseOps; + using Ops = internal::ConvElementwiseOps; using AlgorithmType = decltype(ALGORITHM); static_assert(ALGORITHM.transfer.a.lds_transfer.is_direct_load == @@ -64,19 +64,19 @@ struct ConvFwdXdlV3Factory // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< SPATIAL_DIM, - typename Layouts::ALayout, - typename Layouts::BLayout, + typename Layouts::InLayout, + typename Layouts::WeiLayout, typename Layouts::DsLayout, - typename Layouts::ELayout, + typename Layouts::OutLayout, typename Types::ADataType, typename Types::BDataType, typename Types::AccDataType, typename Types::CShuffleDataType, typename Types::DsDataTypes, typename Types::EDataType, - typename Ops::AElementwiseOp, - typename Ops::BElementwiseOp, - typename Ops::CDEElementwiseOp, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, SPECIALIZATION.conv_spec, SPECIALIZATION.gemm_spec, BLOCK.block_size, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp index 3506f5d1a92..01a78738ce2 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp @@ -28,7 +28,7 @@ struct ConvFwdWmmaFactory static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; using Layouts = internal::ConvTensorLayouts; using Types = internal::FwdConvTensorDataTypes; - using Ops = internal::ElementwiseOps; + using Ops = internal::ConvElementwiseOps; using AlgorithmType = decltype(ALGORITHM); static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization(); @@ -60,19 +60,19 @@ struct ConvFwdWmmaFactory // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< SPATIAL_DIM, - typename Layouts::ALayout, - typename Layouts::BLayout, + typename Layouts::InLayout, + typename Layouts::WeiLayout, typename Layouts::DsLayout, - typename Layouts::ELayout, + typename Layouts::OutLayout, typename Types::ADataType, typename Types::BDataType, typename Types::AccDataType, typename Types::CShuffleDataType, typename Types::DsDataTypes, typename Types::EDataType, - typename Ops::AElementwiseOp, - typename Ops::BElementwiseOp, - typename Ops::CDEElementwiseOp, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, SPECIALIZATION.conv_spec, SPECIALIZATION.gemm_spec, ALGORITHM.num_gemm_k_prefetch_stages, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp index 446ceceda25..50116a4f877 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp @@ -28,7 +28,7 @@ struct ConvFwdXdlFactory static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; using Layouts = internal::ConvTensorLayouts; using Types = internal::FwdConvTensorDataTypes; - using Ops = internal::ElementwiseOps; + using Ops = internal::ConvElementwiseOps; using AlgorithmType = decltype(ALGORITHM); static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization(); @@ -59,19 +59,19 @@ struct ConvFwdXdlFactory // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< SPATIAL_DIM, - typename Layouts::ALayout, - typename Layouts::BLayout, + typename Layouts::InLayout, + typename Layouts::WeiLayout, typename Layouts::DsLayout, - typename Layouts::ELayout, + typename Layouts::OutLayout, typename Types::ADataType, typename Types::BDataType, typename Types::AccDataType, typename Types::CShuffleDataType, typename Types::DsDataTypes, typename Types::EDataType, - typename Ops::AElementwiseOp, - typename Ops::BElementwiseOp, - typename Ops::CDEElementwiseOp, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, SPECIALIZATION.conv_spec, SPECIALIZATION.gemm_spec, ALGORITHM.num_gemm_k_prefetch_stages, diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp index c7d9ce0ac6a..d2f82d3ecd3 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp @@ -62,30 +62,20 @@ consteval auto GetElementwiseOp() } template -struct ElementwiseOps +struct ConvElementwiseOps { - private: static constexpr auto input_op = GetElementwiseOp(); static constexpr auto weight_op = GetElementwiseOp(); static constexpr auto output_op = GetElementwiseOp(); - static constexpr bool is_forward = ConvDirectionIsForward; - static constexpr bool is_bwd_weight = ConvDirectionIsBackwardWeight; - - using InputOp = typename decltype(input_op)::Op; - using WeightOp = typename decltype(weight_op)::Op; - using OutputOp = typename decltype(output_op)::Op; - - public: - // Forward convolution elementwise ops - using AElementwiseOp = std::conditional_t; - using BElementwiseOp = std::conditional_t; - using CDEElementwiseOp = std::conditional_t; - - // Backward weight convolution elementwise ops - using InElementwiseOp = std::conditional_t; - using WeiElementwiseOp = std::conditional_t; - using OutElementwiseOp = std::conditional_t; + using InElementwiseOp = typename decltype(input_op)::Op; + using WeiElementwiseOp = typename decltype(weight_op)::Op; + using OutElementwiseOp = typename decltype(output_op)::Op; + + // TODO: Remove, now left for compatibility. Factories do not need it anymore. + // using AElementwiseOp = InElementwiseOp; + // using BElementwiseOp = WeiElementwiseOp; + // using CDEElementwiseOp = OutElementwiseOp; }; } // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp index d08dddff83a..f44ce154944 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp @@ -222,28 +222,15 @@ template ValidConvOutputLayoutForSpatialDim) struct ConvTensorLayouts { - private: - static constexpr bool is_forward = ConvDirectionIsForward; - static constexpr bool is_bwd_weight = ConvDirectionIsBackwardWeight; - - using InputLayout = decltype(TensorLayoutToCK()); - using WeightLayout = decltype(TensorLayoutToCK()); - using OutputLayout = decltype(TensorLayoutToCK()); - using AuxLayout = decltype(GetAuxiliaryTensorLayouts())::type; - - public: - // Forward convolution layouts - using ALayout = std::conditional_t; - using BLayout = std::conditional_t; - using ELayout = std::conditional_t; - - // Backward weight convolution layouts - using InLayout = std::conditional_t; - using WeiLayout = std::conditional_t; - using OutLayout = std::conditional_t; - - // Applicable for all directions - using DsLayout = AuxLayout; + using InLayout = decltype(TensorLayoutToCK()); + using WeiLayout = decltype(TensorLayoutToCK()); + using OutLayout = decltype(TensorLayoutToCK()); + using DsLayout = decltype(GetAuxiliaryTensorLayouts())::type; + + // TODO: Remove,now left for compatibility. Factories do not need it anymore. + // using ALayout = InLayout; + // using BLayout = WeiLayout; + // using ELayout = OutLayout; }; } // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/testing/conv_fwd.hpp b/experimental/builder/include/ck_tile/builder/testing/conv_fwd.hpp index 3240033c554..28024cd1b84 100644 --- a/experimental/builder/include/ck_tile/builder/testing/conv_fwd.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/conv_fwd.hpp @@ -71,7 +71,7 @@ struct Args using OutputDescriptor = TensorDescriptor; // TODO: We shouldn't need to call into an internal namespace here. - using Ops = factory::internal::ElementwiseOps; + using Ops = factory::internal::ConvElementwiseOps; // TODO: We shouldn't need to call into an internal namespace here. using Layouts = factory::internal::ConvTensorLayouts; @@ -88,9 +88,9 @@ struct Args FilterExtent input_left_pad; FilterExtent input_right_pad; - Ops::AElementwiseOp a_elementwise_op; - Ops::BElementwiseOp b_elementwise_op; - Ops::CDEElementwiseOp cde_elementwise_op; + Ops::InElementwiseOp a_elementwise_op; + Ops::WeiElementwiseOp b_elementwise_op; + Ops::OutElementwiseOp cde_elementwise_op; /// This function returns the `TensorDescriptor` corresponding to /// the input-tensor of the convolution problem. This can then @@ -105,7 +105,7 @@ struct Args // function. const auto param = to_ck_conv_param(); const auto desc = ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed< - typename Layouts::ALayout>(param); + typename Layouts::InLayout>(param); using Extent = typename InputDescriptor::Extent; return InputDescriptor(Extent::from_vector(desc.GetLengths()), Extent::from_vector(desc.GetStrides())); @@ -119,7 +119,7 @@ struct Args // See note in implementation of `make_input_descriptor`. const auto param = to_ck_conv_param(); const auto desc = ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed< - typename Layouts::BLayout>(param); + typename Layouts::WeiLayout>(param); using Extent = typename WeightDescriptor::Extent; return WeightDescriptor(Extent::from_vector(desc.GetLengths()), Extent::from_vector(desc.GetStrides())); @@ -133,7 +133,7 @@ struct Args // See note in implementation of `make_input_descriptor`. const auto param = to_ck_conv_param(); const auto desc = ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed< - typename Layouts::ELayout>(param); + typename Layouts::OutLayout>(param); using Extent = typename OutputDescriptor::Extent; return OutputDescriptor(Extent::from_vector(desc.GetLengths()), Extent::from_vector(desc.GetStrides())); diff --git a/experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck.hpp b/experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck.hpp index 499e0ef3de1..a90f53ba7d0 100644 --- a/experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck.hpp @@ -27,7 +27,7 @@ template > + typename Ops = factory::internal::ConvElementwiseOps> concept CkConvInstance = requires(Conv& conv, // TODO: This should be changed depending on IsMultiA etc. // Currently that is not yet supported elsewhere anyway. @@ -37,9 +37,9 @@ concept CkConvInstance = requires(Conv& conv, std::array lengths, std::array strides, std::array filter, - Ops::AElementwiseOp elementwise_a, - Ops::BElementwiseOp elementwise_b, - Ops::CDEElementwiseOp elementwise_cde) { + Ops::InElementwiseOp elementwise_a, + Ops::WeiElementwiseOp elementwise_b, + Ops::OutElementwiseOp elementwise_cde) { { conv.MakeArgument(p_a, p_b, diff --git a/experimental/builder/test/unit_conv_tensor_layout.cpp b/experimental/builder/test/unit_conv_tensor_layout.cpp index 8c1ba5562eb..0df94d977e7 100644 --- a/experimental/builder/test/unit_conv_tensor_layout.cpp +++ b/experimental/builder/test/unit_conv_tensor_layout.cpp @@ -40,9 +40,9 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_NWGC_GKXC_NWGK) using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v>)); } @@ -59,9 +59,9 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKXC_NGKW) using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v>)); } @@ -78,9 +78,9 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_GNWC_GKXC_GNWK) using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v>)); } @@ -97,9 +97,9 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKCX_NGKW) using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v>)); } @@ -116,9 +116,9 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKYXC_NGKHW) using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v>)); } @@ -135,9 +135,9 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_NHWGC_GKYXC_NHWGK) using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v>)); } @@ -154,9 +154,9 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_GNHWC_GKYXC_GNHWK) using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v>)); } @@ -173,9 +173,9 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKCYX_NGKHW) using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v>)); } @@ -192,9 +192,9 @@ TEST(ConvTensorLayout, AssignsLayoutsFor3D_NGCDHW_GKCZYX_NGKDHW) using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v>)); } @@ -211,9 +211,9 @@ TEST(ConvTensorLayout, AssignsLayoutsFor3D_NDHWGC_GKZYXC_NDHWGK) using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v>)); } @@ -230,9 +230,9 @@ TEST(ConvTensorLayout, AssignsLayoutsFor3D_GNDHWC_GKZYXC_GNDHWK) using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v>)); } @@ -389,9 +389,9 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasG_K) using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); using ExpectedDsLayout = ck::Tuple; EXPECT_TRUE((std::is_same_v)); @@ -416,9 +416,9 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasGC) using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); using ExpectedDsLayout = ck::Tuple; EXPECT_TRUE((std::is_same_v)); @@ -444,9 +444,9 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithTwoAuxiliaryTensors) using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); using ExpectedDsLayout = ck::Tuple; @@ -472,9 +472,9 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv1DWithBias) using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); using ExpectedDsLayout = ck::Tuple; EXPECT_TRUE((std::is_same_v)); @@ -499,9 +499,9 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv3DWithBias) using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); using ExpectedDsLayout = ck::Tuple; EXPECT_TRUE((std::is_same_v)); From b5d060b6b3bf787d5786b3737126fbc31489fe2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Tue, 13 Jan 2026 02:44:54 -0500 Subject: [PATCH 68/81] Remove old layout and elementwise ops. --- .../builder/factory/helpers/ck/conv_elementwise_op.hpp | 5 ----- .../builder/factory/helpers/ck/conv_tensor_layout.hpp | 5 ----- 2 files changed, 10 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp index d2f82d3ecd3..0cc43fc679b 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp @@ -71,11 +71,6 @@ struct ConvElementwiseOps using InElementwiseOp = typename decltype(input_op)::Op; using WeiElementwiseOp = typename decltype(weight_op)::Op; using OutElementwiseOp = typename decltype(output_op)::Op; - - // TODO: Remove, now left for compatibility. Factories do not need it anymore. - // using AElementwiseOp = InElementwiseOp; - // using BElementwiseOp = WeiElementwiseOp; - // using CDEElementwiseOp = OutElementwiseOp; }; } // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp index f44ce154944..86554595969 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp @@ -226,11 +226,6 @@ struct ConvTensorLayouts using WeiLayout = decltype(TensorLayoutToCK()); using OutLayout = decltype(TensorLayoutToCK()); using DsLayout = decltype(GetAuxiliaryTensorLayouts())::type; - - // TODO: Remove,now left for compatibility. Factories do not need it anymore. - // using ALayout = InLayout; - // using BLayout = WeiLayout; - // using ELayout = OutLayout; }; } // namespace ck_tile::builder::factory::internal From 97793cf352781e12cb17e83db27f9be2f5e123c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Tue, 13 Jan 2026 03:08:01 -0500 Subject: [PATCH 69/81] Unify handling of conv tensor types between fwd and bwd directions. --- .../factory/conv_bwd_weight_dl_factory.hpp | 2 +- ...onv_bwd_weight_multi_d_wmma_v3_factory.hpp | 2 +- .../conv_bwd_weight_multi_d_xdl_factory.hpp | 2 +- ...v_bwd_weight_two_stage_wmma_v3_factory.hpp | 2 +- .../conv_bwd_weight_two_stage_xdl_factory.hpp | 2 +- .../factory/conv_bwd_weight_wmma_factory.hpp | 2 +- .../conv_bwd_weight_wmma_v3_factory.hpp | 2 +- .../factory/conv_bwd_weight_xdl_factory.hpp | 2 +- .../conv_bwd_weight_xdl_v3_factory.hpp | 2 +- .../builder/factory/conv_fwd_dl_factory.hpp | 10 +++---- .../factory/conv_fwd_large_tensor_factory.hpp | 16 +++++------ .../builder/factory/conv_fwd_v3_factory.hpp | 16 +++++------ .../builder/factory/conv_fwd_wmma_factory.hpp | 12 ++++---- .../builder/factory/conv_fwd_xdl_factory.hpp | 16 +++++------ .../factory/helpers/ck/conv_tensor_type.hpp | 28 +------------------ .../builder/factory/reference_factory.hpp | 8 +++--- .../reflect/instance_traits_reference.hpp | 8 +++--- 17 files changed, 53 insertions(+), 79 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_dl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_dl_factory.hpp index 60551b86331..fda1659c75f 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_dl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_dl_factory.hpp @@ -24,7 +24,7 @@ struct ConvBwdWeightDlFactory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; using Layouts = internal::ConvTensorLayouts; - using Types = internal::BwdWeightConvTensorDataTypes; + using Types = internal::ConvTensorDataTypes; using Ops = internal::ConvElementwiseOps; using AlgorithmType = decltype(ALGORITHM); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp index 9485e050ef3..b02dea95589 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp @@ -27,7 +27,7 @@ struct ConvBwdWeightMultiDWmmaV3Factory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; using Layouts = internal::ConvTensorLayouts; - using Types = internal::BwdWeightConvTensorDataTypes; + using Types = internal::ConvTensorDataTypes; using Ops = internal::ConvElementwiseOps; using AlgorithmType = decltype(ALGORITHM); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp index a8c92a39fbc..4f6812617aa 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp @@ -27,7 +27,7 @@ struct ConvBwdWeightMultiDXdlFactory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; using Layouts = internal::ConvTensorLayouts; - using Types = internal::BwdWeightConvTensorDataTypes; + using Types = internal::ConvTensorDataTypes; using Ops = internal::ConvElementwiseOps; using AlgorithmType = decltype(ALGORITHM); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp index 4f51ac52150..adf108bac48 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp @@ -27,7 +27,7 @@ struct ConvBwdWeightTwoStageWmmaV3Factory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; using Layouts = internal::ConvTensorLayouts; - using Types = internal::BwdWeightConvTensorDataTypes; + using Types = internal::ConvTensorDataTypes; using Ops = internal::ConvElementwiseOps; using AlgorithmType = decltype(ALGORITHM); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp index 3fb04ffa2e4..d887c1c1ced 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp @@ -27,7 +27,7 @@ struct ConvBwdWeightTwoStageXdlFactory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; using Layouts = internal::ConvTensorLayouts; - using Types = internal::BwdWeightConvTensorDataTypes; + using Types = internal::ConvTensorDataTypes; using Ops = internal::ConvElementwiseOps; using AlgorithmType = decltype(ALGORITHM); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp index eecc2ad1023..4067845291f 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp @@ -27,7 +27,7 @@ struct ConvBwdWeightWmmaFactory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; using Layouts = internal::ConvTensorLayouts; - using Types = internal::BwdWeightConvTensorDataTypes; + using Types = internal::ConvTensorDataTypes; using Ops = internal::ConvElementwiseOps; using AlgorithmType = decltype(ALGORITHM); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp index b7845d7a006..027c8a1fba6 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp @@ -27,7 +27,7 @@ struct ConvBwdWeightWmmaV3Factory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; using Layouts = internal::ConvTensorLayouts; - using Types = internal::BwdWeightConvTensorDataTypes; + using Types = internal::ConvTensorDataTypes; using Ops = internal::ConvElementwiseOps; using AlgorithmType = decltype(ALGORITHM); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp index db666ffb925..fbb177f3337 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp @@ -27,7 +27,7 @@ struct ConvBwdWeightXdlFactory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; using Layouts = internal::ConvTensorLayouts; - using Types = internal::BwdWeightConvTensorDataTypes; + using Types = internal::ConvTensorDataTypes; using Ops = internal::ConvElementwiseOps; using AlgorithmType = decltype(ALGORITHM); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp index 9054aa5b88c..66a47c54078 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp @@ -27,7 +27,7 @@ struct ConvBwdWeightXdlV3Factory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; using Layouts = internal::ConvTensorLayouts; - using Types = internal::BwdWeightConvTensorDataTypes; + using Types = internal::ConvTensorDataTypes; using Ops = internal::ConvElementwiseOps; using AlgorithmType = decltype(ALGORITHM); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp index 679ce4e59e0..1d55772dd65 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp @@ -25,7 +25,7 @@ struct ConvFwdDlFactory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; using Layouts = internal::ConvTensorLayouts; - using Types = internal::FwdConvTensorDataTypes; + using Types = internal::ConvTensorDataTypes; using Ops = internal::ConvElementwiseOps; using AlgorithmType = decltype(ALGORITHM); @@ -89,10 +89,10 @@ struct ConvFwdDlFactory // The DL forward convolution kernel class instance using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< SPATIAL_DIM, - typename Types::ADataType, - typename Types::BDataType, - typename Types::DsDataTypes, - typename Types::EDataType, + typename Types::InDataType, + typename Types::WeiDataType, + typename Types::DsDataType, + typename Types::OutDataType, typename Types::AccDataType, typename Layouts::InLayout, typename Layouts::WeiLayout, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp index 1ad00d5c7a2..0ff410d7311 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp @@ -27,7 +27,7 @@ struct ConvFwdLargeTensorFactory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; using Layouts = internal::ConvTensorLayouts; - using Types = internal::FwdConvTensorDataTypes; + using Types = internal::ConvTensorDataTypes; using Ops = internal::ConvElementwiseOps; using AlgorithmType = decltype(ALGORITHM); @@ -63,12 +63,12 @@ struct ConvFwdLargeTensorFactory typename Layouts::WeiLayout, typename Layouts::DsLayout, typename Layouts::OutLayout, - typename Types::ADataType, - typename Types::BDataType, + typename Types::InDataType, + typename Types::WeiDataType, typename Types::AccDataType, - typename Types::CShuffleDataType, - typename Types::DsDataTypes, - typename Types::EDataType, + typename Types::OutComputeType, + typename Types::DsDataType, + typename Types::OutDataType, typename Ops::InElementwiseOp, typename Ops::WeiElementwiseOp, typename Ops::OutElementwiseOp, @@ -103,8 +103,8 @@ struct ConvFwdLargeTensorFactory C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, to_sequence_v, C_BLOCK_TRANSFER.scalar_per_vector, - typename Types::AComputeType, - typename Types::BComputeType, + typename Types::InComputeType, + typename Types::WeiComputeType, LOOP_SCHEDULER>; }; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp index 210f4393166..dd2fa65eaee 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp @@ -27,7 +27,7 @@ struct ConvFwdXdlV3Factory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; using Layouts = internal::ConvTensorLayouts; - using Types = internal::FwdConvTensorDataTypes; + using Types = internal::ConvTensorDataTypes; using Ops = internal::ConvElementwiseOps; using AlgorithmType = decltype(ALGORITHM); @@ -68,12 +68,12 @@ struct ConvFwdXdlV3Factory typename Layouts::WeiLayout, typename Layouts::DsLayout, typename Layouts::OutLayout, - typename Types::ADataType, - typename Types::BDataType, + typename Types::InDataType, + typename Types::WeiDataType, typename Types::AccDataType, - typename Types::CShuffleDataType, - typename Types::DsDataTypes, - typename Types::EDataType, + typename Types::OutComputeType, + typename Types::DsDataType, + typename Types::OutDataType, typename Ops::InElementwiseOp, typename Ops::WeiElementwiseOp, typename Ops::OutElementwiseOp, @@ -109,8 +109,8 @@ struct ConvFwdXdlV3Factory C_BLOCK_TRANSFER.scalar_per_vector, BLOCK_GEMM.scheduler, BLOCK_GEMM.pipeline_version, - typename Types::AComputeType, - typename Types::BComputeType, + typename Types::InComputeType, + typename Types::WeiComputeType, IS_DIRECT_LOAD>; }; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp index 01a78738ce2..2d6f7c394b9 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp @@ -27,7 +27,7 @@ struct ConvFwdWmmaFactory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; using Layouts = internal::ConvTensorLayouts; - using Types = internal::FwdConvTensorDataTypes; + using Types = internal::ConvTensorDataTypes; using Ops = internal::ConvElementwiseOps; using AlgorithmType = decltype(ALGORITHM); @@ -64,12 +64,12 @@ struct ConvFwdWmmaFactory typename Layouts::WeiLayout, typename Layouts::DsLayout, typename Layouts::OutLayout, - typename Types::ADataType, - typename Types::BDataType, + typename Types::InDataType, + typename Types::WeiDataType, typename Types::AccDataType, - typename Types::CShuffleDataType, - typename Types::DsDataTypes, - typename Types::EDataType, + typename Types::OutComputeType, + typename Types::DsDataType, + typename Types::OutDataType, typename Ops::InElementwiseOp, typename Ops::WeiElementwiseOp, typename Ops::OutElementwiseOp, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp index 50116a4f877..e03e0359699 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp @@ -27,7 +27,7 @@ struct ConvFwdXdlFactory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; using Layouts = internal::ConvTensorLayouts; - using Types = internal::FwdConvTensorDataTypes; + using Types = internal::ConvTensorDataTypes; using Ops = internal::ConvElementwiseOps; using AlgorithmType = decltype(ALGORITHM); @@ -63,12 +63,12 @@ struct ConvFwdXdlFactory typename Layouts::WeiLayout, typename Layouts::DsLayout, typename Layouts::OutLayout, - typename Types::ADataType, - typename Types::BDataType, + typename Types::InDataType, + typename Types::WeiDataType, typename Types::AccDataType, - typename Types::CShuffleDataType, - typename Types::DsDataTypes, - typename Types::EDataType, + typename Types::OutComputeType, + typename Types::DsDataType, + typename Types::OutDataType, typename Ops::InElementwiseOp, typename Ops::WeiElementwiseOp, typename Ops::OutElementwiseOp, @@ -103,8 +103,8 @@ struct ConvFwdXdlFactory C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, to_sequence_v, C_BLOCK_TRANSFER.scalar_per_vector, - typename Types::AComputeType, - typename Types::BComputeType, + typename Types::InComputeType, + typename Types::WeiComputeType, LOOP_SCHEDULER, ALGORITHM.num_conv_groups_to_merge>; }; diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp index 7d4abf2933e..0c017e0c47b 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp @@ -156,33 +156,7 @@ consteval auto GetAuxiliaryTensorDataTypes() } template -struct FwdConvTensorDataTypes -{ - static constexpr auto input_types = - GetTensorDataAndComputeTypes(); - static constexpr auto weight_types = - GetTensorDataAndComputeTypes(); - static constexpr auto output_types = - GetTensorDataAndComputeTypes(); - - using ADataType = typename decltype(input_types.first)::type; - using AComputeType = typename decltype(input_types.second)::type; - using BDataType = typename decltype(weight_types.first)::type; - using BComputeType = typename decltype(weight_types.second)::type; - using AccDataType = - typename decltype(GetTensorAccumulationType())::type; - using EDataType = typename decltype(output_types.first)::type; - - // This is the "compute" type for output. - using CShuffleDataType = typename decltype(output_types.second)::type; - - // Data types for the auxiliary tensors (e.g., bias). - using DsDataTypes = typename decltype(GetAuxiliaryTensorDataTypes())::type; -}; - -template -struct BwdWeightConvTensorDataTypes +struct ConvTensorDataTypes { static constexpr auto input_types = GetTensorDataAndComputeTypes(); diff --git a/experimental/builder/include/ck_tile/builder/factory/reference_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/reference_factory.hpp index 0748725c968..f6fc2dbda85 100644 --- a/experimental/builder/include/ck_tile/builder/factory/reference_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/reference_factory.hpp @@ -26,11 +26,11 @@ struct ReferenceFactory static constexpr auto kValidation = (internal::ValidateReferenceSignature(), 0); static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; - using Types = internal::FwdConvTensorDataTypes; + using Types = internal::ConvTensorDataTypes; - using InDataType = typename Types::ADataType; - using WeiDataType = typename Types::BDataType; - using OutDataType = typename Types::EDataType; + using InDataType = typename Types::InDataType; + using WeiDataType = typename Types::WeiDataType; + using OutDataType = typename Types::OutDataType; struct Instance { diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_reference.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_reference.hpp index b2e8bb6a7c5..6875e586cdd 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_reference.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_reference.hpp @@ -35,10 +35,10 @@ struct ReferenceCommonTraits typename builder::factory::internal::LayoutToCK::type; // Data types - extract from factory's type helper - using Types = builder::factory::internal::FwdConvTensorDataTypes; - using ADataType = typename Types::ADataType; - using BDataType = typename Types::BDataType; - using EDataType = typename Types::EDataType; + using Types = builder::factory::internal::ConvTensorDataTypes; + using ADataType = typename Types::InDataType; + using BDataType = typename Types::WeiDataType; + using EDataType = typename Types::OutDataType; using AccDataType = float; // Reference uses float accumulation // Elementwise operations - reference only supports PassThrough From 1d519792ca58a2cc39e005fc50d04c4a1d40d4a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Tue, 13 Jan 2026 03:19:10 -0500 Subject: [PATCH 70/81] Unify block transfer for fwd and bwd directions. Rename ThreadSliceDim to ThreadClusterRank. --- .../helpers/ck/conv_block_transfer.hpp | 27 ++++++------------- .../test/utils/conv_algorithm_type_utils.hpp | 14 +++++----- 2 files changed, 15 insertions(+), 26 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp index 4ef2f533c98..d873a4b9033 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp @@ -10,11 +10,12 @@ namespace ck_tile::builder::factory::internal { // Block transfer parameters for A or B tensor. +template struct BlockTransfer { - ck::Array thread_cluster_dims{}; // k0, m, k1 - ck::Array thread_cluster_order{}; - ck::Array src_access_order{}; + ck::Array thread_cluster_dims{}; + ck::Array thread_cluster_order{}; + ck::Array src_access_order{}; size_t src_vector_dim = 0; size_t src_scalar_per_vector = 0; size_t lds_dst_scalar_per_vector = 0; @@ -22,27 +23,15 @@ struct BlockTransfer bool lds_padding = false; }; -template -struct BwdBlockTransfer -{ - ck::Array thread_cluster_dims{}; - ck::Array thread_cluster_order{}; - ck::Array src_access_order{}; - size_t src_vector_dim = 0; - size_t src_scalar_per_vector = 0; - size_t lds_dst_scalar_per_vector = 0; - bool lds_padding = false; -}; - template -constexpr BlockTransfer SetFwdConvBlockTransfer() +constexpr BlockTransfer<> SetFwdConvBlockTransfer() { auto& block_xfer = TRANSFER.block_transfer; auto& block_order = TRANSFER.block_transfer_access_order; auto& src_order = TRANSFER.src_access_order; auto& lds_cfg = TRANSFER.lds_transfer; - return BlockTransfer{ + return BlockTransfer<>{ .thread_cluster_dims = {block_xfer.k0, block_xfer.m_n, block_xfer.k1}, .thread_cluster_order = {block_order.order[0], block_order.order[1], block_order.order[2]}, .src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2]}, @@ -68,7 +57,7 @@ constexpr auto SetBwdConvBlockTransfer() if constexpr(array_length == 3) { - return BwdBlockTransfer<3>{ + return BlockTransfer<3>{ .thread_cluster_dims = {block_xfer.k0, block_xfer.m_n, block_xfer.k1}, .thread_cluster_order = {block_order.order[0], block_order.order[1], @@ -82,7 +71,7 @@ constexpr auto SetBwdConvBlockTransfer() } else if constexpr(array_length == 4) { - return BwdBlockTransfer<4>{ + return BlockTransfer<4>{ .thread_cluster_dims = {block_xfer.k_batch_size, block_xfer.k0, block_xfer.m_n, diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index d80d6a1b8cf..f0c3bcde397 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -120,20 +120,20 @@ inline std::string to_string(BlockGemmPipeline t) return oss.str(); } -template -inline std::string to_string(BlockTransfer t) +template +inline std::string to_string(BlockTransfer t) { - if constexpr(ThreadSliceDim == 4) + if constexpr(ThreadClusterRank == 4) { return array_to_seq(std::array{t.k_batch_size, t.k0, t.m_n, t.k1}); } - else if constexpr(ThreadSliceDim == 3) + else if constexpr(ThreadClusterRank == 3) { return array_to_seq(std::array{t.k0, t.m_n, t.k1}); } else { - static_assert(ThreadSliceDim == 3 || ThreadSliceDim == 4, "Unsupported ThreadSliceDim"); + static_assert(ThreadClusterRank == 3 || ThreadClusterRank == 4, "Unsupported ThreadClusterRank"); } } @@ -288,8 +288,8 @@ inline std::string to_string(WmmaGemm_ t) return to_string(t.gridwise_gemm); } -template -inline std::string to_string(Transfer_ t) +template +inline std::string to_string(Transfer_ t) { return to_string(t.transfer); } From a8f1d44078e13a90df052b4e5af96f1946b6bad3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Tue, 13 Jan 2026 04:17:06 -0500 Subject: [PATCH 71/81] Make BlockTransferDescriptor concept parametrized. Introduce a common TileTransferParameters concept for conv algorithms. --- .../builder/conv_algorithm_concepts.hpp | 21 +++----- .../builder/factory/conv_algorithms.hpp | 49 ++++++++++--------- .../test/impl/conv_algorithm_types.hpp | 4 +- 3 files changed, 37 insertions(+), 37 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index 2b0f63296be..791924ccd44 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -56,7 +56,7 @@ concept GridwiseWmmaGemmDescriptor = requires(T t) { // Concept for vectorized data transfer for convolution input tensors. template -concept BlockTransferDescriptor = requires(T t) { +concept BlockTransferDescriptor3D = requires(T t) { { t.k0 } -> SizeType; { t.m_n } -> SizeType; { t.k1 } -> SizeType; @@ -70,6 +70,10 @@ concept BlockTransferDescriptor4D = requires(T t) { { t.k_batch_size } -> SizeType; }; +template +concept BlockTransferDescriptor = (ThreadClusterRank == 3 && BlockTransferDescriptor3D) || + (ThreadClusterRank == 4 && BlockTransferDescriptor4D); + // Concept for thread cluster dimensions for GEMM output tensor. template concept ThreadClusterDescriptor = requires(T t) { @@ -202,19 +206,10 @@ concept SpecifiesGridwiseWmmaGemm = requires(T t) { }; // Concept to check if a struct specifies convolution input and output block transfer info. -template +template concept SpecifiesBlockTransfer = requires(T t) { - { T::transfer.a.block_transfer } -> BlockTransferDescriptor; - { T::transfer.b.block_transfer } -> BlockTransferDescriptor; - { T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor; -}; - -// Concept to check if a struct specifies convolution input and output block transfer info -// for 4D thread slices. -template -concept SpecifiesBlockTransfer4D = requires(T t) { - { T::transfer.a.block_transfer } -> BlockTransferDescriptor4D; - { T::transfer.b.block_transfer } -> BlockTransferDescriptor4D; + { T::transfer.a.block_transfer } -> BlockTransferDescriptor; + { T::transfer.b.block_transfer } -> BlockTransferDescriptor; { T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor; }; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index ffd45efe496..93979440ca0 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -8,39 +8,46 @@ namespace ck_tile::builder::factory { // Base algorithm concepts +template +concept TileTransferParameters = + SpecifiesBlockTransfer && SpecifiesLdsTransfer && + SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder; + +template +concept SpecifiesTileTransferParameters3D = TileTransferParameters; + +template +concept SpecifiesTileTransferParameters4D = TileTransferParameters; + + template concept FwdXdlAlgorithmBase = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesBlockTransfer && - SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && - SpecifiesSourceAccessOrder && SpecifiesGridwiseFwdXdlGemm && - SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesTileTransferParameters3D && + SpecifiesGridwiseFwdXdlGemm && SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && SpecifiesNumGroupsToMerge && SpecifiesLoopScheduler; template concept BwdXdlAlgorithmBase = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesBlockTransfer4D && - SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && - SpecifiesSourceAccessOrder && SpecifiesGridwiseBwdXdlGemm && + ConvAlgorithmDescriptor && SpecifiesThreadBlock && + SpecifiesTileTransferParameters4D && SpecifiesGridwiseBwdXdlGemm && SpecifiesBwdWeightConvSpecialization; template concept BwdXdlV3AlgorithmBase = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesBlockTransfer && - SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && - SpecifiesSourceAccessOrder && SpecifiesGridwiseBwdXdlGemm && + ConvAlgorithmDescriptor && SpecifiesThreadBlock && + SpecifiesTileTransferParameters3D && SpecifiesGridwiseBwdXdlGemm && SpecifiesBwdWeightConvSpecialization && SpecifiesBlockGemm; + template concept BwdWmmaAlgorithmBase = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesBlockTransfer && - SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && - SpecifiesSourceAccessOrder && SpecifiesGridwiseWmmaGemm && + ConvAlgorithmDescriptor && SpecifiesThreadBlock && + SpecifiesTileTransferParameters3D && SpecifiesGridwiseWmmaGemm && SpecifiesBwdWeightConvSpecialization; template concept BwdWmmaV3AlgorithmBase = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesBlockTransfer && - SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && - SpecifiesSourceAccessOrder && SpecifiesGridwiseWmmaGemm && + ConvAlgorithmDescriptor && SpecifiesThreadBlock && + SpecifiesTileTransferParameters3D && SpecifiesGridwiseWmmaGemm && SpecifiesBwdWeightConvSpecialization && SpecifiesBlockGemm; // Reference algorithm concept @@ -62,17 +69,15 @@ concept LargeTensorAlgorithm = FwdXdlAlgorithmBase && SpecifiesLargeTensorSup template concept FwdXdlV3Algorithm = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesBlockTransfer && - SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && - SpecifiesSourceAccessOrder && SpecifiesGridwiseFwdXdlGemm && + ConvAlgorithmDescriptor && SpecifiesThreadBlock && + SpecifiesTileTransferParameters3D && SpecifiesGridwiseFwdXdlGemm && SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && SpecifiesBlockGemm; // FWD WMMA algorithm concepts template concept FwdWmmaAlgorithm = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesBlockTransfer && - SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && - SpecifiesSourceAccessOrder && SpecifiesGridwiseWmmaGemm && + ConvAlgorithmDescriptor && SpecifiesThreadBlock && + SpecifiesTileTransferParameters3D && SpecifiesGridwiseWmmaGemm && SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && SpecifiesLoopScheduler && SpecifiesGridwiseGemmPipeline; diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index 27ba1ec3b60..617686fda14 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -90,8 +90,8 @@ struct BlockTransfer<3> size_t m_n; size_t k1; }; -static_assert(ckb::BlockTransferDescriptor>); -static_assert(ckb::BlockTransferDescriptor>); +static_assert(ckb::BlockTransferDescriptor, 3>); +static_assert(ckb::BlockTransferDescriptor, 4>); // Describe C block transfer thread cluster lengths. struct ThreadCluster From bf57fbf48838649915bbc0c492fb4391c1b4913e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Tue, 13 Jan 2026 04:20:08 -0500 Subject: [PATCH 72/81] clang-format --- .../builder/factory/conv_algorithms.hpp | 46 +++++++++---------- .../factory/helpers/ck/conv_tensor_layout.hpp | 2 +- .../test/utils/conv_algorithm_type_utils.hpp | 3 +- 3 files changed, 25 insertions(+), 26 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index 93979440ca0..fc0ee48ec0b 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -9,7 +9,7 @@ namespace ck_tile::builder::factory { // Base algorithm concepts template -concept TileTransferParameters = +concept TileTransferParameters = SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder; @@ -19,36 +19,34 @@ concept SpecifiesTileTransferParameters3D = TileTransferParameters; template concept SpecifiesTileTransferParameters4D = TileTransferParameters; - template concept FwdXdlAlgorithmBase = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesTileTransferParameters3D && - SpecifiesGridwiseFwdXdlGemm && SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && - SpecifiesNumPrefetchStages && SpecifiesNumGroupsToMerge && SpecifiesLoopScheduler; + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesTileTransferParameters3D && + SpecifiesGridwiseFwdXdlGemm && SpecifiesFwdConvSpecialization && + SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && + SpecifiesNumGroupsToMerge && SpecifiesLoopScheduler; template concept BwdXdlAlgorithmBase = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && - SpecifiesTileTransferParameters4D && SpecifiesGridwiseBwdXdlGemm && - SpecifiesBwdWeightConvSpecialization; + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesTileTransferParameters4D && + SpecifiesGridwiseBwdXdlGemm && SpecifiesBwdWeightConvSpecialization; template concept BwdXdlV3AlgorithmBase = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && - SpecifiesTileTransferParameters3D && SpecifiesGridwiseBwdXdlGemm && - SpecifiesBwdWeightConvSpecialization && SpecifiesBlockGemm; + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesTileTransferParameters3D && + SpecifiesGridwiseBwdXdlGemm && SpecifiesBwdWeightConvSpecialization && + SpecifiesBlockGemm; template concept BwdWmmaAlgorithmBase = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && - SpecifiesTileTransferParameters3D && SpecifiesGridwiseWmmaGemm && - SpecifiesBwdWeightConvSpecialization; + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesTileTransferParameters3D && + SpecifiesGridwiseWmmaGemm && SpecifiesBwdWeightConvSpecialization; template concept BwdWmmaV3AlgorithmBase = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && - SpecifiesTileTransferParameters3D && SpecifiesGridwiseWmmaGemm && - SpecifiesBwdWeightConvSpecialization && SpecifiesBlockGemm; + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesTileTransferParameters3D && + SpecifiesGridwiseWmmaGemm && SpecifiesBwdWeightConvSpecialization && + SpecifiesBlockGemm; // Reference algorithm concept template @@ -69,17 +67,17 @@ concept LargeTensorAlgorithm = FwdXdlAlgorithmBase && SpecifiesLargeTensorSup template concept FwdXdlV3Algorithm = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && - SpecifiesTileTransferParameters3D && SpecifiesGridwiseFwdXdlGemm && - SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && SpecifiesBlockGemm; + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesTileTransferParameters3D && + SpecifiesGridwiseFwdXdlGemm && SpecifiesFwdConvSpecialization && + SpecifiesGemmSpecialization && SpecifiesBlockGemm; // FWD WMMA algorithm concepts template concept FwdWmmaAlgorithm = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && - SpecifiesTileTransferParameters3D && SpecifiesGridwiseWmmaGemm && - SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && - SpecifiesNumPrefetchStages && SpecifiesLoopScheduler && SpecifiesGridwiseGemmPipeline; + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesTileTransferParameters3D && + SpecifiesGridwiseWmmaGemm && SpecifiesFwdConvSpecialization && + SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && SpecifiesLoopScheduler && + SpecifiesGridwiseGemmPipeline; // FWD DL algorithms template diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp index 86554595969..fd6de9ae21e 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp @@ -224,7 +224,7 @@ struct ConvTensorLayouts { using InLayout = decltype(TensorLayoutToCK()); using WeiLayout = decltype(TensorLayoutToCK()); - using OutLayout = decltype(TensorLayoutToCK()); + using OutLayout = decltype(TensorLayoutToCK()); using DsLayout = decltype(GetAuxiliaryTensorLayouts())::type; }; diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index f0c3bcde397..23f4cf33648 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -133,7 +133,8 @@ inline std::string to_string(BlockTransfer t) } else { - static_assert(ThreadClusterRank == 3 || ThreadClusterRank == 4, "Unsupported ThreadClusterRank"); + static_assert(ThreadClusterRank == 3 || ThreadClusterRank == 4, + "Unsupported ThreadClusterRank"); } } From faf91267ecc8c6a1d339d3267d1e7d3a54bd9774 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Tue, 13 Jan 2026 08:57:00 -0500 Subject: [PATCH 73/81] Improve dispatcher error messages. Fix builder smoke tests. --- .../builder/conv_algorithm_concepts.hpp | 3 + .../builder/factory/conv_algorithms.hpp | 25 +++ .../builder/factory/conv_dispatcher.hpp | 163 +++++++++++------- .../builder/test/test_conv_description.cpp | 36 ++-- .../builder/test/unit_conv_tuning_params.cpp | 14 +- 5 files changed, 157 insertions(+), 84 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index d7f618f0792..54bda657e4d 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -350,6 +350,9 @@ concept SpecifiesWmma = requires { requires Algorithm.warp_gemm.matrix_instruction == MatrixInstructionType::WMMA; }; +template +concept SpecifiesValidWarpGemm = SpecifiesXdl || SpecifiesWmma; + /******************************************** */ /* DL-specific descriptors and requirements */ /******************************************** */ diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index b228126106e..46af4e49507 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -132,4 +132,29 @@ concept BwdDlAlgorithm = SpecifiesDlThreadCluster && SpecifiesDlBwdBlockTransfer && SpecifiesDlEpilogue; +// Concepts for valid XDL/WMMA algorithms +template +concept SpecifiesValidFwdXdlAlgorithm = +FwdXdlAlgorithm || FwdXdlV3Algorithm || LargeTensorAlgorithm; + +template +concept SpecifiesValidFwdWmmaAlgorithm = FwdWmmaAlgorithm; + +template +concept SpecifiesValidBwdXdlAlgorithm = + BwdXdlAlgorithm || BwdXdlV3Algorithm || + BwdTwoStageXdlAlgorithm || BwdMultiDXdlAlgorithm; + +template +concept SpecifiesValidBwdWmmaAlgorithm = + BwdWmmaAlgorithm || BwdWmmaV3Algorithm || + BwdTwoStageWmmaV3Algorithm || BwdMultiDWmmaV3Algorithm; + + +template +concept FwdWarpGemmOrDL = SpecifiesValidWarpGemm || FwdDlAlgorithm; + +template +concept BwdWarpGemmOrDL = SpecifiesValidWarpGemm || BwdDlAlgorithm; + } // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp index 97bfa05e1b0..965a231cd1c 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -117,33 +117,49 @@ constexpr auto make_conv_instance() // Forward direction (supports most algorithm variants) else if constexpr(ConvDirectionIsForward) { - if constexpr(FwdXdlV3Algorithm) - { - return typename ConvFwdXdlV3Factory::Instance{}; - } - else if constexpr(FwdXdlAlgorithm) - { - return typename ConvFwdXdlFactory::Instance{}; - } - else if constexpr(FwdWmmaAlgorithm) - { - return typename ConvFwdWmmaFactory::Instance{}; - } - else if constexpr(FwdDlAlgorithm) + if constexpr (SpecifiesXdl) + { + if constexpr(FwdXdlV3Algorithm) + { + return typename ConvFwdXdlV3Factory::Instance{}; + } + else if constexpr(FwdXdlAlgorithm) + { + return typename ConvFwdXdlFactory::Instance{}; + } + else if constexpr(LargeTensorAlgorithm) + { + return typename ConvFwdLargeTensorFactory::Instance{}; + } + else + { + static_assert( + SpecifiesValidFwdXdlAlgorithm, + "No suitable forward convolution XDL kernel factory found for the provided ALGORITHM. " + "The ALGORITHM must satisfy requirements for one of: XDL V3, generic XDL, " + "DL (NHWC layout), or Large Tensor variant."); + } + } + else if constexpr (SpecifiesWmma) + { + if constexpr(FwdWmmaAlgorithm) + { + return typename ConvFwdWmmaFactory::Instance{}; + } + else + { + static_assert(FwdWmmaAlgorithm, + "Did not find matching WMMA factory."); + } + } + else if constexpr (FwdDlAlgorithm) { return typename ConvFwdDlFactory::Instance{}; } - else if constexpr(LargeTensorAlgorithm) - { - return typename ConvFwdLargeTensorFactory::Instance{}; - } else { - static_assert( - false, - "No suitable forward convolution kernel factory found for the provided ALGORITHM. " - "The ALGORITHM must satisfy requirements for one of: Reference, Tile, XDL V3, XDL, " - "WMMA, DL (NHWC layout), or Large Tensor variant."); + static_assert(FwdWarpGemmOrDL, + "Forward convolution: Algorithm must specify either DL, XDL or WMMA."); } } // Backward data direction (will expand with more algorithms in the future) @@ -157,55 +173,72 @@ constexpr auto make_conv_instance() // Backward weight direction (will expand with more algorithms in the future) else if constexpr(ConvDirectionIsBackwardWeight) { - // Start from more specialized and end with least specialized. - if constexpr(BwdTwoStageXdlAlgorithm) - { - return - typename ConvBwdWeightTwoStageXdlFactory::Instance{}; - } - else if constexpr(BwdTwoStageWmmaV3Algorithm) - { - return typename ConvBwdWeightTwoStageWmmaV3Factory:: - Instance{}; - } - else if constexpr(BwdMultiDXdlAlgorithm) - { - return - typename ConvBwdWeightMultiDXdlFactory::Instance{}; - } - else if constexpr(BwdMultiDWmmaV3Algorithm) - { - return typename ConvBwdWeightMultiDWmmaV3Factory:: - Instance{}; + if constexpr (SpecifiesXdl) + { + // Start from more specialized and end with least specialized. + if constexpr(BwdTwoStageXdlAlgorithm) + { + return + typename ConvBwdWeightTwoStageXdlFactory::Instance{}; + } + else if constexpr(BwdMultiDXdlAlgorithm) + { + return + typename ConvBwdWeightMultiDXdlFactory::Instance{}; + } + else if constexpr(BwdXdlV3Algorithm) + { + return typename ConvBwdWeightXdlV3Factory::Instance{}; + } + else if constexpr(BwdXdlAlgorithm) + { + return typename ConvBwdWeightXdlFactory::Instance{}; + } + else + { + static_assert( + SpecifiesValidBwdXdlAlgorithm, + "No suitable backward weight convolution XDL kernel factory found for the provided ALGORITHM. " + "The ALGORITHM must satisfy requirements for one of: Two-Stage XDL, Multi-D XDL, DL, " + "generic XDL, or XDL V3 variant."); + } + } + else if constexpr (SpecifiesWmma) + { + // Start from more specialized and end with least specialized. + if constexpr(BwdTwoStageWmmaV3Algorithm) + { + return typename ConvBwdWeightTwoStageWmmaV3Factory::Instance{}; + } + else if constexpr(BwdMultiDWmmaV3Algorithm) + { + return typename ConvBwdWeightMultiDWmmaV3Factory::Instance{}; + } + else if constexpr(BwdWmmaV3Algorithm) + { + return typename ConvBwdWeightWmmaV3Factory::Instance{}; + } + else if constexpr(BwdWmmaAlgorithm) + { + return typename ConvBwdWeightWmmaFactory::Instance{}; + } + else + { + static_assert( + SpecifiesValidBwdWmmaAlgorithm, + "No suitable backward weight convolution WMMA kernel factory found for the provided ALGORITHM. " + "The ALGORITHM must satisfy requirements for one of: Two-Stage WMMA V3, Multi-D WMMA V3, " + "WMMA V3, or generic WMMA variant."); + } } else if constexpr(BwdDlAlgorithm) { return typename ConvBwdWeightDlFactory::Instance{}; } - else if constexpr(BwdXdlV3Algorithm) - { - return typename ConvBwdWeightXdlV3Factory::Instance{}; - } - else if constexpr(BwdWmmaV3Algorithm) - { - return typename ConvBwdWeightWmmaV3Factory::Instance{}; - } - else if constexpr(BwdXdlAlgorithm) - { - return typename ConvBwdWeightXdlFactory::Instance{}; - } - else if constexpr(BwdWmmaAlgorithm) - { - return typename ConvBwdWeightWmmaFactory::Instance{}; - } - else + else { - static_assert( - false, - "No suitable backward weight convolution kernel factory found for the provided " - "ALGORITHM. The ALGORITHM must satisfy requirements for one of: Reference, Tile, " - "XDL, XDL V3, Two-Stage XDL, DL, Multi-D XDL, WMMA V3, Two-Stage " - "WMMA V3, WMMA, or Multi-D WMMA V3 variant."); + static_assert(BwdWarpGemmOrDL, + "Backward convolution: Algorithm must specify either DL, XDL or WMMA."); } } else diff --git a/experimental/builder/test/test_conv_description.cpp b/experimental/builder/test/test_conv_description.cpp index 9e8008ccf02..ac3ce104718 100644 --- a/experimental/builder/test/test_conv_description.cpp +++ b/experimental/builder/test/test_conv_description.cpp @@ -117,42 +117,47 @@ static_assert(!ckb::ConvSignatureDescriptor transfer{ + ckb::test::InputOutputTileTransfer<> transfer{ .a = { - .block_transfer = {.k0 = 1, .m_n = 128, .k1 = 2}, - .lds_transfer = {.src_vector_dim = 2, + .thread_distribution = {.k0 = 1, .m_n = 128, .k1 = 2}, + .lds_transfer_params = {.global_memory_vector_load_size = 8, + .src_vector_dim = 2, .src_scalar_per_vector = 2, .lds_dst_scalar_per_vector = 2, .is_direct_load = false, .lds_padding = false}, - .block_transfer_access_order = {.order = {0, 1, 2}}, + .thread_distribution_access_order = {.order = {0, 1, 2}}, .src_access_order = {.order = {0, 1, 2}}, }, .b = { - .block_transfer = {.k0 = 1, .m_n = 128, .k1 = 2}, - .lds_transfer = {.src_vector_dim = 2, + .thread_distribution = {.k0 = 1, .m_n = 128, .k1 = 2}, + .lds_transfer_params = {.global_memory_vector_load_size = 8, + .src_vector_dim = 2, .src_scalar_per_vector = 2, .lds_dst_scalar_per_vector = 2, .is_direct_load = false, .lds_padding = false}, - .block_transfer_access_order = {.order = {0, 1, 2}}, + .thread_distribution_access_order = {.order = {0, 1, 2}}, .src_access_order = {.order = {0, 1, 2}}, }, .c = { - .thread_cluster_dims = - {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, + .thread_distribution = + {.gemm_m_block_size = 1, .gemm_m_per_block = 32, .gemm_n_block_size = 1, .gemm_n_per_block = 8}, .epilogue = {.m_xdl_per_wave_per_shuffle = 1, .n_per_wave_per_shuffle = 1, .scalar_per_vector = 2}, @@ -161,9 +166,8 @@ struct DefaultAlgorithm ckb::ConvSpecialization fwd_specialization = ckb::ConvSpecialization::DEFAULT; ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::Default; - ckb::test::BlockGemmPipeline block_gemm_pipeline{.pipeline_version = ckb::PipelineVersion::V4, - .scheduler = - ckb::PipelineScheduler::INTRAWAVE}; + ckb::test::GemmPipeline gemm_pipeline{.pipeline_version = ckb::PipelineVersion::V4, + .scheduler =ckb::PipelineScheduler::INTRAWAVE}; }; static_assert(ckb::ConvAlgorithmDescriptor); diff --git a/experimental/builder/test/unit_conv_tuning_params.cpp b/experimental/builder/test/unit_conv_tuning_params.cpp index 90057429309..26217dc22e1 100644 --- a/experimental/builder/test/unit_conv_tuning_params.cpp +++ b/experimental/builder/test/unit_conv_tuning_params.cpp @@ -17,9 +17,11 @@ TEST(ConvTuningParams, AssignsBlockGemmParams) { struct BlockGemm { + size_t num_conv_groups_to_merge = 1; + size_t num_gemm_k_prefetch_stages = 1; ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V3; ckb::PipelineScheduler scheduler = ckb::PipelineScheduler::INTRAWAVE; - } block_gemm_pipeline; + } gemm_pipeline; } kAlgorithm; constexpr auto block_gemm = SetBlockGemm(); @@ -31,7 +33,10 @@ TEST(ConvTuningParams, AssignsLoopSchedulerParam) { constexpr struct Algorithm { - ckb::PipelineScheduler loop_scheduler = ckb::PipelineScheduler::INTERWAVE; + struct GemmPipeline + { + ckb::PipelineScheduler scheduler = ckb::PipelineScheduler::INTERWAVE; + } gemm_pipeline; } kAlgorithm; constexpr auto loop_scheduler = SetLoopScheduler(); @@ -42,7 +47,10 @@ TEST(ConvTuningParams, AssignsGridwiseGemmPipelineVersion) { constexpr struct Algorithm { - ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V4; + struct GemmPipeline + { + ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V4; + } gemm_pipeline; } kAlgorithm; constexpr auto pipeline_version = SetGridwiseGemmPipelineVersion(); From 6210a83d5e3798b2a17667667982f5678559a7c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Wed, 14 Jan 2026 06:06:05 -0500 Subject: [PATCH 74/81] Rename thread distribution to thread cluster. --- .../builder/conv_algorithm_concepts.hpp | 22 +++--- .../builder/factory/conv_algorithms.hpp | 2 +- .../helpers/ck/conv_block_transfer.hpp | 10 +-- .../test/impl/conv_algorithm_types.hpp | 40 +++++------ .../builder/test/test_conv_description.cpp | 10 +-- .../test/utils/ckb_conv_test_configs.hpp | 70 +++++++++---------- .../test/utils/conv_algorithm_type_utils.hpp | 8 +-- 7 files changed, 81 insertions(+), 81 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index ab786b06a82..794fc3aef7d 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -57,14 +57,14 @@ concept GridwiseWmmaGemmDescriptor = requires(T t) { // Concept for vectorized data transfer for convolution input tensors. template -concept InputTileThreadDistributionDescriptor3D = requires(T t) { +concept InputTileThreadClusterDescriptor3D = requires(T t) { { t.k0 } -> SizeType; { t.m_n } -> SizeType; { t.k1 } -> SizeType; }; template -concept InputTileThreadDistributionDescriptor4D = requires(T t) { +concept InputTileThreadClusterDescriptor4D = requires(T t) { { t.k0 } -> SizeType; { t.m_n } -> SizeType; { t.k1 } -> SizeType; @@ -72,12 +72,12 @@ concept InputTileThreadDistributionDescriptor4D = requires(T t) { }; template -concept InputTileThreadDistributionDescriptor = (ThreadClusterRank == 3 && InputTileThreadDistributionDescriptor3D) || - (ThreadClusterRank == 4 && InputTileThreadDistributionDescriptor4D); +concept InputTileThreadClusterDescriptor = (ThreadClusterRank == 3 && InputTileThreadClusterDescriptor3D) || + (ThreadClusterRank == 4 && InputTileThreadClusterDescriptor4D); // Concept for thread cluster dimensions for GEMM output tensor. template -concept OutputTileThreadDistributionDescriptor = requires(T t) { +concept OutputTileThreadClusterDescriptor = requires(T t) { { t.gemm_m_block_size } -> SizeType; { t.gemm_m_per_block } -> SizeType; { t.gemm_n_block_size } -> SizeType; @@ -182,10 +182,10 @@ concept SpecifiesWarpGemm = requires { // Concept to check if a struct specifies convolution input and output block transfer info. template -concept SpecifiesThreadDistribution = requires(T t) { - { T::transfer.a.thread_distribution } -> InputTileThreadDistributionDescriptor; - { T::transfer.b.thread_distribution } -> InputTileThreadDistributionDescriptor; - { T::transfer.c.thread_distribution } -> OutputTileThreadDistributionDescriptor; +concept SpecifiesThreadClusters = requires(T t) { + { T::transfer.a.thread_cluster } -> InputTileThreadClusterDescriptor; + { T::transfer.b.thread_cluster } -> InputTileThreadClusterDescriptor; + { T::transfer.c.thread_cluster } -> OutputTileThreadClusterDescriptor; }; // Concept to check if a struct specifies convolution scalar per vector infor for A, B and C. @@ -207,8 +207,8 @@ concept SpecifiesLdsTransfer = requires(T t) { // Concept to check if a struct specifies thread cluster access order info. template concept SpecifiesThreadClusterAccessOrder = requires(T t) { - { T::transfer.a.thread_distribution_access_order } -> AccessOrderDescriptor; - { T::transfer.b.thread_distribution_access_order } -> AccessOrderDescriptor; + { T::transfer.a.thread_cluster_access_order } -> AccessOrderDescriptor; + { T::transfer.b.thread_cluster_access_order } -> AccessOrderDescriptor; }; // Concept to check if a struct specifies source access order info. diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index c10d2d62bb0..3b2f7f8c511 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -10,7 +10,7 @@ namespace ck_tile::builder::factory { // Base algorithm concepts template concept TileTransferParameters = - SpecifiesThreadDistribution && SpecifiesLdsTransfer && + SpecifiesThreadClusters && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder; template diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp index 2324416d827..7e811633c08 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp @@ -27,8 +27,8 @@ struct BlockTransfer template constexpr BlockTransfer<> SetFwdConvBlockTransfer() { - auto& block_xfer = TRANSFER.thread_distribution; - auto& block_order = TRANSFER.thread_distribution_access_order; + auto& block_xfer = TRANSFER.thread_cluster; + auto& block_order = TRANSFER.thread_cluster_access_order; auto& src_order = TRANSFER.src_access_order; auto& lds_cfg = TRANSFER.lds_transfer_params; @@ -48,8 +48,8 @@ constexpr BlockTransfer<> SetFwdConvBlockTransfer() template constexpr auto SetBwdConvBlockTransfer() { - auto& block_xfer = TRANSFER.thread_distribution; - auto& block_order = TRANSFER.thread_distribution_access_order; + auto& block_xfer = TRANSFER.thread_cluster; + auto& block_order = TRANSFER.thread_cluster_access_order; auto& src_order = TRANSFER.src_access_order; auto& lds_cfg = TRANSFER.lds_transfer_params; @@ -112,7 +112,7 @@ struct CBlockTransfer template constexpr CBlockTransfer SetCBlockTransfer() { - auto& thread_cluster_dims = ALGORITHM.transfer.c.thread_distribution; + auto& thread_cluster_dims = ALGORITHM.transfer.c.thread_cluster; auto& epilogue_config = ALGORITHM.transfer.c.epilogue; return CBlockTransfer{ .m_xdl_per_wave_per_shuffle = epilogue_config.m_xdl_per_wave_per_shuffle, diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index dfc46629b0a..fc8b4770bce 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -48,8 +48,8 @@ struct GemmPipeline static_assert(ckb::GemmPipelineDescriptor); // Describe input tensor thread cluster lengths. -template -struct InputDataThreadDistribution +template +struct InputThreadCluster { size_t k0; size_t m_n; @@ -57,26 +57,26 @@ struct InputDataThreadDistribution size_t k_batch_size; }; -// Specialization for ThreadSliceLength == 3 +// Specialization for ThreadClusterRank == 3 template <> -struct InputDataThreadDistribution<3> +struct InputThreadCluster<3> { size_t k0; size_t m_n; size_t k1; }; -static_assert(ckb::InputTileThreadDistributionDescriptor, 3>); -static_assert(ckb::InputTileThreadDistributionDescriptor, 4>); +static_assert(ckb::InputTileThreadClusterDescriptor, 3>); +static_assert(ckb::InputTileThreadClusterDescriptor, 4>); // Describe C block transfer thread cluster lengths. -struct OutputDataThreadDistribution +struct OutputThreadCluster { size_t gemm_m_block_size; size_t gemm_m_per_block; size_t gemm_n_block_size; size_t gemm_n_per_block; }; -static_assert(OutputTileThreadDistributionDescriptor); +static_assert(OutputTileThreadClusterDescriptor); struct LdsInputTransferParams { @@ -97,34 +97,34 @@ struct Epilogue }; static_assert(EpilogueDescriptor); -template +template struct AccessOrder { - std::array order; + std::array order; }; static_assert(AccessOrderDescriptor>); static_assert(AccessOrderDescriptor>); -template +template struct InputTileTransfer { - InputDataThreadDistribution thread_distribution; + InputThreadCluster thread_cluster; LdsInputTransferParams lds_transfer_params; - AccessOrder thread_distribution_access_order; - AccessOrder src_access_order; + AccessOrder thread_cluster_access_order; + AccessOrder src_access_order; }; struct OutputTileTransfer { - OutputDataThreadDistribution thread_distribution; + OutputThreadCluster thread_cluster; Epilogue epilogue; }; -template +template struct InputOutputTileTransfer { - InputTileTransfer a; - InputTileTransfer b; + InputTileTransfer a; + InputTileTransfer b; OutputTileTransfer c; }; @@ -180,10 +180,10 @@ struct WarpGemm_ WarpGemmParams warp_gemm; }; -template +template struct InputOutputTileTransfer_ { - InputOutputTileTransfer transfer; + InputOutputTileTransfer transfer; }; struct ConvSpecializationFwd_ diff --git a/experimental/builder/test/test_conv_description.cpp b/experimental/builder/test/test_conv_description.cpp index ac3ce104718..4d5cb1f2839 100644 --- a/experimental/builder/test/test_conv_description.cpp +++ b/experimental/builder/test/test_conv_description.cpp @@ -131,32 +131,32 @@ struct DefaultAlgorithm ckb::test::InputOutputTileTransfer<> transfer{ .a = { - .thread_distribution = {.k0 = 1, .m_n = 128, .k1 = 2}, + .thread_cluster = {.k0 = 1, .m_n = 128, .k1 = 2}, .lds_transfer_params = {.global_memory_vector_load_size = 8, .src_vector_dim = 2, .src_scalar_per_vector = 2, .lds_dst_scalar_per_vector = 2, .is_direct_load = false, .lds_padding = false}, - .thread_distribution_access_order = {.order = {0, 1, 2}}, + .thread_cluster_access_order = {.order = {0, 1, 2}}, .src_access_order = {.order = {0, 1, 2}}, }, .b = { - .thread_distribution = {.k0 = 1, .m_n = 128, .k1 = 2}, + .thread_cluster = {.k0 = 1, .m_n = 128, .k1 = 2}, .lds_transfer_params = {.global_memory_vector_load_size = 8, .src_vector_dim = 2, .src_scalar_per_vector = 2, .lds_dst_scalar_per_vector = 2, .is_direct_load = false, .lds_padding = false}, - .thread_distribution_access_order = {.order = {0, 1, 2}}, + .thread_cluster_access_order = {.order = {0, 1, 2}}, .src_access_order = {.order = {0, 1, 2}}, }, .c = { - .thread_distribution = + .thread_cluster = {.gemm_m_block_size = 1, .gemm_m_per_block = 32, .gemm_n_block_size = 1, .gemm_n_per_block = 8}, .epilogue = {.m_xdl_per_wave_per_shuffle = 1, .n_per_wave_per_shuffle = 1, diff --git a/experimental/builder/test/utils/ckb_conv_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_test_configs.hpp index 34682746cee..b58e9379476 100644 --- a/experimental/builder/test/utils/ckb_conv_test_configs.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_configs.hpp @@ -55,31 +55,31 @@ constexpr DlTransfer<5> DlTransfer5D{.a = DlBlockTransfer_1x8x1x1x1, constexpr InputOutputTileTransfer<> Transfer_4x64x1{ .a = { - .thread_distribution = {.k0 = 4, .m_n = 64, .k1 = 1}, + .thread_cluster = {.k0 = 4, .m_n = 64, .k1 = 1}, .lds_transfer_params = {.global_memory_vector_load_size = 8, .src_vector_dim = 2, .src_scalar_per_vector = 2, .lds_dst_scalar_per_vector = 8, .is_direct_load = false, .lds_padding = false}, - .thread_distribution_access_order = {1, 0, 2}, + .thread_cluster_access_order = {1, 0, 2}, .src_access_order = {1, 0, 2}, }, .b = { - .thread_distribution = {.k0 = 4, .m_n = 64, .k1 = 1}, + .thread_cluster = {.k0 = 4, .m_n = 64, .k1 = 1}, .lds_transfer_params = {.global_memory_vector_load_size = 8, .src_vector_dim = 2, .src_scalar_per_vector = 8, .lds_dst_scalar_per_vector = 8, .is_direct_load = false, .lds_padding = false}, - .thread_distribution_access_order = {1, 0, 2}, + .thread_cluster_access_order = {1, 0, 2}, .src_access_order = {1, 0, 2}, }, .c = { - .thread_distribution = + .thread_cluster = {.gemm_m_block_size = 1, .gemm_m_per_block = 32, .gemm_n_block_size = 1, .gemm_n_per_block = 8}, .epilogue = {.m_xdl_per_wave_per_shuffle = 1, .n_per_wave_per_shuffle = 1, @@ -90,31 +90,31 @@ constexpr InputOutputTileTransfer<> Transfer_4x64x1{ constexpr InputOutputTileTransfer<4> BwdTransfer_4x64x1{ .a = { - .thread_distribution = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1}, + .thread_cluster = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1}, .lds_transfer_params = {.global_memory_vector_load_size = 8, .src_vector_dim = 2, .src_scalar_per_vector = 2, .lds_dst_scalar_per_vector = 4, .is_direct_load = false, .lds_padding = true}, - .thread_distribution_access_order = {0, 3, 1, 2}, + .thread_cluster_access_order = {0, 3, 1, 2}, .src_access_order = {0, 2, 1, 3}, }, .b = { - .thread_distribution = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1}, + .thread_cluster = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1}, .lds_transfer_params = {.global_memory_vector_load_size = 8, .src_vector_dim = 2, .src_scalar_per_vector = 2, .lds_dst_scalar_per_vector = 4, .is_direct_load = false, .lds_padding = true}, - .thread_distribution_access_order = {0, 3, 1, 2}, + .thread_cluster_access_order = {0, 3, 1, 2}, .src_access_order = {0, 2, 1, 3}, }, .c = { - .thread_distribution = + .thread_cluster = {.gemm_m_block_size = 1, .gemm_m_per_block = 32, .gemm_n_block_size = 1, .gemm_n_per_block = 8}, .epilogue = {.m_xdl_per_wave_per_shuffle = 1, .n_per_wave_per_shuffle = 1, @@ -125,31 +125,31 @@ constexpr InputOutputTileTransfer<4> BwdTransfer_4x64x1{ constexpr InputOutputTileTransfer<> BwdTransfer_4x8x1_4x16x1_v3{ .a = { - .thread_distribution = {.k0 = 4, .m_n = 8, .k1 = 1}, + .thread_cluster = {.k0 = 4, .m_n = 8, .k1 = 1}, .lds_transfer_params = {.global_memory_vector_load_size = 8, .src_vector_dim = 1, .src_scalar_per_vector = 2, .lds_dst_scalar_per_vector = 2, .is_direct_load = false, .lds_padding = false}, - .thread_distribution_access_order = {2, 0, 1}, + .thread_cluster_access_order = {2, 0, 1}, .src_access_order = {1, 0, 2}, }, .b = { - .thread_distribution = {.k0 = 4, .m_n = 16, .k1 = 1}, + .thread_cluster = {.k0 = 4, .m_n = 16, .k1 = 1}, .lds_transfer_params = {.global_memory_vector_load_size = 8, .src_vector_dim = 1, .src_scalar_per_vector = 2, .lds_dst_scalar_per_vector = 2, .is_direct_load = false, .lds_padding = false}, - .thread_distribution_access_order = {2, 0, 1}, + .thread_cluster_access_order = {2, 0, 1}, .src_access_order = {1, 0, 2}, }, .c = { - .thread_distribution = + .thread_cluster = {.gemm_m_block_size = 1, .gemm_m_per_block = 8, .gemm_n_block_size = 1, .gemm_n_per_block = 8}, .epilogue = {.m_xdl_per_wave_per_shuffle = 1, .n_per_wave_per_shuffle = 1, @@ -160,31 +160,31 @@ constexpr InputOutputTileTransfer<> BwdTransfer_4x8x1_4x16x1_v3{ constexpr InputOutputTileTransfer<> Transfer_4x64x1_fp8{ .a = { - .thread_distribution = {.k0 = 4, .m_n = 64, .k1 = 1}, + .thread_cluster = {.k0 = 4, .m_n = 64, .k1 = 1}, .lds_transfer_params = {.global_memory_vector_load_size = 8, .src_vector_dim = 2, .src_scalar_per_vector = 8, .lds_dst_scalar_per_vector = 8, .is_direct_load = false, .lds_padding = true}, - .thread_distribution_access_order = {1, 0, 2}, + .thread_cluster_access_order = {1, 0, 2}, .src_access_order = {1, 0, 2}, }, .b = { - .thread_distribution = {.k0 = 4, .m_n = 64, .k1 = 1}, + .thread_cluster = {.k0 = 4, .m_n = 64, .k1 = 1}, .lds_transfer_params = {.global_memory_vector_load_size = 8, .src_vector_dim = 2, .src_scalar_per_vector = 8, .lds_dst_scalar_per_vector = 8, .is_direct_load = false, .lds_padding = true}, - .thread_distribution_access_order = {1, 0, 2}, + .thread_cluster_access_order = {1, 0, 2}, .src_access_order = {1, 0, 2}, }, .c = { - .thread_distribution = + .thread_cluster = {.gemm_m_block_size = 1, .gemm_m_per_block = 32, .gemm_n_block_size = 1, .gemm_n_per_block = 8}, .epilogue = {.m_xdl_per_wave_per_shuffle = 1, .n_per_wave_per_shuffle = 1, @@ -195,31 +195,31 @@ constexpr InputOutputTileTransfer<> Transfer_4x64x1_fp8{ constexpr InputOutputTileTransfer<> Transfer_4x16x1{ .a = { - .thread_distribution = {.k0 = 4, .m_n = 16, .k1 = 1}, + .thread_cluster = {.k0 = 4, .m_n = 16, .k1 = 1}, .lds_transfer_params = {.global_memory_vector_load_size = 8, .src_vector_dim = 2, .src_scalar_per_vector = 8, .lds_dst_scalar_per_vector = 8, .is_direct_load = false, .lds_padding = true}, - .thread_distribution_access_order = {1, 0, 2}, + .thread_cluster_access_order = {1, 0, 2}, .src_access_order = {1, 0, 2}, }, .b = { - .thread_distribution = {.k0 = 4, .m_n = 16, .k1 = 1}, + .thread_cluster = {.k0 = 4, .m_n = 16, .k1 = 1}, .lds_transfer_params = {.global_memory_vector_load_size = 8, .src_vector_dim = 2, .src_scalar_per_vector = 8, .lds_dst_scalar_per_vector = 8, .is_direct_load = false, .lds_padding = true}, - .thread_distribution_access_order = {1, 0, 2}, + .thread_cluster_access_order = {1, 0, 2}, .src_access_order = {1, 0, 2}, }, .c = { - .thread_distribution = + .thread_cluster = {.gemm_m_block_size = 1, .gemm_m_per_block = 16, .gemm_n_block_size = 1, .gemm_n_per_block = 4}, .epilogue = {.m_xdl_per_wave_per_shuffle = 1, .n_per_wave_per_shuffle = 1, @@ -231,31 +231,31 @@ constexpr InputOutputTileTransfer<> Transfer_4x16x1{ constexpr InputOutputTileTransfer<> Transfer_4x32x1{ .a = { - .thread_distribution = {.k0 = 4, .m_n = 32, .k1 = 1}, + .thread_cluster = {.k0 = 4, .m_n = 32, .k1 = 1}, .lds_transfer_params = {.global_memory_vector_load_size = 8, .src_vector_dim = 2, .src_scalar_per_vector = 16, .lds_dst_scalar_per_vector = 16, .is_direct_load = false, .lds_padding = true}, - .thread_distribution_access_order = {1, 0, 2}, + .thread_cluster_access_order = {1, 0, 2}, .src_access_order = {1, 0, 2}, }, .b = { - .thread_distribution = {.k0 = 4, .m_n = 32, .k1 = 1}, + .thread_cluster = {.k0 = 4, .m_n = 32, .k1 = 1}, .lds_transfer_params = {.global_memory_vector_load_size = 8, .src_vector_dim = 2, .src_scalar_per_vector = 16, .lds_dst_scalar_per_vector = 16, .is_direct_load = false, .lds_padding = true}, - .thread_distribution_access_order = {1, 0, 2}, + .thread_cluster_access_order = {1, 0, 2}, .src_access_order = {1, 0, 2}, }, .c = { - .thread_distribution = + .thread_cluster = {.gemm_m_block_size = 1, .gemm_m_per_block = 32, .gemm_n_block_size = 1, .gemm_n_per_block = 4}, .epilogue = {.m_xdl_per_wave_per_shuffle = 1, .n_per_wave_per_shuffle = 1, @@ -266,31 +266,31 @@ constexpr InputOutputTileTransfer<> Transfer_4x32x1{ constexpr InputOutputTileTransfer<> Transfer_4x32x1_vector_load_16_generic{ .a = { - .thread_distribution = {.k0 = 4, .m_n = 32, .k1 = 1}, + .thread_cluster = {.k0 = 4, .m_n = 32, .k1 = 1}, .lds_transfer_params = {.global_memory_vector_load_size = 16, .src_vector_dim = 2, .src_scalar_per_vector = 1, .lds_dst_scalar_per_vector = 16, .is_direct_load = false, .lds_padding = true}, - .thread_distribution_access_order = {1, 0, 2}, + .thread_cluster_access_order = {1, 0, 2}, .src_access_order = {1, 0, 2}, }, .b = { - .thread_distribution = {.k0 = 4, .m_n = 32, .k1 = 1}, + .thread_cluster = {.k0 = 4, .m_n = 32, .k1 = 1}, .lds_transfer_params = {.global_memory_vector_load_size = 16, .src_vector_dim = 2, .src_scalar_per_vector = 1, .lds_dst_scalar_per_vector = 16, .is_direct_load = false, .lds_padding = true}, - .thread_distribution_access_order = {1, 0, 2}, + .thread_cluster_access_order = {1, 0, 2}, .src_access_order = {1, 0, 2}, }, .c = { - .thread_distribution = + .thread_cluster = {.gemm_m_block_size = 1, .gemm_m_per_block = 32, .gemm_n_block_size = 1, .gemm_n_per_block = 4}, .epilogue = {.m_xdl_per_wave_per_shuffle = 1, .n_per_wave_per_shuffle = 1, diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index 2895da8d380..a79081529c3 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -103,7 +103,7 @@ inline std::string to_string(GemmPipeline t) } template -inline std::string to_string(InputDataThreadDistribution t) +inline std::string to_string(InputThreadCluster t) { if constexpr(ThreadClusterRank == 4) { @@ -121,7 +121,7 @@ inline std::string to_string(InputDataThreadDistribution t) } template <> -inline std::string to_string(OutputDataThreadDistribution t) +inline std::string to_string(OutputThreadCluster t) { return array_to_seq( std::array{t.gemm_m_block_size, t.gemm_m_per_block, t.gemm_n_block_size, t.gemm_n_per_block}); @@ -147,7 +147,7 @@ template inline std::string to_string(InputTileTransfer t) { std::ostringstream oss; - oss << to_string(t.thread_distribution) << "," << to_string(t.thread_distribution_access_order) << "," + oss << to_string(t.thread_cluster) << "," << to_string(t.thread_cluster_access_order) << "," << to_string(t.src_access_order) << "," << t.lds_transfer_params.src_vector_dim << "," << t.lds_transfer_params.src_scalar_per_vector << "," << t.lds_transfer_params.lds_dst_scalar_per_vector << "," << (t.lds_transfer_params.lds_padding ? "true" : "false"); @@ -159,7 +159,7 @@ inline std::string to_string(OutputTileTransfer t) { std::ostringstream oss; oss << t.epilogue.m_xdl_per_wave_per_shuffle << "," << t.epilogue.n_per_wave_per_shuffle << "," - << to_string(t.thread_distribution) << "," << t.epilogue.scalar_per_vector; + << to_string(t.thread_cluster) << "," << t.epilogue.scalar_per_vector; return oss.str(); } From 07608d1a86cca9b67605ba8cd5eb15966439fd1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Wed, 14 Jan 2026 06:15:14 -0500 Subject: [PATCH 75/81] Rename LDS transfer related assets. --- .../builder/conv_algorithm_concepts.hpp | 6 ++-- .../builder/factory/conv_fwd_v3_factory.hpp | 6 ++-- .../helpers/ck/conv_block_transfer.hpp | 4 +-- .../test/impl/conv_algorithm_types.hpp | 6 ++-- .../builder/test/test_conv_description.cpp | 4 +-- .../test/utils/ckb_conv_test_configs.hpp | 28 +++++++-------- .../test/utils/conv_algorithm_type_utils.hpp | 36 +++++++++---------- 7 files changed, 45 insertions(+), 45 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index 794fc3aef7d..fcbcd62b27a 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -86,7 +86,7 @@ concept OutputTileThreadClusterDescriptor = requires(T t) { // Concept for the LDS transfer for the convolution input tensors. template -concept LdsInputTransferDescriptor = requires(T t) { +concept LdsTransferDescriptor = requires(T t) { { t.global_memory_vector_load_size } -> SizeType; { t.src_vector_dim } -> SizeType; { t.src_scalar_per_vector } -> SizeType; @@ -199,8 +199,8 @@ concept SpecifiesTileTransfer = requires(T t) { // Concept to check if a struct specifies LDS transfer info for tensors A, B, and C. template concept SpecifiesLdsTransfer = requires(T t) { - { T::transfer.a.lds_transfer_params } -> LdsInputTransferDescriptor; - { T::transfer.b.lds_transfer_params } -> LdsInputTransferDescriptor; + { T::transfer.a.lds_transfer } -> LdsTransferDescriptor; + { T::transfer.b.lds_transfer } -> LdsTransferDescriptor; { T::transfer.c.epilogue } -> EpilogueDescriptor; }; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp index 7e403a7b297..e6f71b0b258 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp @@ -31,11 +31,11 @@ struct ConvFwdXdlV3Factory using Ops = internal::ConvElementwiseOps; using AlgorithmType = decltype(ALGORITHM); - static_assert(ALGORITHM.transfer.a.lds_transfer_params.is_direct_load == - ALGORITHM.transfer.b.lds_transfer_params.is_direct_load, + static_assert(ALGORITHM.transfer.a.lds_transfer.is_direct_load == + ALGORITHM.transfer.b.lds_transfer.is_direct_load, "A and B block transfers must both be direct load or not."); - static constexpr bool IS_DIRECT_LOAD = ALGORITHM.transfer.a.lds_transfer_params.is_direct_load; + static constexpr bool IS_DIRECT_LOAD = ALGORITHM.transfer.a.lds_transfer.is_direct_load; static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization(); static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization(); static constexpr internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION, diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp index 7e811633c08..c48404dd470 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp @@ -30,7 +30,7 @@ constexpr BlockTransfer<> SetFwdConvBlockTransfer() auto& block_xfer = TRANSFER.thread_cluster; auto& block_order = TRANSFER.thread_cluster_access_order; auto& src_order = TRANSFER.src_access_order; - auto& lds_cfg = TRANSFER.lds_transfer_params; + auto& lds_cfg = TRANSFER.lds_transfer; return BlockTransfer<>{ .thread_cluster_dims = {block_xfer.k0, block_xfer.m_n, block_xfer.k1}, @@ -51,7 +51,7 @@ constexpr auto SetBwdConvBlockTransfer() auto& block_xfer = TRANSFER.thread_cluster; auto& block_order = TRANSFER.thread_cluster_access_order; auto& src_order = TRANSFER.src_access_order; - auto& lds_cfg = TRANSFER.lds_transfer_params; + auto& lds_cfg = TRANSFER.lds_transfer; constexpr auto array_length = block_order.order.size(); static_assert(block_order.order.size() == src_order.order.size(), diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index fc8b4770bce..8d315f69b8c 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -78,7 +78,7 @@ struct OutputThreadCluster }; static_assert(OutputTileThreadClusterDescriptor); -struct LdsInputTransferParams +struct LdsTransfer { size_t global_memory_vector_load_size; size_t src_vector_dim; @@ -87,7 +87,7 @@ struct LdsInputTransferParams bool is_direct_load; bool lds_padding; }; -static_assert(LdsInputTransferDescriptor); +static_assert(LdsTransferDescriptor); struct Epilogue { @@ -109,7 +109,7 @@ template struct InputTileTransfer { InputThreadCluster thread_cluster; - LdsInputTransferParams lds_transfer_params; + LdsTransfer lds_transfer; AccessOrder thread_cluster_access_order; AccessOrder src_access_order; }; diff --git a/experimental/builder/test/test_conv_description.cpp b/experimental/builder/test/test_conv_description.cpp index 4d5cb1f2839..760ada78eae 100644 --- a/experimental/builder/test/test_conv_description.cpp +++ b/experimental/builder/test/test_conv_description.cpp @@ -132,7 +132,7 @@ struct DefaultAlgorithm .a = { .thread_cluster = {.k0 = 1, .m_n = 128, .k1 = 2}, - .lds_transfer_params = {.global_memory_vector_load_size = 8, + .lds_transfer = {.global_memory_vector_load_size = 8, .src_vector_dim = 2, .src_scalar_per_vector = 2, .lds_dst_scalar_per_vector = 2, @@ -145,7 +145,7 @@ struct DefaultAlgorithm .b = { .thread_cluster = {.k0 = 1, .m_n = 128, .k1 = 2}, - .lds_transfer_params = {.global_memory_vector_load_size = 8, + .lds_transfer = {.global_memory_vector_load_size = 8, .src_vector_dim = 2, .src_scalar_per_vector = 2, .lds_dst_scalar_per_vector = 2, diff --git a/experimental/builder/test/utils/ckb_conv_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_test_configs.hpp index b58e9379476..a281ca28ad6 100644 --- a/experimental/builder/test/utils/ckb_conv_test_configs.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_configs.hpp @@ -56,7 +56,7 @@ constexpr InputOutputTileTransfer<> Transfer_4x64x1{ .a = { .thread_cluster = {.k0 = 4, .m_n = 64, .k1 = 1}, - .lds_transfer_params = {.global_memory_vector_load_size = 8, + .lds_transfer = {.global_memory_vector_load_size = 8, .src_vector_dim = 2, .src_scalar_per_vector = 2, .lds_dst_scalar_per_vector = 8, @@ -68,7 +68,7 @@ constexpr InputOutputTileTransfer<> Transfer_4x64x1{ .b = { .thread_cluster = {.k0 = 4, .m_n = 64, .k1 = 1}, - .lds_transfer_params = {.global_memory_vector_load_size = 8, + .lds_transfer = {.global_memory_vector_load_size = 8, .src_vector_dim = 2, .src_scalar_per_vector = 8, .lds_dst_scalar_per_vector = 8, @@ -91,7 +91,7 @@ constexpr InputOutputTileTransfer<4> BwdTransfer_4x64x1{ .a = { .thread_cluster = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1}, - .lds_transfer_params = {.global_memory_vector_load_size = 8, + .lds_transfer = {.global_memory_vector_load_size = 8, .src_vector_dim = 2, .src_scalar_per_vector = 2, .lds_dst_scalar_per_vector = 4, @@ -103,7 +103,7 @@ constexpr InputOutputTileTransfer<4> BwdTransfer_4x64x1{ .b = { .thread_cluster = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1}, - .lds_transfer_params = {.global_memory_vector_load_size = 8, + .lds_transfer = {.global_memory_vector_load_size = 8, .src_vector_dim = 2, .src_scalar_per_vector = 2, .lds_dst_scalar_per_vector = 4, @@ -126,7 +126,7 @@ constexpr InputOutputTileTransfer<> BwdTransfer_4x8x1_4x16x1_v3{ .a = { .thread_cluster = {.k0 = 4, .m_n = 8, .k1 = 1}, - .lds_transfer_params = {.global_memory_vector_load_size = 8, + .lds_transfer = {.global_memory_vector_load_size = 8, .src_vector_dim = 1, .src_scalar_per_vector = 2, .lds_dst_scalar_per_vector = 2, @@ -138,7 +138,7 @@ constexpr InputOutputTileTransfer<> BwdTransfer_4x8x1_4x16x1_v3{ .b = { .thread_cluster = {.k0 = 4, .m_n = 16, .k1 = 1}, - .lds_transfer_params = {.global_memory_vector_load_size = 8, + .lds_transfer = {.global_memory_vector_load_size = 8, .src_vector_dim = 1, .src_scalar_per_vector = 2, .lds_dst_scalar_per_vector = 2, @@ -161,7 +161,7 @@ constexpr InputOutputTileTransfer<> Transfer_4x64x1_fp8{ .a = { .thread_cluster = {.k0 = 4, .m_n = 64, .k1 = 1}, - .lds_transfer_params = {.global_memory_vector_load_size = 8, + .lds_transfer = {.global_memory_vector_load_size = 8, .src_vector_dim = 2, .src_scalar_per_vector = 8, .lds_dst_scalar_per_vector = 8, @@ -173,7 +173,7 @@ constexpr InputOutputTileTransfer<> Transfer_4x64x1_fp8{ .b = { .thread_cluster = {.k0 = 4, .m_n = 64, .k1 = 1}, - .lds_transfer_params = {.global_memory_vector_load_size = 8, + .lds_transfer = {.global_memory_vector_load_size = 8, .src_vector_dim = 2, .src_scalar_per_vector = 8, .lds_dst_scalar_per_vector = 8, @@ -196,7 +196,7 @@ constexpr InputOutputTileTransfer<> Transfer_4x16x1{ .a = { .thread_cluster = {.k0 = 4, .m_n = 16, .k1 = 1}, - .lds_transfer_params = {.global_memory_vector_load_size = 8, + .lds_transfer = {.global_memory_vector_load_size = 8, .src_vector_dim = 2, .src_scalar_per_vector = 8, .lds_dst_scalar_per_vector = 8, @@ -208,7 +208,7 @@ constexpr InputOutputTileTransfer<> Transfer_4x16x1{ .b = { .thread_cluster = {.k0 = 4, .m_n = 16, .k1 = 1}, - .lds_transfer_params = {.global_memory_vector_load_size = 8, + .lds_transfer = {.global_memory_vector_load_size = 8, .src_vector_dim = 2, .src_scalar_per_vector = 8, .lds_dst_scalar_per_vector = 8, @@ -232,7 +232,7 @@ constexpr InputOutputTileTransfer<> Transfer_4x32x1{ .a = { .thread_cluster = {.k0 = 4, .m_n = 32, .k1 = 1}, - .lds_transfer_params = {.global_memory_vector_load_size = 8, + .lds_transfer = {.global_memory_vector_load_size = 8, .src_vector_dim = 2, .src_scalar_per_vector = 16, .lds_dst_scalar_per_vector = 16, @@ -244,7 +244,7 @@ constexpr InputOutputTileTransfer<> Transfer_4x32x1{ .b = { .thread_cluster = {.k0 = 4, .m_n = 32, .k1 = 1}, - .lds_transfer_params = {.global_memory_vector_load_size = 8, + .lds_transfer = {.global_memory_vector_load_size = 8, .src_vector_dim = 2, .src_scalar_per_vector = 16, .lds_dst_scalar_per_vector = 16, @@ -267,7 +267,7 @@ constexpr InputOutputTileTransfer<> Transfer_4x32x1_vector_load_16_generic{ .a = { .thread_cluster = {.k0 = 4, .m_n = 32, .k1 = 1}, - .lds_transfer_params = {.global_memory_vector_load_size = 16, + .lds_transfer = {.global_memory_vector_load_size = 16, .src_vector_dim = 2, .src_scalar_per_vector = 1, .lds_dst_scalar_per_vector = 16, @@ -279,7 +279,7 @@ constexpr InputOutputTileTransfer<> Transfer_4x32x1_vector_load_16_generic{ .b = { .thread_cluster = {.k0 = 4, .m_n = 32, .k1 = 1}, - .lds_transfer_params = {.global_memory_vector_load_size = 16, + .lds_transfer = {.global_memory_vector_load_size = 16, .src_vector_dim = 2, .src_scalar_per_vector = 1, .lds_dst_scalar_per_vector = 16, diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index a79081529c3..db5aa93a7ce 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -128,7 +128,7 @@ inline std::string to_string(OutputThreadCluster t) } template <> -inline std::string to_string(LdsInputTransferParams t) +inline std::string to_string(LdsTransfer t) { std::ostringstream oss; oss << t.global_memory_vector_load_size << "," << t.src_vector_dim << "," << t.src_scalar_per_vector << "," << t.lds_dst_scalar_per_vector @@ -148,9 +148,9 @@ inline std::string to_string(InputTileTransfer t) { std::ostringstream oss; oss << to_string(t.thread_cluster) << "," << to_string(t.thread_cluster_access_order) << "," - << to_string(t.src_access_order) << "," << t.lds_transfer_params.src_vector_dim << "," - << t.lds_transfer_params.src_scalar_per_vector << "," << t.lds_transfer_params.lds_dst_scalar_per_vector - << "," << (t.lds_transfer_params.lds_padding ? "true" : "false"); + << to_string(t.src_access_order) << "," << t.lds_transfer.src_vector_dim << "," + << t.lds_transfer.src_scalar_per_vector << "," << t.lds_transfer.lds_dst_scalar_per_vector + << "," << (t.lds_transfer.lds_padding ? "true" : "false"); return oss.str(); } @@ -322,15 +322,15 @@ inline std::string to_string(t)) << "," - << t.transfer.a.lds_transfer_params.global_memory_vector_load_size << "," + << t.transfer.a.lds_transfer.global_memory_vector_load_size << "," << to_string(static_cast(t)) << "," << to_string(static_cast>(t)); } else { oss << to_string(static_cast(t)) << "," - << t.transfer.a.lds_transfer_params.global_memory_vector_load_size << "," - << t.transfer.b.lds_transfer_params.global_memory_vector_load_size << "," + << t.transfer.a.lds_transfer.global_memory_vector_load_size << "," + << t.transfer.b.lds_transfer.global_memory_vector_load_size << "," << to_string(static_cast(t)) << "," << to_string(static_cast>(t)); } @@ -343,8 +343,8 @@ inline std::string to_string(t)) << "," - << t.transfer.a.lds_transfer_params.global_memory_vector_load_size << "," - << t.transfer.b.lds_transfer_params.global_memory_vector_load_size << "," + << t.transfer.a.lds_transfer.global_memory_vector_load_size << "," + << t.transfer.b.lds_transfer.global_memory_vector_load_size << "," << to_string(static_cast(t)) << "," << to_string(static_cast>(t)); return oss.str(); @@ -368,8 +368,8 @@ inline std::string to_string(t)) << "," - << t.transfer.a.lds_transfer_params.global_memory_vector_load_size << "," - << t.transfer.b.lds_transfer_params.global_memory_vector_load_size << "," + << t.transfer.a.lds_transfer.global_memory_vector_load_size << "," + << t.transfer.b.lds_transfer.global_memory_vector_load_size << "," << to_string(static_cast(t)) << "," << to_string(static_cast>(t)); return oss.str(); @@ -381,7 +381,7 @@ inline std::string to_string(t)) << "," - << t.transfer.a.lds_transfer_params.global_memory_vector_load_size << "," + << t.transfer.a.lds_transfer.global_memory_vector_load_size << "," << to_string(static_cast(t)) << "," << to_string(static_cast>(t)); return oss.str(); @@ -393,7 +393,7 @@ inline std::string to_string(t)) << "," - << t.transfer.a.lds_transfer_params.global_memory_vector_load_size << "," + << t.transfer.a.lds_transfer.global_memory_vector_load_size << "," << to_string(static_cast(t)) << "," << to_string(static_cast>(t)); return oss.str(); @@ -405,7 +405,7 @@ inline std::string to_string(t)) << "," - << t.transfer.a.lds_transfer_params.global_memory_vector_load_size << "," + << t.transfer.a.lds_transfer.global_memory_vector_load_size << "," << to_string(static_cast(t)) << "," << to_string(static_cast>(t)); return oss.str(); @@ -417,7 +417,7 @@ inline std::string to_string(t)) << "," - << t.transfer.a.lds_transfer_params.global_memory_vector_load_size << "," + << t.transfer.a.lds_transfer.global_memory_vector_load_size << "," << to_string(static_cast(t)) << "," << to_string(static_cast>(t)); return oss.str(); @@ -429,7 +429,7 @@ inline std::string to_string(t)) << "," - << t.transfer.a.lds_transfer_params.global_memory_vector_load_size << "," + << t.transfer.a.lds_transfer.global_memory_vector_load_size << "," << to_string(static_cast(t)) << "," << to_string(static_cast>(t)); return oss.str(); @@ -442,7 +442,7 @@ inline std::string to_string(t)) - << "," << t.transfer.a.lds_transfer_params.global_memory_vector_load_size + << "," << t.transfer.a.lds_transfer.global_memory_vector_load_size << "," << to_string(static_cast(t)) << "," << to_string(static_cast>(t)); return oss.str(); @@ -467,7 +467,7 @@ inline std::string to_string(t)) << "," - << t.transfer.a.lds_transfer_params.global_memory_vector_load_size << "," + << t.transfer.a.lds_transfer.global_memory_vector_load_size << "," << to_string(static_cast(t)) << "," << to_string(static_cast>(t)); return oss.str(); From 3f7b250d338ef91e4e400142fc2e445539f83fc7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Wed, 14 Jan 2026 06:22:46 -0500 Subject: [PATCH 76/81] Small concepts clean-up. --- .../ck_tile/builder/conv_algorithm_concepts.hpp | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index fcbcd62b27a..adec5643062 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -38,23 +38,15 @@ concept WarpGemmDescriptor = requires(T t) { { t.gemm_n_iters_per_wave } -> SizeType; }; -// Concept for parameter that describe block GEMM problem. +// Concept for parameters that describe the GEMM pipeline. template concept GemmPipelineDescriptor = requires(T t) { + { t.num_conv_groups_to_merge } -> SizeType; + { t.num_gemm_k_prefetch_stages } -> SizeType; { t.pipeline_version } -> std::convertible_to; { t.scheduler } -> std::convertible_to; }; -// Concept for parameters that describe a gridwise WMMA GEMM problem. -template -concept GridwiseWmmaGemmDescriptor = requires(T t) { - { t.k1 } -> SizeType; - { t.m_per_wmma } -> SizeType; - { t.n_per_wmma } -> SizeType; - { t.m_wmma_per_wave } -> SizeType; - { t.n_wmma_per_wave } -> SizeType; -}; - // Concept for vectorized data transfer for convolution input tensors. template concept InputTileThreadClusterDescriptor3D = requires(T t) { From 096592eb995f819da20caa1efc7a874e55b34220 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Wed, 14 Jan 2026 08:29:49 -0500 Subject: [PATCH 77/81] Refactor conv algorithms into more categorized form. --- .../include/ck_tile/builder/factory/README.md | 228 ++++++++++++++++++ .../builder/factory/conv_algorithms.hpp | 87 +++---- .../test/impl/conv_algorithm_types.hpp | 4 +- 3 files changed, 267 insertions(+), 52 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/factory/README.md b/experimental/builder/include/ck_tile/builder/factory/README.md index d1794349ab5..4c30348f9a7 100644 --- a/experimental/builder/include/ck_tile/builder/factory/README.md +++ b/experimental/builder/include/ck_tile/builder/factory/README.md @@ -29,3 +29,231 @@ using Factory = decltype(make_conv_instance()); ``` The dispatcher automatically selects the appropriate factory following explicit logic. + +# Convolution Algorithm Hierarchy + +This section illustrates the hierarchy of convolution algorithm concepts defined in `conv_algorithms.hpp`. + +## Overview + +The convolution algorithms are organized into three main categories: + +1. **XDL Algorithms** - GPU matrix multiplication using XDL (matrix core instructions) +2. **WMMA Algorithms** - GPU matrix multiplication using WMMA (Wave Matrix Multiply-Accumulate) +3. **DL Algorithms** - Special vectorized dot-product kernels optimized for specific data layouts with separate implementation. + +XDL and WMMA algorithms share a common base, while DL algorithms have their own independent base. + +## Common Base Hierarchy (XDL & WMMA) + +Both XDL and WMMA algorithms share the following foundational concepts: + +``` +ConvAlgorithm (Base Concept) +│ +│ Requirements: +│ • ConvAlgorithmDescriptor +│ • SpecifiesThreadBlock +│ • SpecifiesTileTransferParameters (ThreadClusters, LdsTransfer, AccessOrders) +│ • SpecifiesWarpGemm +│ • SpecifiesGemmPipeline +│ +├─── FwdAlgorithm (Forward Convolution) +│ │ +│ │ Additional: SpecifiesFwdConvSpecialization +│ │ +│ └─── FwdAlgorithmV3 +│ │ +│ │ Additional: SpecifiesPipelineV3 +│ │ +│ +└─── BwdAlgorithm (Backward Weight Convolution) + │ + │ Additional: SpecifiesBwdWeightConvSpecialization + │ + └─── BwdAlgorithmV3 + │ + │ Additional: SpecifiesPipelineV3 + │ +``` + +--- + +## XDL Algorithm Hierarchy + +### Forward XDL Algorithms + +``` +FwdAlgorithm + SpecifiesXdl +│ +├─── FwdXdlAlgorithmBase + │ + ├─── FwdXdlAlgorithm + │ │ + │ └─ Requirements: Base + SpecifiesGenericInstance + │ + ├─── LargeTensorAlgorithm + │ │ + │ └─ Requirements: Base + SpecifiesLargeTensorSupport + │ + └─── FwdXdlV3Algorithm + │ + └─ Based on: FwdAlgorithmV3 + SpecifiesXdl +``` + +### Backward XDL Algorithms + +``` +BwdAlgorithm + SpecifiesXdl +│ +├─── BwdXdlAlgorithmBase (ThreadClusterRank=4) +│ │ +│ ├─── BwdXdlAlgorithm +│ │ │ +│ │ └─ Requirements: Base + SpecifiesTransposeTransfer + SpecifiesGenericInstance +│ │ +│ └─── BwdMultiDXdlAlgorithm +│ │ +│ └─ Requirements: Base + SpecifiesMultipleDSupport +│ +└─── BwdXdlV3AlgorithmBase + │ + ├─── BwdXdlV3Algorithm + │ │ + │ └─ Requirements: Base + │ + └─── BwdTwoStageXdlAlgorithm + │ + └─ Requirements: Base + SpecifiesTransposeTransfer + SpecifiesTwoStageSupport +``` + +**Valid XDL Algorithms:** +- FwdXdlAlgorithm +- FwdXdlV3Algorithm +- LargeTensorAlgorithm +- BwdXdlAlgorithm +- BwdXdlV3Algorithm +- BwdTwoStageXdlAlgorithm +- BwdMultiDXdlAlgorithm + +--- + +## WMMA Algorithm Hierarchy + +### Forward WMMA Algorithms + +``` +FwdAlgorithm + SpecifiesWmma +│ +└─── FwdWmmaAlgorithm + │ + └─ Requirements: Base + SpecifiesWmma +``` + +### Backward WMMA Algorithms + +``` +BwdAlgorithm + SpecifiesWmma +│ +├─── BwdWmmaAlgorithmBase (ThreadClusterRank=3) +│ │ +│ └─── BwdWmmaAlgorithm +│ │ +│ └─ Requirements: Base + SpecifiesNumPrefetchStages + SpecifiesGemmPipeline + SpecifiesGenericInstance +│ +└─── BwdWmmaV3AlgorithmBase (Based on BwdAlgorithmV3) + │ + ├─── BwdMultiDWmmaV3Algorithm + │ │ + │ └─ Requirements: Base + SpecifiesMultipleDSupport + │ + ├─── BwdWmmaV3Algorithm + │ │ + │ └─ Requirements: Base + SpecifiesTransposeTransfer + SpecifiesGenericInstance + │ + └─── BwdTwoStageWmmaV3Algorithm + │ + └─ Requirements: Base + SpecifiesTransposeTransfer + SpecifiesTwoStageSupport +``` + +**Valid WMMA Algorithms:** +- FwdWmmaAlgorithm +- BwdWmmaAlgorithm +- BwdWmmaV3Algorithm +- BwdTwoStageWmmaV3Algorithm +- BwdMultiDWmmaV3Algorithm + +--- + +## DL Algorithm Hierarchy + +DL algorithms have a separate base and do not share the common hierarchy with XDL/WMMA algorithms. + +``` +DlAlgorithm +│ +│ Requirements: +│ • ConvAlgorithmDescriptor +│ • SpecifiesThreadBlock +│ • SpecifiesDlThreadConfig +│ • SpecifiesDlThreadCluster +│ • SpecifiesDlEpilogue +│ +├─── FwdDlAlgorithmBase +│ │ +│ │ Requirements: Base + SpecifiesFwdConvSpecialization + SpecifiesDlFwdBlockTransfer + SpecifiesGemmSpecialization +│ │ +│ └─── FwdDlAlgorithm +│ +└─── BwdDlAlgorithm + │ + └─ Requirements: Base + SpecifiesBwdWeightConvSpecialization + SpecifiesDlBwdBlockTransfer +``` + +**Valid DL Algorithms:** +- FwdDlAlgorithm +- BwdDlAlgorithm + +--- + + +## Reference Algorithms + +``` +ReferenceAlgorithm +│ +└─ Requirements: ConvAlgorithmDescriptor + + SpecifiesReferenceAlgorithm +``` + +Used for reference implementations and testing. + +## CK Tile Algorithms + +``` +TileAlgorithm +│ +└─ Requirements: ConvAlgorithmDescriptor + + SpecifiesTileThreadBlock + + SpecifiesTileTransfer + + SpecifiesTileConvSpecialization + + SpecifiesTileBlockGemm + + SpecifiesTileOptimizations +``` + +The CK Tile algorithms are applicable to foward convolution as well as backwards convolution (weight and data). + +--- + +## Summary for XDL/WMMA/DL algorithms + +| Category | Algorithm Type | Forward Variants | Backward Variants | +|----------|---------------|------------------|-------------------| +| **XDL** | Base | FwdXdlAlgorithmBase | BwdXdlAlgorithmBase, BwdXdlV3AlgorithmBase | +| | Concrete | • FwdXdlAlgorithm
• FwdXdlV3Algorithm
• LargeTensorAlgorithm | • BwdXdlAlgorithm
• BwdXdlV3Algorithm
• BwdTwoStageXdlAlgorithm
• BwdMultiDXdlAlgorithm | +| **WMMA** | Base | FwdAlgorithm | BwdWmmaAlgorithmBase, BwdWmmaV3AlgorithmBase | +| | Concrete | • FwdWmmaAlgorithm | • BwdWmmaAlgorithm
• BwdWmmaV3Algorithm
• BwdTwoStageWmmaV3Algorithm
• BwdMultiDWmmaV3Algorithm | +| **DL** | Base | FwdDlAlgorithmBase | DlAlgorithm | +| | Concrete | • FwdDlAlgorithm | • BwdDlAlgorithm | + +--- diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index 3b2f7f8c511..5c169d96a7e 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -7,52 +7,53 @@ namespace ck_tile::builder::factory { -// Base algorithm concepts template -concept TileTransferParameters = +concept SpecifiesTileTransferParameters = SpecifiesThreadClusters && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder; +// Base algorithm concepts +template +concept ConvAlgorithm = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && + SpecifiesTileTransferParameters && + SpecifiesWarpGemm && SpecifiesGemmPipeline;; + +template +concept FwdAlgorithm = ConvAlgorithm && SpecifiesFwdConvSpecialization; + +template +concept FwdAlgorithmV3 = FwdAlgorithm && SpecifiesPipelineV3; + +template +concept BwdAlgorithm = ConvAlgorithm && SpecifiesBwdWeightConvSpecialization; + template -concept SpecifiesTileTransferParameters3D = TileTransferParameters; +concept BwdAlgorithmV3 = BwdAlgorithm && SpecifiesPipelineV3; template -concept SpecifiesTileTransferParameters4D = TileTransferParameters; +concept DlAlgorithm = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && + SpecifiesDlThreadConfig && SpecifiesDlThreadCluster && SpecifiesDlEpilogue; + +template +concept FwdDlAlgorithmBase = DlAlgorithm && SpecifiesFwdConvSpecialization && + SpecifiesDlFwdBlockTransfer && SpecifiesGemmSpecialization; template -concept FwdXdlAlgorithmBase = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && - SpecifiesTileTransferParameters3D && - SpecifiesWarpGemm && SpecifiesFwdConvSpecialization && - SpecifiesGemmPipeline && SpecifiesXdl; +concept FwdXdlAlgorithmBase = FwdAlgorithm && SpecifiesXdl; template -concept BwdXdlAlgorithmBase = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && - SpecifiesTileTransferParameters4D && - SpecifiesWarpGemm && SpecifiesBwdWeightConvSpecialization && - SpecifiesXdl; +concept BwdXdlAlgorithmBase = BwdAlgorithm && SpecifiesXdl; template -concept BwdXdlV3AlgorithmBase = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && - SpecifiesTileTransferParameters3D && - SpecifiesWarpGemm && SpecifiesBwdWeightConvSpecialization && - SpecifiesGemmPipeline && SpecifiesXdl && SpecifiesPipelineV3; +concept BwdXdlV3AlgorithmBase = BwdAlgorithmV3 && SpecifiesXdl; template -concept BwdWmmaAlgorithmBase = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && - SpecifiesTileTransferParameters3D && - SpecifiesWarpGemm && SpecifiesBwdWeightConvSpecialization && - SpecifiesWmma; +concept BwdWmmaAlgorithmBase = BwdAlgorithm && SpecifiesWmma; template -concept BwdWmmaV3AlgorithmBase = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && - SpecifiesTileTransferParameters3D && - SpecifiesWarpGemm && SpecifiesBwdWeightConvSpecialization && - SpecifiesGemmPipeline && SpecifiesWmma && SpecifiesPipelineV3; +concept BwdWmmaV3AlgorithmBase = BwdAlgorithmV3 && SpecifiesWmma; // Reference algorithm concept template @@ -72,28 +73,16 @@ template concept LargeTensorAlgorithm = FwdXdlAlgorithmBase && SpecifiesLargeTensorSupport; template -concept FwdXdlV3Algorithm = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && - SpecifiesTileTransferParameters3D && - SpecifiesWarpGemm && SpecifiesFwdConvSpecialization && - SpecifiesGemmPipeline && SpecifiesXdl && SpecifiesPipelineV3; +concept FwdXdlV3Algorithm = FwdAlgorithmV3 && SpecifiesXdl; // FWD WMMA algorithm concepts template -concept FwdWmmaAlgorithm = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && - SpecifiesTileTransferParameters3D && - SpecifiesWarpGemm && SpecifiesFwdConvSpecialization && - SpecifiesGemmPipeline && SpecifiesWmma; +concept FwdWmmaAlgorithm = FwdAlgorithm && SpecifiesWmma; // FWD DL algorithms template -concept FwdDlAlgorithm = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && - SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && - SpecifiesDlThreadConfig && SpecifiesDlThreadCluster && - SpecifiesDlFwdBlockTransfer && SpecifiesDlEpilogue; - +concept FwdDlAlgorithm = FwdDlAlgorithmBase; + // BWD weight XDL algorithm concepts template concept BwdXdlAlgorithm = @@ -130,11 +119,8 @@ concept BwdTwoStageWmmaV3Algorithm = BwdWmmaV3AlgorithmBase && SpecifiesT // BWD weight DL algorithms template -concept BwdDlAlgorithm = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && - SpecifiesBwdWeightConvSpecialization && SpecifiesDlThreadConfig && - SpecifiesDlThreadCluster && SpecifiesDlBwdBlockTransfer && - SpecifiesDlEpilogue; +concept BwdDlAlgorithm = DlAlgorithm && SpecifiesBwdWeightConvSpecialization && + SpecifiesDlBwdBlockTransfer; // Concepts for valid XDL/WMMA algorithms template @@ -154,7 +140,6 @@ concept SpecifiesValidBwdWmmaAlgorithm = BwdWmmaAlgorithm || BwdWmmaV3Algorithm || BwdTwoStageWmmaV3Algorithm || BwdMultiDWmmaV3Algorithm; - template concept FwdWarpGemmOrDL = SpecifiesValidWarpGemm || FwdDlAlgorithm; diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index 8d315f69b8c..264f69ad811 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -42,7 +42,7 @@ struct GemmPipeline { size_t num_gemm_k_prefetch_stages{1}; size_t num_conv_groups_to_merge{1}; - PipelineVersion pipeline_version; + PipelineVersion pipeline_version{PipelineVersion::V1}; PipelineScheduler scheduler{PipelineScheduler::DEFAULT}; }; static_assert(ckb::GemmPipelineDescriptor); @@ -543,6 +543,7 @@ using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle = WarpGemm_, InputOutputTileTransfer_<4>, ConvSpecializationBwdWeight_, + GemmPipeline_, // Not needed, but we need this to meet the ConvAlgorithm concept. TransposeParams_, AlgorithmSpecialization_<>>; @@ -592,6 +593,7 @@ using ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle = ConvAlgorithmTemplate, + GemmPipeline_, // Not needed, but we need this to meet the ConvAlgorithm concept. ConvSpecializationBwdWeight_, AlgorithmSpecialization_>; From 75d20e08f115e6bcea4a70e0d5fd393ef94f07d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Wed, 14 Jan 2026 09:29:04 -0500 Subject: [PATCH 78/81] clang-format --- .../builder/conv_algorithm_concepts.hpp | 16 +- .../builder/factory/conv_algorithms.hpp | 81 +++-- ...onv_bwd_weight_multi_d_wmma_v3_factory.hpp | 2 +- .../conv_bwd_weight_multi_d_xdl_factory.hpp | 4 +- ...v_bwd_weight_two_stage_wmma_v3_factory.hpp | 2 +- .../conv_bwd_weight_two_stage_xdl_factory.hpp | 4 +- .../factory/conv_bwd_weight_wmma_factory.hpp | 2 +- .../conv_bwd_weight_wmma_v3_factory.hpp | 2 +- .../factory/conv_bwd_weight_xdl_factory.hpp | 5 +- .../conv_bwd_weight_xdl_v3_factory.hpp | 4 +- .../builder/factory/conv_dispatcher.hpp | 67 ++-- .../builder/factory/conv_fwd_v3_factory.hpp | 4 +- .../helpers/ck/conv_block_transfer.hpp | 76 ++-- .../factory/helpers/ck/conv_tuning_params.hpp | 9 +- .../builder/include/ck_tile/builder/types.hpp | 10 +- ...test_ckb_conv_bwd_weight_wmma_cshuffle.cpp | 13 +- .../conv/ck/test_ckb_conv_fwd_2d_bf16.cpp | 15 +- .../test/impl/conv_algorithm_types.hpp | 28 +- .../builder/test/test_conv_description.cpp | 53 +-- .../test/utils/ckb_conv_test_configs.hpp | 344 ++++++++++-------- .../test/utils/conv_algorithm_type_utils.hpp | 83 +++-- 21 files changed, 439 insertions(+), 385 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index adec5643062..ce60bf4a94a 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -64,8 +64,9 @@ concept InputTileThreadClusterDescriptor4D = requires(T t) { }; template -concept InputTileThreadClusterDescriptor = (ThreadClusterRank == 3 && InputTileThreadClusterDescriptor3D) || - (ThreadClusterRank == 4 && InputTileThreadClusterDescriptor4D); +concept InputTileThreadClusterDescriptor = + (ThreadClusterRank == 3 && InputTileThreadClusterDescriptor3D) || + (ThreadClusterRank == 4 && InputTileThreadClusterDescriptor4D); // Concept for thread cluster dimensions for GEMM output tensor. template @@ -285,7 +286,6 @@ template concept TransposeTransferWellDefinedIfProvided = !HasTransposeTransfer || SpecifiesTransposeTransfer; - /******************************************** */ /* Algorithm specialization concepts */ /******************************************** */ @@ -328,14 +328,12 @@ concept SpecifiesGenericInstance = !requires { }; template -concept SpecifiesXdl = requires { - requires Algorithm.warp_gemm.matrix_instruction == MatrixInstructionType::XDL; -}; +concept SpecifiesXdl = + requires { requires Algorithm.warp_gemm.matrix_instruction == MatrixInstructionType::XDL; }; template -concept SpecifiesWmma = requires { - requires Algorithm.warp_gemm.matrix_instruction == MatrixInstructionType::WMMA; -}; +concept SpecifiesWmma = + requires { requires Algorithm.warp_gemm.matrix_instruction == MatrixInstructionType::WMMA; }; template concept SpecifiesValidWarpGemm = SpecifiesXdl || SpecifiesWmma; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index 5c169d96a7e..ad72609f848 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -14,31 +14,32 @@ concept SpecifiesTileTransferParameters = // Base algorithm concepts template -concept ConvAlgorithm = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && - SpecifiesTileTransferParameters && - SpecifiesWarpGemm && SpecifiesGemmPipeline;; +concept ConvAlgorithm = ConvAlgorithmDescriptor && SpecifiesThreadBlock && + SpecifiesTileTransferParameters && + SpecifiesWarpGemm && SpecifiesGemmPipeline; +; template -concept FwdAlgorithm = ConvAlgorithm && SpecifiesFwdConvSpecialization; +concept FwdAlgorithm = ConvAlgorithm && SpecifiesFwdConvSpecialization; template concept FwdAlgorithmV3 = FwdAlgorithm && SpecifiesPipelineV3; template -concept BwdAlgorithm = ConvAlgorithm && SpecifiesBwdWeightConvSpecialization; +concept BwdAlgorithm = + ConvAlgorithm && SpecifiesBwdWeightConvSpecialization; template concept BwdAlgorithmV3 = BwdAlgorithm && SpecifiesPipelineV3; template -concept DlAlgorithm = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && - SpecifiesDlThreadConfig && SpecifiesDlThreadCluster && SpecifiesDlEpilogue; +concept DlAlgorithm = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesDlThreadConfig && + SpecifiesDlThreadCluster && SpecifiesDlEpilogue; template -concept FwdDlAlgorithmBase = DlAlgorithm && SpecifiesFwdConvSpecialization && - SpecifiesDlFwdBlockTransfer && SpecifiesGemmSpecialization; +concept FwdDlAlgorithmBase = DlAlgorithm && SpecifiesFwdConvSpecialization && + SpecifiesDlFwdBlockTransfer && SpecifiesGemmSpecialization; template concept FwdXdlAlgorithmBase = FwdAlgorithm && SpecifiesXdl; @@ -57,20 +58,23 @@ concept BwdWmmaV3AlgorithmBase = BwdAlgorithmV3 && SpecifiesWmm // Reference algorithm concept template -concept ReferenceAlgorithm = ConvAlgorithmDescriptor && SpecifiesReferenceAlgorithm; +concept ReferenceAlgorithm = + ConvAlgorithmDescriptor && SpecifiesReferenceAlgorithm; // Tile-based algorithm concept template -concept TileAlgorithm = ConvAlgorithmDescriptor && SpecifiesTileThreadBlock && - SpecifiesTileTransfer && SpecifiesTileConvSpecialization && - SpecifiesTileBlockGemm && SpecifiesTileOptimizations; +concept TileAlgorithm = + ConvAlgorithmDescriptor && SpecifiesTileThreadBlock && + SpecifiesTileTransfer && SpecifiesTileConvSpecialization && + SpecifiesTileBlockGemm && SpecifiesTileOptimizations; // FWD XDL algorithm concepts template concept FwdXdlAlgorithm = FwdXdlAlgorithmBase && SpecifiesGenericInstance; template -concept LargeTensorAlgorithm = FwdXdlAlgorithmBase && SpecifiesLargeTensorSupport; +concept LargeTensorAlgorithm = + FwdXdlAlgorithmBase && SpecifiesLargeTensorSupport; template concept FwdXdlV3Algorithm = FwdAlgorithmV3 && SpecifiesXdl; @@ -82,63 +86,68 @@ concept FwdWmmaAlgorithm = FwdAlgorithm && SpecifiesWmma // FWD DL algorithms template concept FwdDlAlgorithm = FwdDlAlgorithmBase; - + // BWD weight XDL algorithm concepts template concept BwdXdlAlgorithm = - BwdXdlAlgorithmBase && SpecifiesTransposeTransfer && + BwdXdlAlgorithmBase && SpecifiesTransposeTransfer && SpecifiesGenericInstance; template -concept BwdMultiDXdlAlgorithm = BwdXdlAlgorithmBase && SpecifiesMultipleDSupport; +concept BwdMultiDXdlAlgorithm = + BwdXdlAlgorithmBase && SpecifiesMultipleDSupport; template concept BwdXdlV3Algorithm = BwdXdlV3AlgorithmBase; template -concept BwdTwoStageXdlAlgorithm = BwdXdlV3AlgorithmBase && SpecifiesTransposeTransfer && - SpecifiesTwoStageSupport; +concept BwdTwoStageXdlAlgorithm = + BwdXdlV3AlgorithmBase && SpecifiesTransposeTransfer && + SpecifiesTwoStageSupport; // BWD weight WMMA algorithm concepts template concept BwdWmmaAlgorithm = - BwdWmmaAlgorithmBase && SpecifiesNumPrefetchStages && + BwdWmmaAlgorithmBase && SpecifiesNumPrefetchStages && SpecifiesGemmPipeline && SpecifiesGenericInstance; template -concept BwdMultiDWmmaV3Algorithm = BwdWmmaV3AlgorithmBase && SpecifiesMultipleDSupport; +concept BwdMultiDWmmaV3Algorithm = + BwdWmmaV3AlgorithmBase && SpecifiesMultipleDSupport; template concept BwdWmmaV3Algorithm = - BwdWmmaV3AlgorithmBase && SpecifiesTransposeTransfer && + BwdWmmaV3AlgorithmBase && SpecifiesTransposeTransfer && SpecifiesGenericInstance; template -concept BwdTwoStageWmmaV3Algorithm = BwdWmmaV3AlgorithmBase && SpecifiesTransposeTransfer && - SpecifiesTwoStageSupport; +concept BwdTwoStageWmmaV3Algorithm = + BwdWmmaV3AlgorithmBase && SpecifiesTransposeTransfer && + SpecifiesTwoStageSupport; // BWD weight DL algorithms template -concept BwdDlAlgorithm = DlAlgorithm && SpecifiesBwdWeightConvSpecialization && - SpecifiesDlBwdBlockTransfer; +concept BwdDlAlgorithm = + DlAlgorithm && SpecifiesBwdWeightConvSpecialization && + SpecifiesDlBwdBlockTransfer; // Concepts for valid XDL/WMMA algorithms template -concept SpecifiesValidFwdXdlAlgorithm = -FwdXdlAlgorithm || FwdXdlV3Algorithm || LargeTensorAlgorithm; +concept SpecifiesValidFwdXdlAlgorithm = + FwdXdlAlgorithm || FwdXdlV3Algorithm || LargeTensorAlgorithm; template concept SpecifiesValidFwdWmmaAlgorithm = FwdWmmaAlgorithm; template -concept SpecifiesValidBwdXdlAlgorithm = - BwdXdlAlgorithm || BwdXdlV3Algorithm || - BwdTwoStageXdlAlgorithm || BwdMultiDXdlAlgorithm; +concept SpecifiesValidBwdXdlAlgorithm = + BwdXdlAlgorithm || BwdXdlV3Algorithm || BwdTwoStageXdlAlgorithm || + BwdMultiDXdlAlgorithm; template -concept SpecifiesValidBwdWmmaAlgorithm = - BwdWmmaAlgorithm || BwdWmmaV3Algorithm || - BwdTwoStageWmmaV3Algorithm || BwdMultiDWmmaV3Algorithm; +concept SpecifiesValidBwdWmmaAlgorithm = + BwdWmmaAlgorithm || BwdWmmaV3Algorithm || BwdTwoStageWmmaV3Algorithm || + BwdMultiDWmmaV3Algorithm; template concept FwdWarpGemmOrDL = SpecifiesValidWarpGemm || FwdDlAlgorithm; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp index 8333a784ba7..9f86028c337 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp @@ -34,7 +34,7 @@ struct ConvBwdWeightMultiDWmmaV3Factory static constexpr auto BWD_CONV_SPECIALIZATION = internal::SetBwdWeightConvSpecialization(); - static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); static constexpr auto WARP_GEMM = ALGORITHM.warp_gemm; static constexpr auto A_BLOCK_TRANSFER = internal::SetBwdConvBlockTransfer(); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp index 9f8dae1de8b..ec67a3c6b09 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp @@ -34,8 +34,8 @@ struct ConvBwdWeightMultiDXdlFactory static constexpr auto BWD_CONV_SPECIALIZATION = internal::SetBwdWeightConvSpecialization(); - static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm; + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm; static constexpr auto A_BLOCK_TRANSFER = internal::SetBwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp index 9409f7e8bd3..afdc5e1a17f 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp @@ -34,7 +34,7 @@ struct ConvBwdWeightTwoStageWmmaV3Factory static constexpr auto BWD_CONV_SPECIALIZATION = internal::SetBwdWeightConvSpecialization(); - static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); static constexpr auto WARP_GEMM = ALGORITHM.warp_gemm; static constexpr auto A_BLOCK_TRANSFER = internal::SetBwdConvBlockTransfer(); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp index 986b9fd946d..cac6de591a8 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp @@ -34,8 +34,8 @@ struct ConvBwdWeightTwoStageXdlFactory static constexpr auto BWD_CONV_SPECIALIZATION = internal::SetBwdWeightConvSpecialization(); - static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm; + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm; static constexpr auto A_BLOCK_TRANSFER = internal::SetBwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp index a1fc53f7c3e..0fd089b6c08 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp @@ -34,7 +34,7 @@ struct ConvBwdWeightWmmaFactory static constexpr auto BWD_CONV_SPECIALIZATION = internal::SetBwdWeightConvSpecialization(); - static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); static constexpr auto WARP_GEMM = ALGORITHM.warp_gemm; static constexpr auto GRIDWISE_GEMM_PIPELINE_VERSION = internal::SetGridwiseGemmPipelineVersion(); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp index ba5302f82c9..cb3c87905d8 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp @@ -34,7 +34,7 @@ struct ConvBwdWeightWmmaV3Factory static constexpr auto BWD_CONV_SPECIALIZATION = internal::SetBwdWeightConvSpecialization(); - static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); static constexpr auto WARP_GEMM = ALGORITHM.warp_gemm; static constexpr auto A_BLOCK_TRANSFER = internal::SetBwdConvBlockTransfer(); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp index 1b0cc36747a..b382b7d3d9e 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp @@ -34,8 +34,8 @@ struct ConvBwdWeightXdlFactory static constexpr auto BWD_CONV_SPECIALIZATION = internal::SetBwdWeightConvSpecialization(); - static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm; + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm; static constexpr auto A_BLOCK_TRANSFER = internal::SetBwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -57,7 +57,6 @@ struct ConvBwdWeightXdlFactory "A nd B block transfer vector load size need to be the same"); static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size; - // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle< SPATIAL_DIM, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp index f37da289c74..9c9bbd21af9 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp @@ -34,8 +34,8 @@ struct ConvBwdWeightXdlV3Factory static constexpr auto BWD_CONV_SPECIALIZATION = internal::SetBwdWeightConvSpecialization(); - static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm; + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm; static constexpr auto A_BLOCK_TRANSFER = internal::SetBwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp index 965a231cd1c..cb51cb70dc7 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -117,7 +117,7 @@ constexpr auto make_conv_instance() // Forward direction (supports most algorithm variants) else if constexpr(ConvDirectionIsForward) { - if constexpr (SpecifiesXdl) + if constexpr(SpecifiesXdl) { if constexpr(FwdXdlV3Algorithm) { @@ -129,30 +129,31 @@ constexpr auto make_conv_instance() } else if constexpr(LargeTensorAlgorithm) { - return typename ConvFwdLargeTensorFactory::Instance{}; + return + typename ConvFwdLargeTensorFactory::Instance{}; } else { static_assert( SpecifiesValidFwdXdlAlgorithm, - "No suitable forward convolution XDL kernel factory found for the provided ALGORITHM. " + "No suitable forward convolution XDL kernel factory found for the provided " + "ALGORITHM. " "The ALGORITHM must satisfy requirements for one of: XDL V3, generic XDL, " "DL (NHWC layout), or Large Tensor variant."); } } - else if constexpr (SpecifiesWmma) + else if constexpr(SpecifiesWmma) { if constexpr(FwdWmmaAlgorithm) { return typename ConvFwdWmmaFactory::Instance{}; } - else + else { - static_assert(FwdWmmaAlgorithm, - "Did not find matching WMMA factory."); + static_assert(FwdWmmaAlgorithm, "Did not find matching WMMA factory."); } } - else if constexpr (FwdDlAlgorithm) + else if constexpr(FwdDlAlgorithm) { return typename ConvFwdDlFactory::Instance{}; } @@ -173,22 +174,23 @@ constexpr auto make_conv_instance() // Backward weight direction (will expand with more algorithms in the future) else if constexpr(ConvDirectionIsBackwardWeight) { - if constexpr (SpecifiesXdl) + if constexpr(SpecifiesXdl) { // Start from more specialized and end with least specialized. if constexpr(BwdTwoStageXdlAlgorithm) { - return - typename ConvBwdWeightTwoStageXdlFactory::Instance{}; + return typename ConvBwdWeightTwoStageXdlFactory:: + Instance{}; } else if constexpr(BwdMultiDXdlAlgorithm) { - return - typename ConvBwdWeightMultiDXdlFactory::Instance{}; + return typename ConvBwdWeightMultiDXdlFactory:: + Instance{}; } else if constexpr(BwdXdlV3Algorithm) { - return typename ConvBwdWeightXdlV3Factory::Instance{}; + return + typename ConvBwdWeightXdlV3Factory::Instance{}; } else if constexpr(BwdXdlAlgorithm) { @@ -196,46 +198,51 @@ constexpr auto make_conv_instance() } else { - static_assert( - SpecifiesValidBwdXdlAlgorithm, - "No suitable backward weight convolution XDL kernel factory found for the provided ALGORITHM. " - "The ALGORITHM must satisfy requirements for one of: Two-Stage XDL, Multi-D XDL, DL, " - "generic XDL, or XDL V3 variant."); + static_assert(SpecifiesValidBwdXdlAlgorithm, + "No suitable backward weight convolution XDL kernel factory found " + "for the provided ALGORITHM. " + "The ALGORITHM must satisfy requirements for one of: Two-Stage XDL, " + "Multi-D XDL, DL, " + "generic XDL, or XDL V3 variant."); } } - else if constexpr (SpecifiesWmma) + else if constexpr(SpecifiesWmma) { // Start from more specialized and end with least specialized. if constexpr(BwdTwoStageWmmaV3Algorithm) { - return typename ConvBwdWeightTwoStageWmmaV3Factory::Instance{}; + return typename ConvBwdWeightTwoStageWmmaV3Factory:: + Instance{}; } else if constexpr(BwdMultiDWmmaV3Algorithm) { - return typename ConvBwdWeightMultiDWmmaV3Factory::Instance{}; + return typename ConvBwdWeightMultiDWmmaV3Factory:: + Instance{}; } else if constexpr(BwdWmmaV3Algorithm) { - return typename ConvBwdWeightWmmaV3Factory::Instance{}; + return + typename ConvBwdWeightWmmaV3Factory::Instance{}; } else if constexpr(BwdWmmaAlgorithm) { return typename ConvBwdWeightWmmaFactory::Instance{}; } - else + else { - static_assert( - SpecifiesValidBwdWmmaAlgorithm, - "No suitable backward weight convolution WMMA kernel factory found for the provided ALGORITHM. " - "The ALGORITHM must satisfy requirements for one of: Two-Stage WMMA V3, Multi-D WMMA V3, " - "WMMA V3, or generic WMMA variant."); + static_assert(SpecifiesValidBwdWmmaAlgorithm, + "No suitable backward weight convolution WMMA kernel factory found " + "for the provided ALGORITHM. " + "The ALGORITHM must satisfy requirements for one of: Two-Stage WMMA " + "V3, Multi-D WMMA V3, " + "WMMA V3, or generic WMMA variant."); } } else if constexpr(BwdDlAlgorithm) { return typename ConvBwdWeightDlFactory::Instance{}; } - else + else { static_assert(BwdWarpGemmOrDL, "Backward convolution: Algorithm must specify either DL, XDL or WMMA."); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp index e6f71b0b258..ad6f91b0a4d 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp @@ -41,8 +41,8 @@ struct ConvFwdXdlV3Factory static constexpr internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION, .gemm_spec = GEMM_SPECIALIZATION}; - static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm; + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm; static constexpr auto A_BLOCK_TRANSFER = internal::SetFwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp index c48404dd470..6259e239a9c 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp @@ -17,11 +17,11 @@ struct BlockTransfer ck::Array thread_cluster_order{}; ck::Array src_access_order{}; size_t global_memory_vector_load_size = 0; - size_t src_vector_dim = 0; - size_t src_scalar_per_vector = 0; - size_t lds_dst_scalar_per_vector = 0; - bool is_direct_load = false; - bool lds_padding = false; + size_t src_vector_dim = 0; + size_t src_scalar_per_vector = 0; + size_t lds_dst_scalar_per_vector = 0; + bool is_direct_load = false; + bool lds_padding = false; }; template @@ -33,15 +33,15 @@ constexpr BlockTransfer<> SetFwdConvBlockTransfer() auto& lds_cfg = TRANSFER.lds_transfer; return BlockTransfer<>{ - .thread_cluster_dims = {block_xfer.k0, block_xfer.m_n, block_xfer.k1}, - .thread_cluster_order = {block_order.order[0], block_order.order[1], block_order.order[2]}, - .src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2]}, + .thread_cluster_dims = {block_xfer.k0, block_xfer.m_n, block_xfer.k1}, + .thread_cluster_order = {block_order.order[0], block_order.order[1], block_order.order[2]}, + .src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2]}, .global_memory_vector_load_size = lds_cfg.global_memory_vector_load_size, - .src_vector_dim = lds_cfg.src_vector_dim, - .src_scalar_per_vector = lds_cfg.src_scalar_per_vector, - .lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector, - .is_direct_load = lds_cfg.is_direct_load, - .lds_padding = lds_cfg.lds_padding, + .src_vector_dim = lds_cfg.src_vector_dim, + .src_scalar_per_vector = lds_cfg.src_scalar_per_vector, + .lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector, + .is_direct_load = lds_cfg.is_direct_load, + .lds_padding = lds_cfg.lds_padding, }; } @@ -60,38 +60,38 @@ constexpr auto SetBwdConvBlockTransfer() if constexpr(array_length == 3) { return BlockTransfer<3>{ - .thread_cluster_dims = {block_xfer.k0, block_xfer.m_n, block_xfer.k1}, - .thread_cluster_order = {block_order.order[0], - block_order.order[1], - block_order.order[2]}, - .src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2]}, + .thread_cluster_dims = {block_xfer.k0, block_xfer.m_n, block_xfer.k1}, + .thread_cluster_order = {block_order.order[0], + block_order.order[1], + block_order.order[2]}, + .src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2]}, .global_memory_vector_load_size = lds_cfg.global_memory_vector_load_size, - .src_vector_dim = lds_cfg.src_vector_dim, - .src_scalar_per_vector = lds_cfg.src_scalar_per_vector, - .lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector, - .lds_padding = lds_cfg.lds_padding, + .src_vector_dim = lds_cfg.src_vector_dim, + .src_scalar_per_vector = lds_cfg.src_scalar_per_vector, + .lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector, + .lds_padding = lds_cfg.lds_padding, }; } else if constexpr(array_length == 4) { return BlockTransfer<4>{ - .thread_cluster_dims = {block_xfer.k_batch_size, - block_xfer.k0, - block_xfer.m_n, - block_xfer.k1}, - .thread_cluster_order = {block_order.order[0], - block_order.order[1], - block_order.order[2], - block_order.order[3]}, - .src_access_order = {src_order.order[0], - src_order.order[1], - src_order.order[2], - src_order.order[3]}, + .thread_cluster_dims = {block_xfer.k_batch_size, + block_xfer.k0, + block_xfer.m_n, + block_xfer.k1}, + .thread_cluster_order = {block_order.order[0], + block_order.order[1], + block_order.order[2], + block_order.order[3]}, + .src_access_order = {src_order.order[0], + src_order.order[1], + src_order.order[2], + src_order.order[3]}, .global_memory_vector_load_size = lds_cfg.global_memory_vector_load_size, - .src_vector_dim = lds_cfg.src_vector_dim, - .src_scalar_per_vector = lds_cfg.src_scalar_per_vector, - .lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector, - .lds_padding = lds_cfg.lds_padding, + .src_vector_dim = lds_cfg.src_vector_dim, + .src_scalar_per_vector = lds_cfg.src_scalar_per_vector, + .lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector, + .lds_padding = lds_cfg.lds_padding, }; } else diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp index 8db9607f343..f123edbfa95 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp @@ -65,11 +65,10 @@ consteval BlockGemmSpec SetBlockGemm() default: throw "Unknown PipelineVersion"; } - return BlockGemmSpec{ - .num_conv_groups_to_merge = BG.num_conv_groups_to_merge, - .num_gemm_k_prefetch_stages = BG.num_gemm_k_prefetch_stages, - .pipeline_version = version, - .scheduler = scheduler}; + return BlockGemmSpec{.num_conv_groups_to_merge = BG.num_conv_groups_to_merge, + .num_gemm_k_prefetch_stages = BG.num_gemm_k_prefetch_stages, + .pipeline_version = version, + .scheduler = scheduler}; } template diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index af9cc44115b..4d09ad8cb4e 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -232,12 +232,12 @@ enum class PipelineScheduler enum class ConvAlgorithmSpecialization { - NONE = 0, + NONE = 0, LARGE_TENSOR = 1 << 0, - REFERENCE = 1 << 1, // GPU reference implementation for validation, - TWO_STAGE = 1 << 2, - MULTIPLE_D = 1 << 3, - PIPELINE_V3 = 1 << 4 + REFERENCE = 1 << 1, // GPU reference implementation for validation, + TWO_STAGE = 1 << 2, + MULTIPLE_D = 1 << 3, + PIPELINE_V3 = 1 << 4 }; constexpr ConvAlgorithmSpecialization operator|(ConvAlgorithmSpecialization lhs, diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp index 562599da936..136c8fea40d 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp @@ -19,12 +19,13 @@ constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 3, .weight = {.config = {.layout = GKZYXC}}, .output = {.config = {.layout = NGKDHW}}}; -constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle{} - .with_thread_block(cku::ThreadBlock_64_32x32x32) - .with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave) - .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) - .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT) - .with_gemm_pipeline(ckb::PipelineVersion::V1, ckb::PipelineScheduler::DEFAULT); +constexpr auto ALGORITHM = + cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle{} + .with_thread_block(cku::ThreadBlock_64_32x32x32) + .with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave) + .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) + .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT) + .with_gemm_pipeline(ckb::PipelineVersion::V1, ckb::PipelineScheduler::DEFAULT); using Builder = ckb::ConvBuilder; using Instance = Builder::Instance; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp index ebc30272c9d..9ed5758b4e8 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp @@ -62,14 +62,13 @@ TEST(FwdConvInstances, .weight = {.config = {.layout = GKYXC}}, .output = {.config = {.layout = NHWGK}}}; - constexpr auto FwdConvAlgorithm = - ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_CShuffle_V3{} - .with_thread_block(ThreadBlock_256_256x256x32) - .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) - .with_transfer(Transfer_4x64x1) - .with_fwd_specializations(ConvSpecialization::FILTER_3x3, - GemmSpecialization::MNKPadding) - .with_gemm_pipeline(BlockGemmDesc_v5_intrawave); + constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_CShuffle_V3{} + .with_thread_block(ThreadBlock_256_256x256x32) + .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) + .with_transfer(Transfer_4x64x1) + .with_fwd_specializations(ConvSpecialization::FILTER_3x3, + GemmSpecialization::MNKPadding) + .with_gemm_pipeline(BlockGemmDesc_v5_intrawave); using Builder = ConvBuilder; diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index 264f69ad811..cf4bc25c0f9 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -31,10 +31,10 @@ static_assert(ckb::ThreadBlockDescriptor); struct WarpGemmParams { MatrixInstructionType matrix_instruction; - size_t gemm_m_per_instruction = 0; - size_t gemm_n_per_instruction = 0; - size_t gemm_m_iters_per_wave = 0; - size_t gemm_n_iters_per_wave = 0; + size_t gemm_m_per_instruction = 0; + size_t gemm_n_per_instruction = 0; + size_t gemm_m_iters_per_wave = 0; + size_t gemm_n_iters_per_wave = 0; }; static_assert(ckb::WarpGemmDescriptor); @@ -370,7 +370,7 @@ struct ConvAlgorithmTemplate : Components... constexpr auto with_num_conv_groups_to_merge(size_t num_groups_to_merge) const { static_assert(std::is_base_of_v); - auto result = *this; + auto result = *this; result.gemm_pipeline.num_conv_groups_to_merge = num_groups_to_merge; return result; } @@ -378,7 +378,7 @@ struct ConvAlgorithmTemplate : Components... constexpr auto with_num_gemm_k_prefetch_stages(size_t num_prefetch_stages) const { static_assert(std::is_base_of_v); - auto result = *this; + auto result = *this; result.gemm_pipeline.num_gemm_k_prefetch_stages = num_prefetch_stages; return result; } @@ -387,7 +387,7 @@ struct ConvAlgorithmTemplate : Components... constexpr auto with_gemm_pipeline(const PL& pl) const { static_assert(std::is_base_of_v); - auto result = *this; + auto result = *this; result.gemm_pipeline = pl; return result; } @@ -395,7 +395,7 @@ struct ConvAlgorithmTemplate : Components... constexpr auto with_gemm_pipeline(const PipelineVersion plv) const { static_assert(std::is_base_of_v); - auto result = *this; + auto result = *this; result.gemm_pipeline.pipeline_version = plv; return result; } @@ -403,7 +403,7 @@ struct ConvAlgorithmTemplate : Components... constexpr auto with_gemm_pipeline(const PipelineScheduler sch) const { static_assert(std::is_base_of_v); - auto result = *this; + auto result = *this; result.gemm_pipeline.scheduler = sch; return result; } @@ -411,7 +411,7 @@ struct ConvAlgorithmTemplate : Components... constexpr auto with_gemm_pipeline(const PipelineVersion plv, const PipelineScheduler sch) const { static_assert(std::is_base_of_v); - auto result = *this; + auto result = *this; result.gemm_pipeline.pipeline_version = plv; result.gemm_pipeline.scheduler = sch; return result; @@ -535,7 +535,7 @@ using ConvAlgorithm_Tile_GroupedConvolutionKernel = ConvAlgorithmTemplate; // Reference algorithm descriptor - for GPU reference validation -using ConvAlgorithm_Reference = ConvAlgorithmTemplate>; +using ConvAlgorithm_Reference = ConvAlgorithmTemplate>; // Bwd weight algorithm types using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle = @@ -543,7 +543,8 @@ using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle = WarpGemm_, InputOutputTileTransfer_<4>, ConvSpecializationBwdWeight_, - GemmPipeline_, // Not needed, but we need this to meet the ConvAlgorithm concept. + GemmPipeline_, // Not needed, but we need this to meet the ConvAlgorithm + // concept. TransposeParams_, AlgorithmSpecialization_<>>; @@ -593,7 +594,8 @@ using ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle = ConvAlgorithmTemplate, - GemmPipeline_, // Not needed, but we need this to meet the ConvAlgorithm concept. + GemmPipeline_, // Not needed, but we need this to meet the ConvAlgorithm + // concept. ConvSpecializationBwdWeight_, AlgorithmSpecialization_>; diff --git a/experimental/builder/test/test_conv_description.cpp b/experimental/builder/test/test_conv_description.cpp index 760ada78eae..c8cfbe6b196 100644 --- a/experimental/builder/test/test_conv_description.cpp +++ b/experimental/builder/test/test_conv_description.cpp @@ -117,57 +117,60 @@ static_assert(!ckb::ConvSignatureDescriptor transfer{ .a = { - .thread_cluster = {.k0 = 1, .m_n = 128, .k1 = 2}, - .lds_transfer = {.global_memory_vector_load_size = 8, - .src_vector_dim = 2, - .src_scalar_per_vector = 2, - .lds_dst_scalar_per_vector = 2, - .is_direct_load = false, - .lds_padding = false}, + .thread_cluster = {.k0 = 1, .m_n = 128, .k1 = 2}, + .lds_transfer = {.global_memory_vector_load_size = 8, + .src_vector_dim = 2, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 2, + .is_direct_load = false, + .lds_padding = false}, .thread_cluster_access_order = {.order = {0, 1, 2}}, .src_access_order = {.order = {0, 1, 2}}, }, .b = { - .thread_cluster = {.k0 = 1, .m_n = 128, .k1 = 2}, - .lds_transfer = {.global_memory_vector_load_size = 8, - .src_vector_dim = 2, - .src_scalar_per_vector = 2, - .lds_dst_scalar_per_vector = 2, - .is_direct_load = false, - .lds_padding = false}, + .thread_cluster = {.k0 = 1, .m_n = 128, .k1 = 2}, + .lds_transfer = {.global_memory_vector_load_size = 8, + .src_vector_dim = 2, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 2, + .is_direct_load = false, + .lds_padding = false}, .thread_cluster_access_order = {.order = {0, 1, 2}}, .src_access_order = {.order = {0, 1, 2}}, }, .c = { - .thread_cluster = - {.gemm_m_block_size = 1, .gemm_m_per_block = 32, .gemm_n_block_size = 1, .gemm_n_per_block = 8}, - .epilogue = {.m_xdl_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, - .scalar_per_vector = 2}, + .thread_cluster = {.gemm_m_block_size = 1, + .gemm_m_per_block = 32, + .gemm_n_block_size = 1, + .gemm_n_per_block = 8}, + .epilogue = {.m_xdl_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 2}, }, }; ckb::ConvSpecialization fwd_specialization = ckb::ConvSpecialization::DEFAULT; ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::Default; ckb::test::GemmPipeline gemm_pipeline{.pipeline_version = ckb::PipelineVersion::V4, - .scheduler =ckb::PipelineScheduler::INTRAWAVE}; + .scheduler = ckb::PipelineScheduler::INTRAWAVE}; }; static_assert(ckb::ConvAlgorithmDescriptor); diff --git a/experimental/builder/test/utils/ckb_conv_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_test_configs.hpp index a281ca28ad6..4209d708cdf 100644 --- a/experimental/builder/test/utils/ckb_conv_test_configs.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_configs.hpp @@ -55,13 +55,13 @@ constexpr DlTransfer<5> DlTransfer5D{.a = DlBlockTransfer_1x8x1x1x1, constexpr InputOutputTileTransfer<> Transfer_4x64x1{ .a = { - .thread_cluster = {.k0 = 4, .m_n = 64, .k1 = 1}, + .thread_cluster = {.k0 = 4, .m_n = 64, .k1 = 1}, .lds_transfer = {.global_memory_vector_load_size = 8, - .src_vector_dim = 2, - .src_scalar_per_vector = 2, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = false}, + .src_vector_dim = 2, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = false}, .thread_cluster_access_order = {1, 0, 2}, .src_access_order = {1, 0, 2}, }, @@ -69,21 +69,23 @@ constexpr InputOutputTileTransfer<> Transfer_4x64x1{ { .thread_cluster = {.k0 = 4, .m_n = 64, .k1 = 1}, .lds_transfer = {.global_memory_vector_load_size = 8, - .src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = false}, + .src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = false}, .thread_cluster_access_order = {1, 0, 2}, .src_access_order = {1, 0, 2}, }, .c = { - .thread_cluster = - {.gemm_m_block_size = 1, .gemm_m_per_block = 32, .gemm_n_block_size = 1, .gemm_n_per_block = 8}, - .epilogue = {.m_xdl_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, - .scalar_per_vector = 4}, + .thread_cluster = {.gemm_m_block_size = 1, + .gemm_m_per_block = 32, + .gemm_n_block_size = 1, + .gemm_n_per_block = 8}, + .epilogue = {.m_xdl_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 4}, }, }; @@ -92,11 +94,11 @@ constexpr InputOutputTileTransfer<4> BwdTransfer_4x64x1{ { .thread_cluster = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1}, .lds_transfer = {.global_memory_vector_load_size = 8, - .src_vector_dim = 2, - .src_scalar_per_vector = 2, - .lds_dst_scalar_per_vector = 4, - .is_direct_load = false, - .lds_padding = true}, + .src_vector_dim = 2, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 4, + .is_direct_load = false, + .lds_padding = true}, .thread_cluster_access_order = {0, 3, 1, 2}, .src_access_order = {0, 2, 1, 3}, }, @@ -104,21 +106,23 @@ constexpr InputOutputTileTransfer<4> BwdTransfer_4x64x1{ { .thread_cluster = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1}, .lds_transfer = {.global_memory_vector_load_size = 8, - .src_vector_dim = 2, - .src_scalar_per_vector = 2, - .lds_dst_scalar_per_vector = 4, - .is_direct_load = false, - .lds_padding = true}, + .src_vector_dim = 2, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 4, + .is_direct_load = false, + .lds_padding = true}, .thread_cluster_access_order = {0, 3, 1, 2}, .src_access_order = {0, 2, 1, 3}, }, .c = { - .thread_cluster = - {.gemm_m_block_size = 1, .gemm_m_per_block = 32, .gemm_n_block_size = 1, .gemm_n_per_block = 8}, - .epilogue = {.m_xdl_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, - .scalar_per_vector = 8}, + .thread_cluster = {.gemm_m_block_size = 1, + .gemm_m_per_block = 32, + .gemm_n_block_size = 1, + .gemm_n_per_block = 8}, + .epilogue = {.m_xdl_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, }, }; @@ -127,11 +131,11 @@ constexpr InputOutputTileTransfer<> BwdTransfer_4x8x1_4x16x1_v3{ { .thread_cluster = {.k0 = 4, .m_n = 8, .k1 = 1}, .lds_transfer = {.global_memory_vector_load_size = 8, - .src_vector_dim = 1, - .src_scalar_per_vector = 2, - .lds_dst_scalar_per_vector = 2, - .is_direct_load = false, - .lds_padding = false}, + .src_vector_dim = 1, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 2, + .is_direct_load = false, + .lds_padding = false}, .thread_cluster_access_order = {2, 0, 1}, .src_access_order = {1, 0, 2}, }, @@ -139,21 +143,23 @@ constexpr InputOutputTileTransfer<> BwdTransfer_4x8x1_4x16x1_v3{ { .thread_cluster = {.k0 = 4, .m_n = 16, .k1 = 1}, .lds_transfer = {.global_memory_vector_load_size = 8, - .src_vector_dim = 1, - .src_scalar_per_vector = 2, - .lds_dst_scalar_per_vector = 2, - .is_direct_load = false, - .lds_padding = false}, + .src_vector_dim = 1, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 2, + .is_direct_load = false, + .lds_padding = false}, .thread_cluster_access_order = {2, 0, 1}, .src_access_order = {1, 0, 2}, }, .c = { - .thread_cluster = - {.gemm_m_block_size = 1, .gemm_m_per_block = 8, .gemm_n_block_size = 1, .gemm_n_per_block = 8}, - .epilogue = {.m_xdl_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, - .scalar_per_vector = 2}, + .thread_cluster = {.gemm_m_block_size = 1, + .gemm_m_per_block = 8, + .gemm_n_block_size = 1, + .gemm_n_per_block = 8}, + .epilogue = {.m_xdl_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 2}, }, }; @@ -162,11 +168,11 @@ constexpr InputOutputTileTransfer<> Transfer_4x64x1_fp8{ { .thread_cluster = {.k0 = 4, .m_n = 64, .k1 = 1}, .lds_transfer = {.global_memory_vector_load_size = 8, - .src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = true}, + .src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, .thread_cluster_access_order = {1, 0, 2}, .src_access_order = {1, 0, 2}, }, @@ -174,21 +180,23 @@ constexpr InputOutputTileTransfer<> Transfer_4x64x1_fp8{ { .thread_cluster = {.k0 = 4, .m_n = 64, .k1 = 1}, .lds_transfer = {.global_memory_vector_load_size = 8, - .src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = true}, + .src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, .thread_cluster_access_order = {1, 0, 2}, .src_access_order = {1, 0, 2}, }, .c = { - .thread_cluster = - {.gemm_m_block_size = 1, .gemm_m_per_block = 32, .gemm_n_block_size = 1, .gemm_n_per_block = 8}, - .epilogue = {.m_xdl_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, - .scalar_per_vector = 8}, + .thread_cluster = {.gemm_m_block_size = 1, + .gemm_m_per_block = 32, + .gemm_n_block_size = 1, + .gemm_n_per_block = 8}, + .epilogue = {.m_xdl_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, }, }; @@ -196,34 +204,36 @@ constexpr InputOutputTileTransfer<> Transfer_4x16x1{ .a = { .thread_cluster = {.k0 = 4, .m_n = 16, .k1 = 1}, - .lds_transfer = {.global_memory_vector_load_size = 8, - .src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = true}, + .lds_transfer = {.global_memory_vector_load_size = 8, + .src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, .thread_cluster_access_order = {1, 0, 2}, .src_access_order = {1, 0, 2}, }, .b = { .thread_cluster = {.k0 = 4, .m_n = 16, .k1 = 1}, - .lds_transfer = {.global_memory_vector_load_size = 8, - .src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = true}, + .lds_transfer = {.global_memory_vector_load_size = 8, + .src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, .thread_cluster_access_order = {1, 0, 2}, .src_access_order = {1, 0, 2}, }, .c = { - .thread_cluster = - {.gemm_m_block_size = 1, .gemm_m_per_block = 16, .gemm_n_block_size = 1, .gemm_n_per_block = 4}, - .epilogue = {.m_xdl_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, - .scalar_per_vector = 8}, + .thread_cluster = {.gemm_m_block_size = 1, + .gemm_m_per_block = 16, + .gemm_n_block_size = 1, + .gemm_n_per_block = 4}, + .epilogue = {.m_xdl_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, }, }; @@ -232,34 +242,36 @@ constexpr InputOutputTileTransfer<> Transfer_4x32x1{ .a = { .thread_cluster = {.k0 = 4, .m_n = 32, .k1 = 1}, - .lds_transfer = {.global_memory_vector_load_size = 8, - .src_vector_dim = 2, - .src_scalar_per_vector = 16, - .lds_dst_scalar_per_vector = 16, - .is_direct_load = false, - .lds_padding = true}, + .lds_transfer = {.global_memory_vector_load_size = 8, + .src_vector_dim = 2, + .src_scalar_per_vector = 16, + .lds_dst_scalar_per_vector = 16, + .is_direct_load = false, + .lds_padding = true}, .thread_cluster_access_order = {1, 0, 2}, .src_access_order = {1, 0, 2}, }, .b = { .thread_cluster = {.k0 = 4, .m_n = 32, .k1 = 1}, - .lds_transfer = {.global_memory_vector_load_size = 8, - .src_vector_dim = 2, - .src_scalar_per_vector = 16, - .lds_dst_scalar_per_vector = 16, - .is_direct_load = false, - .lds_padding = true}, + .lds_transfer = {.global_memory_vector_load_size = 8, + .src_vector_dim = 2, + .src_scalar_per_vector = 16, + .lds_dst_scalar_per_vector = 16, + .is_direct_load = false, + .lds_padding = true}, .thread_cluster_access_order = {1, 0, 2}, .src_access_order = {1, 0, 2}, }, .c = { - .thread_cluster = - {.gemm_m_block_size = 1, .gemm_m_per_block = 32, .gemm_n_block_size = 1, .gemm_n_per_block = 4}, - .epilogue = {.m_xdl_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, - .scalar_per_vector = 8}, + .thread_cluster = {.gemm_m_block_size = 1, + .gemm_m_per_block = 32, + .gemm_n_block_size = 1, + .gemm_n_per_block = 4}, + .epilogue = {.m_xdl_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, }, }; @@ -267,68 +279,94 @@ constexpr InputOutputTileTransfer<> Transfer_4x32x1_vector_load_16_generic{ .a = { .thread_cluster = {.k0 = 4, .m_n = 32, .k1 = 1}, - .lds_transfer = {.global_memory_vector_load_size = 16, - .src_vector_dim = 2, - .src_scalar_per_vector = 1, - .lds_dst_scalar_per_vector = 16, - .is_direct_load = false, - .lds_padding = true}, + .lds_transfer = {.global_memory_vector_load_size = 16, + .src_vector_dim = 2, + .src_scalar_per_vector = 1, + .lds_dst_scalar_per_vector = 16, + .is_direct_load = false, + .lds_padding = true}, .thread_cluster_access_order = {1, 0, 2}, .src_access_order = {1, 0, 2}, }, .b = { .thread_cluster = {.k0 = 4, .m_n = 32, .k1 = 1}, - .lds_transfer = {.global_memory_vector_load_size = 16, - .src_vector_dim = 2, - .src_scalar_per_vector = 1, - .lds_dst_scalar_per_vector = 16, - .is_direct_load = false, - .lds_padding = true}, + .lds_transfer = {.global_memory_vector_load_size = 16, + .src_vector_dim = 2, + .src_scalar_per_vector = 1, + .lds_dst_scalar_per_vector = 16, + .is_direct_load = false, + .lds_padding = true}, .thread_cluster_access_order = {1, 0, 2}, .src_access_order = {1, 0, 2}, }, .c = { - .thread_cluster = - {.gemm_m_block_size = 1, .gemm_m_per_block = 32, .gemm_n_block_size = 1, .gemm_n_per_block = 4}, - .epilogue = {.m_xdl_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, - .scalar_per_vector = 1}, + .thread_cluster = {.gemm_m_block_size = 1, + .gemm_m_per_block = 32, + .gemm_n_block_size = 1, + .gemm_n_per_block = 4}, + .epilogue = {.m_xdl_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 1}, }, }; -constexpr WarpGemmParams BwdGemmParams_Xdl_4x4_per_wave{ - .matrix_instruction = MatrixInstructionType::XDL, - .gemm_m_per_instruction = 32, .gemm_n_per_instruction = 32, .gemm_m_iters_per_wave = 4, .gemm_n_iters_per_wave = 4}; - -constexpr WarpGemmParams BwdGemmParams_Xdl_1x1_per_wave{ - .matrix_instruction = MatrixInstructionType::XDL, - .gemm_m_per_instruction = 32, .gemm_n_per_instruction = 32, .gemm_m_iters_per_wave = 1, .gemm_n_iters_per_wave = 1}; - -constexpr WarpGemmParams FwdGemmParams_Xdl_4x4_per_wave{ - .matrix_instruction = MatrixInstructionType::XDL, - .gemm_m_per_instruction = 32, .gemm_n_per_instruction = 32, .gemm_m_iters_per_wave = 4, .gemm_n_iters_per_wave = 4}; - -constexpr WarpGemmParams FwdGemmParams_Xdl_4x2_per_wave{ - .matrix_instruction = MatrixInstructionType::XDL, - .gemm_m_per_instruction = 32, .gemm_n_per_instruction = 32, .gemm_m_iters_per_wave = 4, .gemm_n_iters_per_wave = 2}; - -constexpr WarpGemmParams FwdGemmParams_Xdl_2x2_per_wave{ - .matrix_instruction = MatrixInstructionType::XDL, - .gemm_m_per_instruction = 32, .gemm_n_per_instruction = 32, .gemm_m_iters_per_wave = 2, .gemm_n_iters_per_wave = 2}; - -constexpr WarpGemmParams FwdGemmParams_Xdl_2x1_per_wave{ - .matrix_instruction = MatrixInstructionType::XDL, - .gemm_m_per_instruction = 32, .gemm_n_per_instruction = 32, .gemm_m_iters_per_wave = 2, .gemm_n_iters_per_wave = 1}; - -constexpr WarpGemmParams GemmParams_Wmma_16x16_2x1_per_wave{ - .matrix_instruction = MatrixInstructionType::WMMA, - .gemm_m_per_instruction = 16, .gemm_n_per_instruction = 16, .gemm_m_iters_per_wave = 2, .gemm_n_iters_per_wave = 1}; - -constexpr WarpGemmParams GemmParams_Wmma_16x16_2x2_per_wave{ - .matrix_instruction = MatrixInstructionType::WMMA, - .gemm_m_per_instruction = 16, .gemm_n_per_instruction = 16, .gemm_m_iters_per_wave = 2, .gemm_n_iters_per_wave = 2}; +constexpr WarpGemmParams BwdGemmParams_Xdl_4x4_per_wave{.matrix_instruction = + MatrixInstructionType::XDL, + .gemm_m_per_instruction = 32, + .gemm_n_per_instruction = 32, + .gemm_m_iters_per_wave = 4, + .gemm_n_iters_per_wave = 4}; + +constexpr WarpGemmParams BwdGemmParams_Xdl_1x1_per_wave{.matrix_instruction = + MatrixInstructionType::XDL, + .gemm_m_per_instruction = 32, + .gemm_n_per_instruction = 32, + .gemm_m_iters_per_wave = 1, + .gemm_n_iters_per_wave = 1}; + +constexpr WarpGemmParams FwdGemmParams_Xdl_4x4_per_wave{.matrix_instruction = + MatrixInstructionType::XDL, + .gemm_m_per_instruction = 32, + .gemm_n_per_instruction = 32, + .gemm_m_iters_per_wave = 4, + .gemm_n_iters_per_wave = 4}; + +constexpr WarpGemmParams FwdGemmParams_Xdl_4x2_per_wave{.matrix_instruction = + MatrixInstructionType::XDL, + .gemm_m_per_instruction = 32, + .gemm_n_per_instruction = 32, + .gemm_m_iters_per_wave = 4, + .gemm_n_iters_per_wave = 2}; + +constexpr WarpGemmParams FwdGemmParams_Xdl_2x2_per_wave{.matrix_instruction = + MatrixInstructionType::XDL, + .gemm_m_per_instruction = 32, + .gemm_n_per_instruction = 32, + .gemm_m_iters_per_wave = 2, + .gemm_n_iters_per_wave = 2}; + +constexpr WarpGemmParams FwdGemmParams_Xdl_2x1_per_wave{.matrix_instruction = + MatrixInstructionType::XDL, + .gemm_m_per_instruction = 32, + .gemm_n_per_instruction = 32, + .gemm_m_iters_per_wave = 2, + .gemm_n_iters_per_wave = 1}; + +constexpr WarpGemmParams GemmParams_Wmma_16x16_2x1_per_wave{.matrix_instruction = + MatrixInstructionType::WMMA, + .gemm_m_per_instruction = 16, + .gemm_n_per_instruction = 16, + .gemm_m_iters_per_wave = 2, + .gemm_n_iters_per_wave = 1}; + +constexpr WarpGemmParams GemmParams_Wmma_16x16_2x2_per_wave{.matrix_instruction = + MatrixInstructionType::WMMA, + .gemm_m_per_instruction = 16, + .gemm_n_per_instruction = 16, + .gemm_m_iters_per_wave = 2, + .gemm_n_iters_per_wave = 2}; constexpr ThreadBlock ThreadBlock_256_256x256x32{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; @@ -357,19 +395,19 @@ constexpr ThreadBlock ThreadBlock_128_128x128x32{.block_size = 128, constexpr ThreadBlock ThreadBlock_128_64x64x64{.block_size = 128, .tile_size = {.m = 64, .n = 64, .k = 64}}; -constexpr GemmPipeline BlockGemmDesc_v1_intrawave = { - .pipeline_version = PipelineVersion::V1, .scheduler = PipelineScheduler::INTRAWAVE}; +constexpr GemmPipeline BlockGemmDesc_v1_intrawave = {.pipeline_version = PipelineVersion::V1, + .scheduler = PipelineScheduler::INTRAWAVE}; -constexpr GemmPipeline BlockGemmDesc_v2_intrawave = { - .pipeline_version = PipelineVersion::V2, .scheduler = PipelineScheduler::INTRAWAVE}; +constexpr GemmPipeline BlockGemmDesc_v2_intrawave = {.pipeline_version = PipelineVersion::V2, + .scheduler = PipelineScheduler::INTRAWAVE}; -constexpr GemmPipeline BlockGemmDesc_v3_intrawave = { - .pipeline_version = PipelineVersion::V3, .scheduler = PipelineScheduler::INTRAWAVE}; +constexpr GemmPipeline BlockGemmDesc_v3_intrawave = {.pipeline_version = PipelineVersion::V3, + .scheduler = PipelineScheduler::INTRAWAVE}; -constexpr GemmPipeline BlockGemmDesc_v4_intrawave = { - .pipeline_version = PipelineVersion::V4, .scheduler = PipelineScheduler::INTRAWAVE}; +constexpr GemmPipeline BlockGemmDesc_v4_intrawave = {.pipeline_version = PipelineVersion::V4, + .scheduler = PipelineScheduler::INTRAWAVE}; -constexpr GemmPipeline BlockGemmDesc_v5_intrawave = { - .pipeline_version = PipelineVersion::V5, .scheduler = PipelineScheduler::INTRAWAVE}; +constexpr GemmPipeline BlockGemmDesc_v5_intrawave = {.pipeline_version = PipelineVersion::V5, + .scheduler = PipelineScheduler::INTRAWAVE}; } // namespace ck_tile::builder::test_utils diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index db5aa93a7ce..2f66a633fad 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -98,7 +98,8 @@ template <> inline std::string to_string(GemmPipeline t) { std::ostringstream oss; - oss << t.num_gemm_k_prefetch_stages << "," << t.num_conv_groups_to_merge << "," << to_string(t.scheduler) << "," << to_string(t.pipeline_version); + oss << t.num_gemm_k_prefetch_stages << "," << t.num_conv_groups_to_merge << "," + << to_string(t.scheduler) << "," << to_string(t.pipeline_version); return oss.str(); } @@ -123,17 +124,17 @@ inline std::string to_string(InputThreadCluster t) template <> inline std::string to_string(OutputThreadCluster t) { - return array_to_seq( - std::array{t.gemm_m_block_size, t.gemm_m_per_block, t.gemm_n_block_size, t.gemm_n_per_block}); + return array_to_seq(std::array{ + t.gemm_m_block_size, t.gemm_m_per_block, t.gemm_n_block_size, t.gemm_n_per_block}); } template <> inline std::string to_string(LdsTransfer t) { std::ostringstream oss; - oss << t.global_memory_vector_load_size << "," << t.src_vector_dim << "," << t.src_scalar_per_vector << "," << t.lds_dst_scalar_per_vector - << "," << (t.lds_padding ? "true" : "false") << "," - << (t.is_direct_load ? "true" : "false"); + oss << t.global_memory_vector_load_size << "," << t.src_vector_dim << "," + << t.src_scalar_per_vector << "," << t.lds_dst_scalar_per_vector << "," + << (t.lds_padding ? "true" : "false") << "," << (t.is_direct_load ? "true" : "false"); return oss.str(); } @@ -259,7 +260,6 @@ inline std::string to_string(WarpGemm_ t) return to_string(t.warp_gemm); } - template inline std::string to_string(InputOutputTileTransfer_ t) { @@ -319,21 +319,21 @@ inline std::string to_string(t)) << "," << t.transfer.a.lds_transfer.global_memory_vector_load_size << "," - << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); + << to_string(static_cast(t)) << "," + << to_string(static_cast>(t)); } - else + else { oss << to_string(static_cast(t)) << "," << t.transfer.a.lds_transfer.global_memory_vector_load_size << "," << t.transfer.b.lds_transfer.global_memory_vector_load_size << "," - << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); - } + << to_string(static_cast(t)) << "," + << to_string(static_cast>(t)); + } return oss.str(); } @@ -342,11 +342,11 @@ inline std::string to_string(t)) << "," + oss << to_string(static_cast(t)) << "," << t.transfer.a.lds_transfer.global_memory_vector_load_size << "," << t.transfer.b.lds_transfer.global_memory_vector_load_size << "," - << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); + << to_string(static_cast(t)) << "," + << to_string(static_cast>(t)); return oss.str(); } @@ -367,11 +367,11 @@ inline std::string to_string(t)) << "," + oss << to_string(static_cast(t)) << "," << t.transfer.a.lds_transfer.global_memory_vector_load_size << "," << t.transfer.b.lds_transfer.global_memory_vector_load_size << "," - << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); + << to_string(static_cast(t)) << "," + << to_string(static_cast>(t)); return oss.str(); } @@ -380,10 +380,10 @@ inline std::string to_string(t)) << "," + oss << to_string(static_cast(t)) << "," << t.transfer.a.lds_transfer.global_memory_vector_load_size << "," - << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); + << to_string(static_cast(t)) << "," + << to_string(static_cast>(t)); return oss.str(); } @@ -392,10 +392,10 @@ inline std::string to_string(t)) << "," + oss << to_string(static_cast(t)) << "," << t.transfer.a.lds_transfer.global_memory_vector_load_size << "," - << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); + << to_string(static_cast(t)) << "," + << to_string(static_cast>(t)); return oss.str(); } @@ -404,10 +404,10 @@ inline std::string to_string(t)) << "," + oss << to_string(static_cast(t)) << "," << t.transfer.a.lds_transfer.global_memory_vector_load_size << "," - << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); + << to_string(static_cast(t)) << "," + << to_string(static_cast>(t)); return oss.str(); } @@ -416,10 +416,10 @@ inline std::string to_string(t)) << "," + oss << to_string(static_cast(t)) << "," << t.transfer.a.lds_transfer.global_memory_vector_load_size << "," - << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); + << to_string(static_cast(t)) << "," + << to_string(static_cast>(t)); return oss.str(); } @@ -430,8 +430,8 @@ inline std::string to_string(t)) << "," << t.transfer.a.lds_transfer.global_memory_vector_load_size << "," - << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); + << to_string(static_cast(t)) << "," + << to_string(static_cast>(t)); return oss.str(); } @@ -441,14 +441,13 @@ inline std::string to_string(t)) - << "," << t.transfer.a.lds_transfer.global_memory_vector_load_size - << "," << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); + oss << to_string(static_cast(t)) << "," + << t.transfer.a.lds_transfer.global_memory_vector_load_size << "," + << to_string(static_cast(t)) << "," + << to_string(static_cast>(t)); return oss.str(); } - template <> inline std::string to_string( ConvAlgorithm_DeviceGroupedConvBwdWeight_Dl t) @@ -466,10 +465,10 @@ inline std::string to_string(t)) << "," + oss << to_string(static_cast(t)) << "," << t.transfer.a.lds_transfer.global_memory_vector_load_size << "," - << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); + << to_string(static_cast(t)) << "," + << to_string(static_cast>(t)); return oss.str(); } From dc6f8f1067f018d35ec447a9f48012fe0c214ab3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Thu, 15 Jan 2026 05:27:22 -0500 Subject: [PATCH 79/81] Fix WMMA conv algorithms hierarchy. --- .../include/ck_tile/builder/conv_algorithm_concepts.hpp | 5 ----- .../builder/include/ck_tile/builder/factory/README.md | 4 ++-- .../include/ck_tile/builder/factory/conv_algorithms.hpp | 7 +++---- ...test_ckb_conv_bwd_weight_two_stage_wmma_cshuffle_v3.cpp | 2 +- 4 files changed, 6 insertions(+), 12 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index ce60bf4a94a..ffe265b0cc8 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -260,11 +260,6 @@ concept SpecifiesGemmSpecialization = requires { { T::gemm_specialization } -> std::convertible_to; }; -template -concept SpecifiesNumPrefetchStages = requires { - { T::num_gemm_k_prefetch_stages } -> SizeType; -}; - template concept SpecifiesNumGroupsToMerge = requires { { T::num_conv_groups_to_merge } -> SizeType; diff --git a/experimental/builder/include/ck_tile/builder/factory/README.md b/experimental/builder/include/ck_tile/builder/factory/README.md index 4c30348f9a7..ab3fe737558 100644 --- a/experimental/builder/include/ck_tile/builder/factory/README.md +++ b/experimental/builder/include/ck_tile/builder/factory/README.md @@ -159,7 +159,7 @@ BwdAlgorithm + SpecifiesWmma │ │ │ └─── BwdWmmaAlgorithm │ │ -│ └─ Requirements: Base + SpecifiesNumPrefetchStages + SpecifiesGemmPipeline + SpecifiesGenericInstance +│ └─ Requirements: Base + SpecifiesGemmPipeline + SpecifiesGenericInstance │ └─── BwdWmmaV3AlgorithmBase (Based on BwdAlgorithmV3) │ @@ -169,7 +169,7 @@ BwdAlgorithm + SpecifiesWmma │ ├─── BwdWmmaV3Algorithm │ │ - │ └─ Requirements: Base + SpecifiesTransposeTransfer + SpecifiesGenericInstance + │ └─ Requirements: Base + SpecifiesTransposeTransfer │ └─── BwdTwoStageWmmaV3Algorithm │ diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index ad72609f848..8327b7a2ad2 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -108,8 +108,8 @@ concept BwdTwoStageXdlAlgorithm = // BWD weight WMMA algorithm concepts template concept BwdWmmaAlgorithm = - BwdWmmaAlgorithmBase && SpecifiesNumPrefetchStages && - SpecifiesGemmPipeline && SpecifiesGenericInstance; + BwdWmmaAlgorithmBase && SpecifiesGemmPipeline && + SpecifiesGenericInstance; template concept BwdMultiDWmmaV3Algorithm = @@ -117,8 +117,7 @@ concept BwdMultiDWmmaV3Algorithm = template concept BwdWmmaV3Algorithm = - BwdWmmaV3AlgorithmBase && SpecifiesTransposeTransfer && - SpecifiesGenericInstance; + BwdWmmaV3AlgorithmBase && SpecifiesTransposeTransfer; template concept BwdTwoStageWmmaV3Algorithm = diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_wmma_cshuffle_v3.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_wmma_cshuffle_v3.cpp index 9310f5a9a69..9ff27f6f997 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_wmma_cshuffle_v3.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_wmma_cshuffle_v3.cpp @@ -19,7 +19,7 @@ constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 2, .weight = {.config = {.layout = GKYXC}}, .output = {.config = {.layout = NGKHW}}}; -constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_CShuffle{} +constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_CShuffle_V3{} .with_thread_block(cku::ThreadBlock_64_32x32x32) .with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave) .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) From 3b784ca60336df55a1e64d84b38cce36055eec4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Thu, 15 Jan 2026 05:31:52 -0500 Subject: [PATCH 80/81] clang-format --- .../include/ck_tile/builder/factory/conv_algorithms.hpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index 8327b7a2ad2..df921a0cf6c 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -107,9 +107,8 @@ concept BwdTwoStageXdlAlgorithm = // BWD weight WMMA algorithm concepts template -concept BwdWmmaAlgorithm = - BwdWmmaAlgorithmBase && SpecifiesGemmPipeline && - SpecifiesGenericInstance; +concept BwdWmmaAlgorithm = BwdWmmaAlgorithmBase && SpecifiesGemmPipeline && + SpecifiesGenericInstance; template concept BwdMultiDWmmaV3Algorithm = From e5be038a5e8e64ec6eb56369b910bdccb8df6fb1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Fri, 16 Jan 2026 05:32:29 -0500 Subject: [PATCH 81/81] Fix GEMM pipeline concept usage. --- .../include/ck_tile/builder/factory/README.md | 7 +++---- .../ck_tile/builder/factory/conv_algorithms.hpp | 13 ++++++------- .../builder/test/impl/conv_algorithm_types.hpp | 4 ---- 3 files changed, 9 insertions(+), 15 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/factory/README.md b/experimental/builder/include/ck_tile/builder/factory/README.md index 4c30348f9a7..f82b5f9ad67 100644 --- a/experimental/builder/include/ck_tile/builder/factory/README.md +++ b/experimental/builder/include/ck_tile/builder/factory/README.md @@ -49,14 +49,13 @@ XDL and WMMA algorithms share a common base, while DL algorithms have their own Both XDL and WMMA algorithms share the following foundational concepts: ``` -ConvAlgorithm (Base Concept) +ConvWarpGemmAlgorithm (Base Concept) │ │ Requirements: │ • ConvAlgorithmDescriptor │ • SpecifiesThreadBlock │ • SpecifiesTileTransferParameters (ThreadClusters, LdsTransfer, AccessOrders) │ • SpecifiesWarpGemm -│ • SpecifiesGemmPipeline │ ├─── FwdAlgorithm (Forward Convolution) │ │ @@ -64,7 +63,7 @@ ConvAlgorithm (Base Concept) │ │ │ └─── FwdAlgorithmV3 │ │ -│ │ Additional: SpecifiesPipelineV3 +│ │ Additional: SpecifiesPipelineV3 + SpecifiesGemmPipeline │ │ │ └─── BwdAlgorithm (Backward Weight Convolution) @@ -73,7 +72,7 @@ ConvAlgorithm (Base Concept) │ └─── BwdAlgorithmV3 │ - │ Additional: SpecifiesPipelineV3 + │ Additional: SpecifiesPipelineV3 + SpecifiesGemmPipeline │ ``` diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index ad72609f848..0c4d766fff2 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -14,23 +14,22 @@ concept SpecifiesTileTransferParameters = // Base algorithm concepts template -concept ConvAlgorithm = ConvAlgorithmDescriptor && SpecifiesThreadBlock && +concept ConvWarpGemmAlgorithm = ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesTileTransferParameters && - SpecifiesWarpGemm && SpecifiesGemmPipeline; -; + SpecifiesWarpGemm; template -concept FwdAlgorithm = ConvAlgorithm && SpecifiesFwdConvSpecialization; +concept FwdAlgorithm = ConvWarpGemmAlgorithm && SpecifiesFwdConvSpecialization; template -concept FwdAlgorithmV3 = FwdAlgorithm && SpecifiesPipelineV3; +concept FwdAlgorithmV3 = FwdAlgorithm && SpecifiesPipelineV3 && SpecifiesGemmPipeline; template concept BwdAlgorithm = - ConvAlgorithm && SpecifiesBwdWeightConvSpecialization; + ConvWarpGemmAlgorithm && SpecifiesBwdWeightConvSpecialization; template -concept BwdAlgorithmV3 = BwdAlgorithm && SpecifiesPipelineV3; +concept BwdAlgorithmV3 = BwdAlgorithm && SpecifiesPipelineV3 && SpecifiesGemmPipeline; template concept DlAlgorithm = diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index cf4bc25c0f9..6989887a264 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -543,8 +543,6 @@ using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle = WarpGemm_, InputOutputTileTransfer_<4>, ConvSpecializationBwdWeight_, - GemmPipeline_, // Not needed, but we need this to meet the ConvAlgorithm - // concept. TransposeParams_, AlgorithmSpecialization_<>>; @@ -594,8 +592,6 @@ using ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle = ConvAlgorithmTemplate, - GemmPipeline_, // Not needed, but we need this to meet the ConvAlgorithm - // concept. ConvSpecializationBwdWeight_, AlgorithmSpecialization_>;