diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index 791924ccd44..ffe265b0cc8 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -30,40 +30,33 @@ concept ThreadBlockDescriptor = requires(T t) { // Concept for parameters that describe a gridwise XDL GEMM problem. template -concept GridwiseXdlGemmDescriptor = requires(T t) { - { t.m_per_xdl } -> SizeType; - { t.n_per_xdl } -> SizeType; - { t.m_xdl_per_wave } -> SizeType; - { t.n_xdl_per_wave } -> SizeType; +concept WarpGemmDescriptor = requires(T t) { + { t.matrix_instruction } -> std::convertible_to; + { t.gemm_m_per_instruction } -> SizeType; + { t.gemm_n_per_instruction } -> SizeType; + { t.gemm_m_iters_per_wave } -> SizeType; + { t.gemm_n_iters_per_wave } -> SizeType; }; -// Concept for parameter that describe block GEMM problem. +// Concept for parameters that describe the GEMM pipeline. template -concept BlockGemmPipelineDescriptor = requires(T t) { +concept GemmPipelineDescriptor = requires(T t) { + { t.num_conv_groups_to_merge } -> SizeType; + { t.num_gemm_k_prefetch_stages } -> SizeType; { t.pipeline_version } -> std::convertible_to; { t.scheduler } -> std::convertible_to; }; -// Concept for parameters that describe a gridwise WMMA GEMM problem. -template -concept GridwiseWmmaGemmDescriptor = requires(T t) { - { t.k1 } -> SizeType; - { t.m_per_wmma } -> SizeType; - { t.n_per_wmma } -> SizeType; - { t.m_wmma_per_wave } -> SizeType; - { t.n_wmma_per_wave } -> SizeType; -}; - // Concept for vectorized data transfer for convolution input tensors. template -concept BlockTransferDescriptor3D = requires(T t) { +concept InputTileThreadClusterDescriptor3D = requires(T t) { { t.k0 } -> SizeType; { t.m_n } -> SizeType; { t.k1 } -> SizeType; }; template -concept BlockTransferDescriptor4D = requires(T t) { +concept InputTileThreadClusterDescriptor4D = requires(T t) { { t.k0 } -> SizeType; { t.m_n } -> SizeType; { t.k1 } -> SizeType; @@ -71,21 +64,23 @@ concept BlockTransferDescriptor4D = requires(T t) { }; template -concept BlockTransferDescriptor = (ThreadClusterRank == 3 && BlockTransferDescriptor3D) || - (ThreadClusterRank == 4 && BlockTransferDescriptor4D); +concept InputTileThreadClusterDescriptor = + (ThreadClusterRank == 3 && InputTileThreadClusterDescriptor3D) || + (ThreadClusterRank == 4 && InputTileThreadClusterDescriptor4D); // Concept for thread cluster dimensions for GEMM output tensor. template -concept ThreadClusterDescriptor = requires(T t) { - { t.m_block } -> SizeType; - { t.m_wave_per_xdl } -> SizeType; - { t.n_block } -> SizeType; - { t.n_wave_per_xdl } -> SizeType; +concept OutputTileThreadClusterDescriptor = requires(T t) { + { t.gemm_m_block_size } -> SizeType; + { t.gemm_m_per_block } -> SizeType; + { t.gemm_n_block_size } -> SizeType; + { t.gemm_n_per_block } -> SizeType; }; // Concept for the LDS transfer for the convolution input tensors. template concept LdsTransferDescriptor = requires(T t) { + { t.global_memory_vector_load_size } -> SizeType; { t.src_vector_dim } -> SizeType; { t.src_scalar_per_vector } -> SizeType; { t.lds_dst_scalar_per_vector } -> SizeType; @@ -172,45 +167,18 @@ concept SpecifiesTileThreadBlock = requires { { T::thread_block } -> TileThreadBlockDescriptor; }; -// Concept to check if a struct specifies gridwise XDL GEMM info. -template -concept GridwiseFwdXdlGemmDescriptor = requires(T t) { - { t.ak1 } -> SizeType; - { t.bk1 } -> SizeType; - { t.xdl_params } -> GridwiseXdlGemmDescriptor; -}; - -// Concept to check if a struct specifies gridwise XDL GEMM info. -template -concept GridwiseBwdXdlGemmDescriptor = requires(T t) { - { t.k1 } -> SizeType; - { t.xdl_params } -> GridwiseXdlGemmDescriptor; -}; - -// Concept to check if a struct specifies gridwise XDL GEMM info. -template -concept SpecifiesGridwiseFwdXdlGemm = requires(T t) { - { t.gridwise_gemm } -> GridwiseFwdXdlGemmDescriptor; -}; - -// Concept to check if a struct specifies gridwise XDL GEMM info. -template -concept SpecifiesGridwiseBwdXdlGemm = requires(T t) { - { t.gridwise_gemm } -> GridwiseBwdXdlGemmDescriptor; -}; - -// Concept to check if a struct specifies gridwise WMMA GEMM info. +// Concept to check if a struct specifies warp GEMM info. template -concept SpecifiesGridwiseWmmaGemm = requires(T t) { - { t.gridwise_gemm } -> GridwiseWmmaGemmDescriptor; +concept SpecifiesWarpGemm = requires { + { T::warp_gemm } -> WarpGemmDescriptor; }; // Concept to check if a struct specifies convolution input and output block transfer info. template -concept SpecifiesBlockTransfer = requires(T t) { - { T::transfer.a.block_transfer } -> BlockTransferDescriptor; - { T::transfer.b.block_transfer } -> BlockTransferDescriptor; - { T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor; +concept SpecifiesThreadClusters = requires(T t) { + { T::transfer.a.thread_cluster } -> InputTileThreadClusterDescriptor; + { T::transfer.b.thread_cluster } -> InputTileThreadClusterDescriptor; + { T::transfer.c.thread_cluster } -> OutputTileThreadClusterDescriptor; }; // Concept to check if a struct specifies convolution scalar per vector infor for A, B and C. @@ -232,8 +200,8 @@ concept SpecifiesLdsTransfer = requires(T t) { // Concept to check if a struct specifies thread cluster access order info. template concept SpecifiesThreadClusterAccessOrder = requires(T t) { - { T::transfer.a.block_transfer_access_order } -> AccessOrderDescriptor; - { T::transfer.b.block_transfer_access_order } -> AccessOrderDescriptor; + { T::transfer.a.thread_cluster_access_order } -> AccessOrderDescriptor; + { T::transfer.b.thread_cluster_access_order } -> AccessOrderDescriptor; }; // Concept to check if a struct specifies source access order info. @@ -245,13 +213,8 @@ concept SpecifiesSourceAccessOrder = requires(T t) { // Concept to check if struct specifies block GEMM. template -concept SpecifiesBlockGemm = requires { - { T::block_gemm_pipeline } -> BlockGemmPipelineDescriptor; -}; - -template -concept SpecifiesGridwiseGemmPipeline = requires { - { T::pipeline_version } -> std::convertible_to; +concept SpecifiesGemmPipeline = requires { + { T::gemm_pipeline } -> GemmPipelineDescriptor; }; // Concept to check if struct specifies block GEMM (CK Tile). @@ -297,26 +260,11 @@ concept SpecifiesGemmSpecialization = requires { { T::gemm_specialization } -> std::convertible_to; }; -template -concept SpecifiesNumPrefetchStages = requires { - { T::num_gemm_k_prefetch_stages } -> SizeType; -}; - template concept SpecifiesNumGroupsToMerge = requires { { T::num_conv_groups_to_merge } -> SizeType; }; -template -concept SpecifiesLoopScheduler = requires { - { T::loop_scheduler } -> std::convertible_to; -}; - -template -concept SpecifiesGenericInstance = !requires { - { T::specialization }; -}; - template concept SpecifiesTransposeTransfer = requires { { T::max_transpose_transfer_src_scalar_per_vector } -> SizeType; @@ -333,38 +281,58 @@ template concept TransposeTransferWellDefinedIfProvided = !HasTransposeTransfer || SpecifiesTransposeTransfer; -template -concept SpecifiesGemmBatchOptions = requires { - { T::num_conv_groups_to_merge } -> SizeType; -}; - /******************************************** */ /* Algorithm specialization concepts */ /******************************************** */ template concept SpecifiesLargeTensorSupport = requires { { T::specialization } -> std::convertible_to; - requires T::specialization == ConvAlgorithmSpecialization::LARGE_TENSOR; + requires !!(T::specialization & ConvAlgorithmSpecialization::LARGE_TENSOR); }; template concept SpecifiesReferenceAlgorithm = requires { { T::specialization } -> std::convertible_to; - requires T::specialization == ConvAlgorithmSpecialization::REFERENCE; + requires !!(T::specialization & ConvAlgorithmSpecialization::REFERENCE); }; template concept SpecifiesTwoStageSupport = requires { { T::specialization } -> std::convertible_to; - requires T::specialization == ConvAlgorithmSpecialization::TWO_STAGE; + requires !!(T::specialization & ConvAlgorithmSpecialization::TWO_STAGE); }; template concept SpecifiesMultipleDSupport = requires { { T::specialization } -> std::convertible_to; - requires T::specialization == ConvAlgorithmSpecialization::MULTIPLE_D; + requires !!(T::specialization & ConvAlgorithmSpecialization::MULTIPLE_D); +}; + +template +concept SpecifiesPipelineV3 = requires { + { T::specialization } -> std::convertible_to; + requires !!(T::specialization & ConvAlgorithmSpecialization::PIPELINE_V3); +}; + +template +concept SpecifiesGenericInstance = !requires { + { T::specialization }; +} || requires { + { T::specialization } -> std::convertible_to; + requires !!(T::specialization == ConvAlgorithmSpecialization::NONE); }; +template +concept SpecifiesXdl = + requires { requires Algorithm.warp_gemm.matrix_instruction == MatrixInstructionType::XDL; }; + +template +concept SpecifiesWmma = + requires { requires Algorithm.warp_gemm.matrix_instruction == MatrixInstructionType::WMMA; }; + +template +concept SpecifiesValidWarpGemm = SpecifiesXdl || SpecifiesWmma; + /******************************************** */ /* DL-specific descriptors and requirements */ /******************************************** */ diff --git a/experimental/builder/include/ck_tile/builder/factory/README.md b/experimental/builder/include/ck_tile/builder/factory/README.md index d1794349ab5..96f2d8411f2 100644 --- a/experimental/builder/include/ck_tile/builder/factory/README.md +++ b/experimental/builder/include/ck_tile/builder/factory/README.md @@ -29,3 +29,230 @@ using Factory = decltype(make_conv_instance()); ``` The dispatcher automatically selects the appropriate factory following explicit logic. + +# Convolution Algorithm Hierarchy + +This section illustrates the hierarchy of convolution algorithm concepts defined in `conv_algorithms.hpp`. + +## Overview + +The convolution algorithms are organized into three main categories: + +1. **XDL Algorithms** - GPU matrix multiplication using XDL (matrix core instructions) +2. **WMMA Algorithms** - GPU matrix multiplication using WMMA (Wave Matrix Multiply-Accumulate) +3. **DL Algorithms** - Special vectorized dot-product kernels optimized for specific data layouts with separate implementation. + +XDL and WMMA algorithms share a common base, while DL algorithms have their own independent base. + +## Common Base Hierarchy (XDL & WMMA) + +Both XDL and WMMA algorithms share the following foundational concepts: + +``` +ConvWarpGemmAlgorithm (Base Concept) +│ +│ Requirements: +│ • ConvAlgorithmDescriptor +│ • SpecifiesThreadBlock +│ • SpecifiesTileTransferParameters (ThreadClusters, LdsTransfer, AccessOrders) +│ • SpecifiesWarpGemm +│ +├─── FwdAlgorithm (Forward Convolution) +│ │ +│ │ Additional: SpecifiesFwdConvSpecialization +│ │ +│ └─── FwdAlgorithmV3 +│ │ +│ │ Additional: SpecifiesPipelineV3 + SpecifiesGemmPipeline +│ │ +│ +└─── BwdAlgorithm (Backward Weight Convolution) + │ + │ Additional: SpecifiesBwdWeightConvSpecialization + │ + └─── BwdAlgorithmV3 + │ + │ Additional: SpecifiesPipelineV3 + SpecifiesGemmPipeline + │ +``` + +--- + +## XDL Algorithm Hierarchy + +### Forward XDL Algorithms + +``` +FwdAlgorithm + SpecifiesXdl +│ +├─── FwdXdlAlgorithmBase + │ + ├─── FwdXdlAlgorithm + │ │ + │ └─ Requirements: Base + SpecifiesGenericInstance + │ + ├─── LargeTensorAlgorithm + │ │ + │ └─ Requirements: Base + SpecifiesLargeTensorSupport + │ + └─── FwdXdlV3Algorithm + │ + └─ Based on: FwdAlgorithmV3 + SpecifiesXdl +``` + +### Backward XDL Algorithms + +``` +BwdAlgorithm + SpecifiesXdl +│ +├─── BwdXdlAlgorithmBase (ThreadClusterRank=4) +│ │ +│ ├─── BwdXdlAlgorithm +│ │ │ +│ │ └─ Requirements: Base + SpecifiesTransposeTransfer + SpecifiesGenericInstance +│ │ +│ └─── BwdMultiDXdlAlgorithm +│ │ +│ └─ Requirements: Base + SpecifiesMultipleDSupport +│ +└─── BwdXdlV3AlgorithmBase + │ + ├─── BwdXdlV3Algorithm + │ │ + │ └─ Requirements: Base + │ + └─── BwdTwoStageXdlAlgorithm + │ + └─ Requirements: Base + SpecifiesTransposeTransfer + SpecifiesTwoStageSupport +``` + +**Valid XDL Algorithms:** +- FwdXdlAlgorithm +- FwdXdlV3Algorithm +- LargeTensorAlgorithm +- BwdXdlAlgorithm +- BwdXdlV3Algorithm +- BwdTwoStageXdlAlgorithm +- BwdMultiDXdlAlgorithm + +--- + +## WMMA Algorithm Hierarchy + +### Forward WMMA Algorithms + +``` +FwdAlgorithm + SpecifiesWmma +│ +└─── FwdWmmaAlgorithm + │ + └─ Requirements: Base + SpecifiesWmma +``` + +### Backward WMMA Algorithms + +``` +BwdAlgorithm + SpecifiesWmma +│ +├─── BwdWmmaAlgorithmBase (ThreadClusterRank=3) +│ │ +│ └─── BwdWmmaAlgorithm +│ │ +│ └─ Requirements: Base + SpecifiesGemmPipeline + SpecifiesGenericInstance +│ +└─── BwdWmmaV3AlgorithmBase (Based on BwdAlgorithmV3) + │ + ├─── BwdMultiDWmmaV3Algorithm + │ │ + │ └─ Requirements: Base + SpecifiesMultipleDSupport + │ + ├─── BwdWmmaV3Algorithm + │ │ + │ └─ Requirements: Base + SpecifiesTransposeTransfer + │ + └─── BwdTwoStageWmmaV3Algorithm + │ + └─ Requirements: Base + SpecifiesTransposeTransfer + SpecifiesTwoStageSupport +``` + +**Valid WMMA Algorithms:** +- FwdWmmaAlgorithm +- BwdWmmaAlgorithm +- BwdWmmaV3Algorithm +- BwdTwoStageWmmaV3Algorithm +- BwdMultiDWmmaV3Algorithm + +--- + +## DL Algorithm Hierarchy + +DL algorithms have a separate base and do not share the common hierarchy with XDL/WMMA algorithms. + +``` +DlAlgorithm +│ +│ Requirements: +│ • ConvAlgorithmDescriptor +│ • SpecifiesThreadBlock +│ • SpecifiesDlThreadConfig +│ • SpecifiesDlThreadCluster +│ • SpecifiesDlEpilogue +│ +├─── FwdDlAlgorithmBase +│ │ +│ │ Requirements: Base + SpecifiesFwdConvSpecialization + SpecifiesDlFwdBlockTransfer + SpecifiesGemmSpecialization +│ │ +│ └─── FwdDlAlgorithm +│ +└─── BwdDlAlgorithm + │ + └─ Requirements: Base + SpecifiesBwdWeightConvSpecialization + SpecifiesDlBwdBlockTransfer +``` + +**Valid DL Algorithms:** +- FwdDlAlgorithm +- BwdDlAlgorithm + +--- + + +## Reference Algorithms + +``` +ReferenceAlgorithm +│ +└─ Requirements: ConvAlgorithmDescriptor + + SpecifiesReferenceAlgorithm +``` + +Used for reference implementations and testing. + +## CK Tile Algorithms + +``` +TileAlgorithm +│ +└─ Requirements: ConvAlgorithmDescriptor + + SpecifiesTileThreadBlock + + SpecifiesTileTransfer + + SpecifiesTileConvSpecialization + + SpecifiesTileBlockGemm + + SpecifiesTileOptimizations +``` + +The CK Tile algorithms are applicable to foward convolution as well as backwards convolution (weight and data). + +--- + +## Summary for XDL/WMMA/DL algorithms + +| Category | Algorithm Type | Forward Variants | Backward Variants | +|----------|---------------|------------------|-------------------| +| **XDL** | Base | FwdXdlAlgorithmBase | BwdXdlAlgorithmBase, BwdXdlV3AlgorithmBase | +| | Concrete | • FwdXdlAlgorithm
• FwdXdlV3Algorithm
• LargeTensorAlgorithm | • BwdXdlAlgorithm
• BwdXdlV3Algorithm
• BwdTwoStageXdlAlgorithm
• BwdMultiDXdlAlgorithm | +| **WMMA** | Base | FwdAlgorithm | BwdWmmaAlgorithmBase, BwdWmmaV3AlgorithmBase | +| | Concrete | • FwdWmmaAlgorithm | • BwdWmmaAlgorithm
• BwdWmmaV3Algorithm
• BwdTwoStageWmmaV3Algorithm
• BwdMultiDWmmaV3Algorithm | +| **DL** | Base | FwdDlAlgorithmBase | DlAlgorithm | +| | Concrete | • FwdDlAlgorithm | • BwdDlAlgorithm | + +--- diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index fc0ee48ec0b..2e65dccdd51 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -7,122 +7,149 @@ namespace ck_tile::builder::factory { -// Base algorithm concepts template -concept TileTransferParameters = - SpecifiesBlockTransfer && SpecifiesLdsTransfer && +concept SpecifiesTileTransferParameters = + SpecifiesThreadClusters && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder; -template -concept SpecifiesTileTransferParameters3D = TileTransferParameters; +// Base algorithm concepts +template +concept ConvWarpGemmAlgorithm = ConvAlgorithmDescriptor && SpecifiesThreadBlock && + SpecifiesTileTransferParameters && + SpecifiesWarpGemm; template -concept SpecifiesTileTransferParameters4D = TileTransferParameters; +concept FwdAlgorithm = ConvWarpGemmAlgorithm && SpecifiesFwdConvSpecialization; template -concept FwdXdlAlgorithmBase = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesTileTransferParameters3D && - SpecifiesGridwiseFwdXdlGemm && SpecifiesFwdConvSpecialization && - SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && - SpecifiesNumGroupsToMerge && SpecifiesLoopScheduler; +concept FwdAlgorithmV3 = FwdAlgorithm && SpecifiesPipelineV3 && SpecifiesGemmPipeline; -template -concept BwdXdlAlgorithmBase = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesTileTransferParameters4D && - SpecifiesGridwiseBwdXdlGemm && SpecifiesBwdWeightConvSpecialization; +template +concept BwdAlgorithm = + ConvWarpGemmAlgorithm && SpecifiesBwdWeightConvSpecialization; template -concept BwdXdlV3AlgorithmBase = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesTileTransferParameters3D && - SpecifiesGridwiseBwdXdlGemm && SpecifiesBwdWeightConvSpecialization && - SpecifiesBlockGemm; +concept BwdAlgorithmV3 = BwdAlgorithm && SpecifiesPipelineV3 && SpecifiesGemmPipeline; template -concept BwdWmmaAlgorithmBase = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesTileTransferParameters3D && - SpecifiesGridwiseWmmaGemm && SpecifiesBwdWeightConvSpecialization; +concept DlAlgorithm = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesDlThreadConfig && + SpecifiesDlThreadCluster && SpecifiesDlEpilogue; template -concept BwdWmmaV3AlgorithmBase = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesTileTransferParameters3D && - SpecifiesGridwiseWmmaGemm && SpecifiesBwdWeightConvSpecialization && - SpecifiesBlockGemm; +concept FwdDlAlgorithmBase = DlAlgorithm && SpecifiesFwdConvSpecialization && + SpecifiesDlFwdBlockTransfer && SpecifiesGemmSpecialization; + +template +concept FwdXdlAlgorithmBase = FwdAlgorithm && SpecifiesXdl; + +template +concept BwdXdlAlgorithmBase = BwdAlgorithm && SpecifiesXdl; + +template +concept BwdXdlV3AlgorithmBase = BwdAlgorithmV3 && SpecifiesXdl; + +template +concept BwdWmmaAlgorithmBase = BwdAlgorithm && SpecifiesWmma; + +template +concept BwdWmmaV3AlgorithmBase = BwdAlgorithmV3 && SpecifiesWmma; // Reference algorithm concept -template -concept ReferenceAlgorithm = ConvAlgorithmDescriptor && SpecifiesReferenceAlgorithm; +template +concept ReferenceAlgorithm = + ConvAlgorithmDescriptor && SpecifiesReferenceAlgorithm; // Tile-based algorithm concept -template -concept TileAlgorithm = ConvAlgorithmDescriptor && SpecifiesTileThreadBlock && - SpecifiesTileTransfer && SpecifiesTileConvSpecialization && - SpecifiesTileBlockGemm && SpecifiesTileOptimizations; +template +concept TileAlgorithm = + ConvAlgorithmDescriptor && SpecifiesTileThreadBlock && + SpecifiesTileTransfer && SpecifiesTileConvSpecialization && + SpecifiesTileBlockGemm && SpecifiesTileOptimizations; // FWD XDL algorithm concepts -template -concept FwdXdlAlgorithm = FwdXdlAlgorithmBase && SpecifiesGenericInstance; +template +concept FwdXdlAlgorithm = FwdXdlAlgorithmBase && SpecifiesGenericInstance; -template -concept LargeTensorAlgorithm = FwdXdlAlgorithmBase && SpecifiesLargeTensorSupport; +template +concept LargeTensorAlgorithm = + FwdXdlAlgorithmBase && SpecifiesLargeTensorSupport; -template -concept FwdXdlV3Algorithm = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesTileTransferParameters3D && - SpecifiesGridwiseFwdXdlGemm && SpecifiesFwdConvSpecialization && - SpecifiesGemmSpecialization && SpecifiesBlockGemm; +template +concept FwdXdlV3Algorithm = FwdAlgorithmV3 && SpecifiesXdl; // FWD WMMA algorithm concepts -template -concept FwdWmmaAlgorithm = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesTileTransferParameters3D && - SpecifiesGridwiseWmmaGemm && SpecifiesFwdConvSpecialization && - SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && SpecifiesLoopScheduler && - SpecifiesGridwiseGemmPipeline; +template +concept FwdWmmaAlgorithm = FwdAlgorithm && SpecifiesWmma; // FWD DL algorithms -template -concept FwdDlAlgorithm = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesFwdConvSpecialization && - SpecifiesGemmSpecialization && SpecifiesDlThreadConfig && SpecifiesDlThreadCluster && - SpecifiesDlFwdBlockTransfer && SpecifiesDlEpilogue; +template +concept FwdDlAlgorithm = FwdDlAlgorithmBase; // BWD weight XDL algorithm concepts -template +template concept BwdXdlAlgorithm = - BwdXdlAlgorithmBase && SpecifiesTransposeTransfer && SpecifiesGenericInstance; + BwdXdlAlgorithmBase && SpecifiesTransposeTransfer && + SpecifiesGenericInstance; -template -concept BwdMultiDXdlAlgorithm = BwdXdlAlgorithmBase && SpecifiesMultipleDSupport; +template +concept BwdMultiDXdlAlgorithm = + BwdXdlAlgorithmBase && SpecifiesMultipleDSupport; -template -concept BwdXdlV3Algorithm = BwdXdlV3AlgorithmBase && SpecifiesGenericInstance; +template +concept BwdXdlV3Algorithm = BwdXdlV3AlgorithmBase; -template -concept BwdTwoStageXdlAlgorithm = BwdXdlV3AlgorithmBase && SpecifiesTransposeTransfer && - SpecifiesGemmBatchOptions && SpecifiesTwoStageSupport; +template +concept BwdTwoStageXdlAlgorithm = + BwdXdlV3AlgorithmBase && SpecifiesTransposeTransfer && + SpecifiesTwoStageSupport; // BWD weight WMMA algorithm concepts -template -concept BwdWmmaAlgorithm = - BwdWmmaAlgorithmBase && SpecifiesNumPrefetchStages && SpecifiesLoopScheduler && - SpecifiesGridwiseGemmPipeline && SpecifiesGenericInstance; +template +concept BwdWmmaAlgorithm = BwdWmmaAlgorithmBase && SpecifiesGemmPipeline && + SpecifiesGenericInstance; -template -concept BwdMultiDWmmaV3Algorithm = BwdWmmaV3AlgorithmBase && SpecifiesMultipleDSupport; +template +concept BwdMultiDWmmaV3Algorithm = + BwdWmmaV3AlgorithmBase && SpecifiesMultipleDSupport; -template +template concept BwdWmmaV3Algorithm = - BwdWmmaV3AlgorithmBase && SpecifiesTransposeTransfer && SpecifiesGenericInstance; + BwdWmmaV3AlgorithmBase && SpecifiesTransposeTransfer; -template -concept BwdTwoStageWmmaV3Algorithm = BwdWmmaV3AlgorithmBase && SpecifiesTransposeTransfer && - SpecifiesGemmBatchOptions && SpecifiesTwoStageSupport; +template +concept BwdTwoStageWmmaV3Algorithm = + BwdWmmaV3AlgorithmBase && SpecifiesTransposeTransfer && + SpecifiesTwoStageSupport; -// BWD weigth DL algorithms -template +// BWD weight DL algorithms +template concept BwdDlAlgorithm = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && - SpecifiesBwdWeightConvSpecialization && SpecifiesDlThreadConfig && - SpecifiesDlThreadCluster && SpecifiesDlBwdBlockTransfer && SpecifiesDlEpilogue; + DlAlgorithm && SpecifiesBwdWeightConvSpecialization && + SpecifiesDlBwdBlockTransfer; + +// Concepts for valid XDL/WMMA algorithms +template +concept SpecifiesValidFwdXdlAlgorithm = + FwdXdlAlgorithm || FwdXdlV3Algorithm || LargeTensorAlgorithm; + +template +concept SpecifiesValidFwdWmmaAlgorithm = FwdWmmaAlgorithm; + +template +concept SpecifiesValidBwdXdlAlgorithm = + BwdXdlAlgorithm || BwdXdlV3Algorithm || BwdTwoStageXdlAlgorithm || + BwdMultiDXdlAlgorithm; + +template +concept SpecifiesValidBwdWmmaAlgorithm = + BwdWmmaAlgorithm || BwdWmmaV3Algorithm || BwdTwoStageWmmaV3Algorithm || + BwdMultiDWmmaV3Algorithm; + +template +concept FwdWarpGemmOrDL = SpecifiesValidWarpGemm || FwdDlAlgorithm; + +template +concept BwdWarpGemmOrDL = SpecifiesValidWarpGemm || BwdDlAlgorithm; } // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp index b02dea95589..9f86028c337 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp @@ -34,8 +34,8 @@ struct ConvBwdWeightMultiDWmmaV3Factory static constexpr auto BWD_CONV_SPECIALIZATION = internal::SetBwdWeightConvSpecialization(); - static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto WARP_GEMM = ALGORITHM.warp_gemm; static constexpr auto A_BLOCK_TRANSFER = internal::SetBwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -57,6 +57,11 @@ struct ConvBwdWeightMultiDWmmaV3Factory static_assert(AccessOrderLimits3D, "Invalid B source access order"); + static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size == + B_BLOCK_TRANSFER.global_memory_vector_load_size, + "A nd B block transfer vector load size need to be the same"); + static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size; + // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< @@ -78,11 +83,11 @@ struct ConvBwdWeightMultiDWmmaV3Factory BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k, - GRIDWISE_GEMM.k1, - GRIDWISE_GEMM.m_per_wmma, - GRIDWISE_GEMM.n_per_wmma, - GRIDWISE_GEMM.m_wmma_per_wave, - GRIDWISE_GEMM.n_wmma_per_wave, + GMEM_VECTOR_LOAD_SIZE, + WARP_GEMM.gemm_m_per_instruction, + WARP_GEMM.gemm_n_per_instruction, + WARP_GEMM.gemm_m_iters_per_wave, + WARP_GEMM.gemm_n_iters_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp index 4f6812617aa..ec67a3c6b09 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp @@ -34,9 +34,8 @@ struct ConvBwdWeightMultiDXdlFactory static constexpr auto BWD_CONV_SPECIALIZATION = internal::SetBwdWeightConvSpecialization(); - static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; - static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm; static constexpr auto A_BLOCK_TRANSFER = internal::SetBwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -53,6 +52,11 @@ struct ConvBwdWeightMultiDXdlFactory static_assert(AccessOrderLimits4D); static_assert(AccessOrderLimits4D); + static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size == + B_BLOCK_TRANSFER.global_memory_vector_load_size, + "A nd B block transfer vector load size need to be the same"); + static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size; + // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< SPATIAL_DIM, @@ -73,11 +77,11 @@ struct ConvBwdWeightMultiDXdlFactory BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k, - GRIDWISE_GEMM.k1, - XDL_PARAMS.m_per_xdl, - XDL_PARAMS.n_per_xdl, - XDL_PARAMS.m_xdl_per_wave, - XDL_PARAMS.n_xdl_per_wave, + GMEM_VECTOR_LOAD_SIZE, + XDL_PARAMS.gemm_m_per_instruction, + XDL_PARAMS.gemm_n_per_instruction, + XDL_PARAMS.gemm_m_iters_per_wave, + XDL_PARAMS.gemm_n_iters_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp index adf108bac48..afdc5e1a17f 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp @@ -34,8 +34,8 @@ struct ConvBwdWeightTwoStageWmmaV3Factory static constexpr auto BWD_CONV_SPECIALIZATION = internal::SetBwdWeightConvSpecialization(); - static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto WARP_GEMM = ALGORITHM.warp_gemm; static constexpr auto A_BLOCK_TRANSFER = internal::SetBwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -57,6 +57,11 @@ struct ConvBwdWeightTwoStageWmmaV3Factory static_assert(AccessOrderLimits3D, "Invalid B source access order"); + static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size == + B_BLOCK_TRANSFER.global_memory_vector_load_size, + "A nd B block transfer vector load size need to be the same"); + static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size; + // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< @@ -76,11 +81,11 @@ struct ConvBwdWeightTwoStageWmmaV3Factory BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k, - GRIDWISE_GEMM.k1, - GRIDWISE_GEMM.m_per_wmma, - GRIDWISE_GEMM.n_per_wmma, - GRIDWISE_GEMM.m_wmma_per_wave, - GRIDWISE_GEMM.n_wmma_per_wave, + GMEM_VECTOR_LOAD_SIZE, + WARP_GEMM.gemm_m_per_instruction, + WARP_GEMM.gemm_n_per_instruction, + WARP_GEMM.gemm_m_iters_per_wave, + WARP_GEMM.gemm_n_iters_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, @@ -101,7 +106,7 @@ struct ConvBwdWeightTwoStageWmmaV3Factory C_BLOCK_TRANSFER.scalar_per_vector, BLOCK_GEMM.scheduler, BLOCK_GEMM.pipeline_version, - ALGORITHM.num_conv_groups_to_merge, + ALGORITHM.gemm_pipeline.num_gemm_k_prefetch_stages, typename Types::OutComputeType, typename Types::InComputeType, ALGORITHM.max_transpose_transfer_src_scalar_per_vector, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp index d887c1c1ced..cac6de591a8 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp @@ -34,9 +34,8 @@ struct ConvBwdWeightTwoStageXdlFactory static constexpr auto BWD_CONV_SPECIALIZATION = internal::SetBwdWeightConvSpecialization(); - static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; - static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm; static constexpr auto A_BLOCK_TRANSFER = internal::SetBwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -58,6 +57,11 @@ struct ConvBwdWeightTwoStageXdlFactory static_assert(AccessOrderLimits3D, "Invalid B source access order"); + static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size == + B_BLOCK_TRANSFER.global_memory_vector_load_size, + "A nd B block transfer vector load size need to be the same"); + static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size; + // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< SPATIAL_DIM, @@ -76,11 +80,11 @@ struct ConvBwdWeightTwoStageXdlFactory BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k, - GRIDWISE_GEMM.k1, - XDL_PARAMS.m_per_xdl, - XDL_PARAMS.n_per_xdl, - XDL_PARAMS.m_xdl_per_wave, - XDL_PARAMS.n_xdl_per_wave, + GMEM_VECTOR_LOAD_SIZE, + XDL_PARAMS.gemm_m_per_instruction, + XDL_PARAMS.gemm_n_per_instruction, + XDL_PARAMS.gemm_m_iters_per_wave, + XDL_PARAMS.gemm_n_iters_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, @@ -101,7 +105,7 @@ struct ConvBwdWeightTwoStageXdlFactory C_BLOCK_TRANSFER.scalar_per_vector, BLOCK_GEMM.scheduler, BLOCK_GEMM.pipeline_version, - ALGORITHM.num_conv_groups_to_merge, + BLOCK_GEMM.num_conv_groups_to_merge, typename Types::OutComputeType, typename Types::InComputeType, ALGORITHM.max_transpose_transfer_src_scalar_per_vector, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp index 4067845291f..0fd089b6c08 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp @@ -34,8 +34,8 @@ struct ConvBwdWeightWmmaFactory static constexpr auto BWD_CONV_SPECIALIZATION = internal::SetBwdWeightConvSpecialization(); - static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto WARP_GEMM = ALGORITHM.warp_gemm; static constexpr auto GRIDWISE_GEMM_PIPELINE_VERSION = internal::SetGridwiseGemmPipelineVersion(); static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler(); @@ -60,6 +60,11 @@ struct ConvBwdWeightWmmaFactory static_assert(AccessOrderLimits3D, "Invalid B source access order"); + static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size == + B_BLOCK_TRANSFER.global_memory_vector_load_size, + "A nd B block transfer vector load size need to be the same"); + static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size; + // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle< SPATIAL_DIM, @@ -78,11 +83,11 @@ struct ConvBwdWeightWmmaFactory BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k, - GRIDWISE_GEMM.k1, - GRIDWISE_GEMM.m_per_wmma, - GRIDWISE_GEMM.n_per_wmma, - GRIDWISE_GEMM.m_wmma_per_wave, - GRIDWISE_GEMM.n_wmma_per_wave, + GMEM_VECTOR_LOAD_SIZE, + WARP_GEMM.gemm_m_per_instruction, + WARP_GEMM.gemm_n_per_instruction, + WARP_GEMM.gemm_m_iters_per_wave, + WARP_GEMM.gemm_n_iters_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, @@ -101,7 +106,7 @@ struct ConvBwdWeightWmmaFactory C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, to_sequence_v, C_BLOCK_TRANSFER.scalar_per_vector, - ALGORITHM.num_gemm_k_prefetch_stages, + ALGORITHM.gemm_pipeline.num_gemm_k_prefetch_stages, LOOP_SCHEDULER, GRIDWISE_GEMM_PIPELINE_VERSION>; }; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp index 027c8a1fba6..cb3c87905d8 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp @@ -34,8 +34,8 @@ struct ConvBwdWeightWmmaV3Factory static constexpr auto BWD_CONV_SPECIALIZATION = internal::SetBwdWeightConvSpecialization(); - static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto WARP_GEMM = ALGORITHM.warp_gemm; static constexpr auto A_BLOCK_TRANSFER = internal::SetBwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -57,6 +57,11 @@ struct ConvBwdWeightWmmaV3Factory static_assert(AccessOrderLimits3D, "Invalid B source access order"); + static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size == + B_BLOCK_TRANSFER.global_memory_vector_load_size, + "A nd B block transfer vector load size need to be the same"); + static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size; + // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< SPATIAL_DIM, @@ -75,11 +80,11 @@ struct ConvBwdWeightWmmaV3Factory BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k, - GRIDWISE_GEMM.k1, - GRIDWISE_GEMM.m_per_wmma, - GRIDWISE_GEMM.n_per_wmma, - GRIDWISE_GEMM.m_wmma_per_wave, - GRIDWISE_GEMM.n_wmma_per_wave, + GMEM_VECTOR_LOAD_SIZE, + WARP_GEMM.gemm_m_per_instruction, + WARP_GEMM.gemm_n_per_instruction, + WARP_GEMM.gemm_m_iters_per_wave, + WARP_GEMM.gemm_n_iters_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp index fbb177f3337..b382b7d3d9e 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp @@ -34,9 +34,8 @@ struct ConvBwdWeightXdlFactory static constexpr auto BWD_CONV_SPECIALIZATION = internal::SetBwdWeightConvSpecialization(); - static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; - static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm; static constexpr auto A_BLOCK_TRANSFER = internal::SetBwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -53,6 +52,11 @@ struct ConvBwdWeightXdlFactory static_assert(AccessOrderLimits4D); static_assert(AccessOrderLimits4D); + static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size == + B_BLOCK_TRANSFER.global_memory_vector_load_size, + "A nd B block transfer vector load size need to be the same"); + static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size; + // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle< SPATIAL_DIM, @@ -71,11 +75,11 @@ struct ConvBwdWeightXdlFactory BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k, - GRIDWISE_GEMM.k1, - XDL_PARAMS.m_per_xdl, - XDL_PARAMS.n_per_xdl, - XDL_PARAMS.m_xdl_per_wave, - XDL_PARAMS.n_xdl_per_wave, + GMEM_VECTOR_LOAD_SIZE, + XDL_PARAMS.gemm_m_per_instruction, + XDL_PARAMS.gemm_n_per_instruction, + XDL_PARAMS.gemm_m_iters_per_wave, + XDL_PARAMS.gemm_n_iters_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp index 66a47c54078..9c9bbd21af9 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp @@ -34,9 +34,8 @@ struct ConvBwdWeightXdlV3Factory static constexpr auto BWD_CONV_SPECIALIZATION = internal::SetBwdWeightConvSpecialization(); - static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; - static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm; static constexpr auto A_BLOCK_TRANSFER = internal::SetBwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -58,6 +57,11 @@ struct ConvBwdWeightXdlV3Factory static_assert(AccessOrderLimits3D, "Invalid B source access order"); + static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size == + B_BLOCK_TRANSFER.global_memory_vector_load_size, + "A nd B block transfer vector load size need to be the same"); + static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size; + // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< SPATIAL_DIM, @@ -76,11 +80,11 @@ struct ConvBwdWeightXdlV3Factory BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k, - GRIDWISE_GEMM.k1, - XDL_PARAMS.m_per_xdl, - XDL_PARAMS.n_per_xdl, - XDL_PARAMS.m_xdl_per_wave, - XDL_PARAMS.n_xdl_per_wave, + GMEM_VECTOR_LOAD_SIZE, + XDL_PARAMS.gemm_m_per_instruction, + XDL_PARAMS.gemm_n_per_instruction, + XDL_PARAMS.gemm_m_iters_per_wave, + XDL_PARAMS.gemm_n_iters_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp index e235db4bb09..cb51cb70dc7 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -104,48 +104,63 @@ template constexpr auto make_conv_instance() { - using AlgoType = std::remove_const_t; - // Reference algorithm supports all directions - if constexpr(ReferenceAlgorithm) + if constexpr(ReferenceAlgorithm) { return typename ReferenceFactory::Instance{}; } // CK Tile supports common factory for each direction - else if constexpr(TileAlgorithm) + else if constexpr(TileAlgorithm) { return typename ConvTileFactory::Instance{}; } // Forward direction (supports most algorithm variants) else if constexpr(ConvDirectionIsForward) { - if constexpr(FwdXdlV3Algorithm) - { - return typename ConvFwdXdlV3Factory::Instance{}; - } - else if constexpr(FwdXdlAlgorithm) + if constexpr(SpecifiesXdl) { - return typename ConvFwdXdlFactory::Instance{}; + if constexpr(FwdXdlV3Algorithm) + { + return typename ConvFwdXdlV3Factory::Instance{}; + } + else if constexpr(FwdXdlAlgorithm) + { + return typename ConvFwdXdlFactory::Instance{}; + } + else if constexpr(LargeTensorAlgorithm) + { + return + typename ConvFwdLargeTensorFactory::Instance{}; + } + else + { + static_assert( + SpecifiesValidFwdXdlAlgorithm, + "No suitable forward convolution XDL kernel factory found for the provided " + "ALGORITHM. " + "The ALGORITHM must satisfy requirements for one of: XDL V3, generic XDL, " + "DL (NHWC layout), or Large Tensor variant."); + } } - else if constexpr(FwdWmmaAlgorithm) + else if constexpr(SpecifiesWmma) { - return typename ConvFwdWmmaFactory::Instance{}; + if constexpr(FwdWmmaAlgorithm) + { + return typename ConvFwdWmmaFactory::Instance{}; + } + else + { + static_assert(FwdWmmaAlgorithm, "Did not find matching WMMA factory."); + } } - else if constexpr(FwdDlAlgorithm) + else if constexpr(FwdDlAlgorithm) { return typename ConvFwdDlFactory::Instance{}; } - else if constexpr(LargeTensorAlgorithm) - { - return typename ConvFwdLargeTensorFactory::Instance{}; - } else { - static_assert( - false, - "No suitable forward convolution kernel factory found for the provided ALGORITHM. " - "The ALGORITHM must satisfy requirements for one of: Reference, Tile, XDL V3, XDL, " - "WMMA, DL (NHWC layout), or Large Tensor variant."); + static_assert(FwdWarpGemmOrDL, + "Forward convolution: Algorithm must specify either DL, XDL or WMMA."); } } // Backward data direction (will expand with more algorithms in the future) @@ -159,54 +174,78 @@ constexpr auto make_conv_instance() // Backward weight direction (will expand with more algorithms in the future) else if constexpr(ConvDirectionIsBackwardWeight) { - if constexpr(BwdXdlAlgorithm) + if constexpr(SpecifiesXdl) { - return typename ConvBwdWeightXdlFactory::Instance{}; + // Start from more specialized and end with least specialized. + if constexpr(BwdTwoStageXdlAlgorithm) + { + return typename ConvBwdWeightTwoStageXdlFactory:: + Instance{}; + } + else if constexpr(BwdMultiDXdlAlgorithm) + { + return typename ConvBwdWeightMultiDXdlFactory:: + Instance{}; + } + else if constexpr(BwdXdlV3Algorithm) + { + return + typename ConvBwdWeightXdlV3Factory::Instance{}; + } + else if constexpr(BwdXdlAlgorithm) + { + return typename ConvBwdWeightXdlFactory::Instance{}; + } + else + { + static_assert(SpecifiesValidBwdXdlAlgorithm, + "No suitable backward weight convolution XDL kernel factory found " + "for the provided ALGORITHM. " + "The ALGORITHM must satisfy requirements for one of: Two-Stage XDL, " + "Multi-D XDL, DL, " + "generic XDL, or XDL V3 variant."); + } } - else if constexpr(BwdXdlV3Algorithm) + else if constexpr(SpecifiesWmma) { - return typename ConvBwdWeightXdlV3Factory::Instance{}; + // Start from more specialized and end with least specialized. + if constexpr(BwdTwoStageWmmaV3Algorithm) + { + return typename ConvBwdWeightTwoStageWmmaV3Factory:: + Instance{}; + } + else if constexpr(BwdMultiDWmmaV3Algorithm) + { + return typename ConvBwdWeightMultiDWmmaV3Factory:: + Instance{}; + } + else if constexpr(BwdWmmaV3Algorithm) + { + return + typename ConvBwdWeightWmmaV3Factory::Instance{}; + } + else if constexpr(BwdWmmaAlgorithm) + { + return typename ConvBwdWeightWmmaFactory::Instance{}; + } + else + { + static_assert(SpecifiesValidBwdWmmaAlgorithm, + "No suitable backward weight convolution WMMA kernel factory found " + "for the provided ALGORITHM. " + "The ALGORITHM must satisfy requirements for one of: Two-Stage WMMA " + "V3, Multi-D WMMA V3, " + "WMMA V3, or generic WMMA variant."); + } } - else if constexpr(BwdTwoStageXdlAlgorithm) - { - return - typename ConvBwdWeightTwoStageXdlFactory::Instance{}; - } - else if constexpr(BwdDlAlgorithm) + else if constexpr(BwdDlAlgorithm) { return typename ConvBwdWeightDlFactory::Instance{}; } - else if constexpr(BwdMultiDXdlAlgorithm) - { - return - typename ConvBwdWeightMultiDXdlFactory::Instance{}; - } - else if constexpr(BwdWmmaV3Algorithm) - { - return typename ConvBwdWeightWmmaV3Factory::Instance{}; - } - else if constexpr(BwdTwoStageWmmaV3Algorithm) - { - return typename ConvBwdWeightTwoStageWmmaV3Factory:: - Instance{}; - } - else if constexpr(BwdWmmaAlgorithm) - { - return typename ConvBwdWeightWmmaFactory::Instance{}; - } - else if constexpr(BwdMultiDWmmaV3Algorithm) - { - return typename ConvBwdWeightMultiDWmmaV3Factory:: - Instance{}; - } else { - static_assert( - false, - "No suitable backward weight convolution kernel factory found for the provided " - "ALGORITHM. The ALGORITHM must satisfy requirements for one of: Reference, Tile, " - "XDL, XDL V3, Two-Stage XDL, DL, Multi-D XDL, WMMA V3, Two-Stage " - "WMMA V3, WMMA, or Multi-D WMMA V3 variant."); + static_assert(BwdWarpGemmOrDL, + "Backward convolution: Algorithm must specify either DL, XDL or WMMA."); } } else diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp index 0ff410d7311..965b8f92f48 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp @@ -38,8 +38,7 @@ struct ConvFwdLargeTensorFactory static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler(); static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; - static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; + static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm; static constexpr auto A_BLOCK_TRANSFER = internal::SetFwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -74,17 +73,17 @@ struct ConvFwdLargeTensorFactory typename Ops::OutElementwiseOp, SPECIALIZATION.conv_spec, SPECIALIZATION.gemm_spec, - ALGORITHM.num_gemm_k_prefetch_stages, + ALGORITHM.gemm_pipeline.num_gemm_k_prefetch_stages, BLOCK.block_size, BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k, - GRIDWISE_GEMM.ak1, - GRIDWISE_GEMM.bk1, - XDL_PARAMS.m_per_xdl, - XDL_PARAMS.n_per_xdl, - XDL_PARAMS.m_xdl_per_wave, - XDL_PARAMS.n_xdl_per_wave, + A_BLOCK_TRANSFER.global_memory_vector_load_size, + B_BLOCK_TRANSFER.global_memory_vector_load_size, + XDL_PARAMS.gemm_m_per_instruction, + XDL_PARAMS.gemm_n_per_instruction, + XDL_PARAMS.gemm_m_iters_per_wave, + XDL_PARAMS.gemm_n_iters_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp index dd2fa65eaee..ad6f91b0a4d 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp @@ -41,9 +41,8 @@ struct ConvFwdXdlV3Factory static constexpr internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION, .gemm_spec = GEMM_SPECIALIZATION}; - static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; - static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm; static constexpr auto A_BLOCK_TRANSFER = internal::SetFwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -83,12 +82,12 @@ struct ConvFwdXdlV3Factory BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k, - GRIDWISE_GEMM.ak1, - GRIDWISE_GEMM.bk1, - XDL_PARAMS.m_per_xdl, - XDL_PARAMS.n_per_xdl, - XDL_PARAMS.m_xdl_per_wave, - XDL_PARAMS.n_xdl_per_wave, + A_BLOCK_TRANSFER.global_memory_vector_load_size, + B_BLOCK_TRANSFER.global_memory_vector_load_size, + XDL_PARAMS.gemm_m_per_instruction, + XDL_PARAMS.gemm_n_per_instruction, + XDL_PARAMS.gemm_m_iters_per_wave, + XDL_PARAMS.gemm_n_iters_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp index 2d6f7c394b9..6b277951779 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp @@ -38,7 +38,7 @@ struct ConvFwdWmmaFactory static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler(); static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto WARP_GEMM = ALGORITHM.warp_gemm; static constexpr auto GRIDWISE_GEMM_PIPELINE_VERSION = internal::SetGridwiseGemmPipelineVersion(); static constexpr auto A_BLOCK_TRANSFER = @@ -57,6 +57,11 @@ struct ConvFwdWmmaFactory static_assert(AccessOrderLimits3D); static_assert(AccessOrderLimits3D); + static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size == + B_BLOCK_TRANSFER.global_memory_vector_load_size, + "A nd B block transfer vector load size need to be the same"); + static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size; + // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< SPATIAL_DIM, @@ -75,16 +80,16 @@ struct ConvFwdWmmaFactory typename Ops::OutElementwiseOp, SPECIALIZATION.conv_spec, SPECIALIZATION.gemm_spec, - ALGORITHM.num_gemm_k_prefetch_stages, + ALGORITHM.gemm_pipeline.num_gemm_k_prefetch_stages, BLOCK.block_size, BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k, - GRIDWISE_GEMM.k1, - GRIDWISE_GEMM.m_per_wmma, - GRIDWISE_GEMM.n_per_wmma, - GRIDWISE_GEMM.m_wmma_per_wave, - GRIDWISE_GEMM.n_wmma_per_wave, + GMEM_VECTOR_LOAD_SIZE, + WARP_GEMM.gemm_m_per_instruction, + WARP_GEMM.gemm_n_per_instruction, + WARP_GEMM.gemm_m_iters_per_wave, + WARP_GEMM.gemm_n_iters_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp index e03e0359699..fa0df7fee8a 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp @@ -38,8 +38,7 @@ struct ConvFwdXdlFactory static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler(); static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; - static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; + static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm; static constexpr auto A_BLOCK_TRANSFER = internal::SetFwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -74,17 +73,17 @@ struct ConvFwdXdlFactory typename Ops::OutElementwiseOp, SPECIALIZATION.conv_spec, SPECIALIZATION.gemm_spec, - ALGORITHM.num_gemm_k_prefetch_stages, + ALGORITHM.gemm_pipeline.num_gemm_k_prefetch_stages, BLOCK.block_size, BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k, - GRIDWISE_GEMM.ak1, - GRIDWISE_GEMM.bk1, - XDL_PARAMS.m_per_xdl, - XDL_PARAMS.n_per_xdl, - XDL_PARAMS.m_xdl_per_wave, - XDL_PARAMS.n_xdl_per_wave, + A_BLOCK_TRANSFER.global_memory_vector_load_size, + B_BLOCK_TRANSFER.global_memory_vector_load_size, + XDL_PARAMS.gemm_m_per_instruction, + XDL_PARAMS.gemm_n_per_instruction, + XDL_PARAMS.gemm_m_iters_per_wave, + XDL_PARAMS.gemm_n_iters_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, @@ -106,7 +105,7 @@ struct ConvFwdXdlFactory typename Types::InComputeType, typename Types::WeiComputeType, LOOP_SCHEDULER, - ALGORITHM.num_conv_groups_to_merge>; + ALGORITHM.gemm_pipeline.num_conv_groups_to_merge>; }; } // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp index d873a4b9033..6259e239a9c 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp @@ -16,38 +16,40 @@ struct BlockTransfer ck::Array thread_cluster_dims{}; ck::Array thread_cluster_order{}; ck::Array src_access_order{}; - size_t src_vector_dim = 0; - size_t src_scalar_per_vector = 0; - size_t lds_dst_scalar_per_vector = 0; - bool is_direct_load = false; - bool lds_padding = false; + size_t global_memory_vector_load_size = 0; + size_t src_vector_dim = 0; + size_t src_scalar_per_vector = 0; + size_t lds_dst_scalar_per_vector = 0; + bool is_direct_load = false; + bool lds_padding = false; }; template constexpr BlockTransfer<> SetFwdConvBlockTransfer() { - auto& block_xfer = TRANSFER.block_transfer; - auto& block_order = TRANSFER.block_transfer_access_order; + auto& block_xfer = TRANSFER.thread_cluster; + auto& block_order = TRANSFER.thread_cluster_access_order; auto& src_order = TRANSFER.src_access_order; auto& lds_cfg = TRANSFER.lds_transfer; return BlockTransfer<>{ - .thread_cluster_dims = {block_xfer.k0, block_xfer.m_n, block_xfer.k1}, - .thread_cluster_order = {block_order.order[0], block_order.order[1], block_order.order[2]}, - .src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2]}, - .src_vector_dim = lds_cfg.src_vector_dim, - .src_scalar_per_vector = lds_cfg.src_scalar_per_vector, - .lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector, - .is_direct_load = lds_cfg.is_direct_load, - .lds_padding = lds_cfg.lds_padding, + .thread_cluster_dims = {block_xfer.k0, block_xfer.m_n, block_xfer.k1}, + .thread_cluster_order = {block_order.order[0], block_order.order[1], block_order.order[2]}, + .src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2]}, + .global_memory_vector_load_size = lds_cfg.global_memory_vector_load_size, + .src_vector_dim = lds_cfg.src_vector_dim, + .src_scalar_per_vector = lds_cfg.src_scalar_per_vector, + .lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector, + .is_direct_load = lds_cfg.is_direct_load, + .lds_padding = lds_cfg.lds_padding, }; } template constexpr auto SetBwdConvBlockTransfer() { - auto& block_xfer = TRANSFER.block_transfer; - auto& block_order = TRANSFER.block_transfer_access_order; + auto& block_xfer = TRANSFER.thread_cluster; + auto& block_order = TRANSFER.thread_cluster_access_order; auto& src_order = TRANSFER.src_access_order; auto& lds_cfg = TRANSFER.lds_transfer; @@ -58,36 +60,38 @@ constexpr auto SetBwdConvBlockTransfer() if constexpr(array_length == 3) { return BlockTransfer<3>{ - .thread_cluster_dims = {block_xfer.k0, block_xfer.m_n, block_xfer.k1}, - .thread_cluster_order = {block_order.order[0], - block_order.order[1], - block_order.order[2]}, - .src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2]}, - .src_vector_dim = lds_cfg.src_vector_dim, - .src_scalar_per_vector = lds_cfg.src_scalar_per_vector, - .lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector, - .lds_padding = lds_cfg.lds_padding, + .thread_cluster_dims = {block_xfer.k0, block_xfer.m_n, block_xfer.k1}, + .thread_cluster_order = {block_order.order[0], + block_order.order[1], + block_order.order[2]}, + .src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2]}, + .global_memory_vector_load_size = lds_cfg.global_memory_vector_load_size, + .src_vector_dim = lds_cfg.src_vector_dim, + .src_scalar_per_vector = lds_cfg.src_scalar_per_vector, + .lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector, + .lds_padding = lds_cfg.lds_padding, }; } else if constexpr(array_length == 4) { return BlockTransfer<4>{ - .thread_cluster_dims = {block_xfer.k_batch_size, - block_xfer.k0, - block_xfer.m_n, - block_xfer.k1}, - .thread_cluster_order = {block_order.order[0], - block_order.order[1], - block_order.order[2], - block_order.order[3]}, - .src_access_order = {src_order.order[0], - src_order.order[1], - src_order.order[2], - src_order.order[3]}, - .src_vector_dim = lds_cfg.src_vector_dim, - .src_scalar_per_vector = lds_cfg.src_scalar_per_vector, - .lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector, - .lds_padding = lds_cfg.lds_padding, + .thread_cluster_dims = {block_xfer.k_batch_size, + block_xfer.k0, + block_xfer.m_n, + block_xfer.k1}, + .thread_cluster_order = {block_order.order[0], + block_order.order[1], + block_order.order[2], + block_order.order[3]}, + .src_access_order = {src_order.order[0], + src_order.order[1], + src_order.order[2], + src_order.order[3]}, + .global_memory_vector_load_size = lds_cfg.global_memory_vector_load_size, + .src_vector_dim = lds_cfg.src_vector_dim, + .src_scalar_per_vector = lds_cfg.src_scalar_per_vector, + .lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector, + .lds_padding = lds_cfg.lds_padding, }; } else @@ -108,17 +112,17 @@ struct CBlockTransfer template constexpr CBlockTransfer SetCBlockTransfer() { - auto& thread_cluster_dims = ALGORITHM.transfer.c.thread_cluster_dims; + auto& thread_cluster_dims = ALGORITHM.transfer.c.thread_cluster; auto& epilogue_config = ALGORITHM.transfer.c.epilogue; return CBlockTransfer{ .m_xdl_per_wave_per_shuffle = epilogue_config.m_xdl_per_wave_per_shuffle, .n_xdl_per_wave_per_shuffle = epilogue_config.n_per_wave_per_shuffle, .thread_cluster_dims = { - thread_cluster_dims.m_block, - thread_cluster_dims.m_wave_per_xdl, - thread_cluster_dims.n_block, - thread_cluster_dims.n_wave_per_xdl, + thread_cluster_dims.gemm_m_block_size, + thread_cluster_dims.gemm_m_per_block, + thread_cluster_dims.gemm_n_block_size, + thread_cluster_dims.gemm_n_per_block, }, .scalar_per_vector = epilogue_config.scalar_per_vector, }; diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp index 9ed1eebc3c0..f123edbfa95 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp @@ -31,6 +31,8 @@ ConvSpec(CONV_ENUM, GEMM_ENUM) -> ConvSpec; struct BlockGemmSpec { + size_t num_conv_groups_to_merge{1}; + size_t num_gemm_k_prefetch_stages{1}; ck::BlockGemmPipelineVersion pipeline_version; ck::BlockGemmPipelineScheduler scheduler; }; @@ -38,7 +40,7 @@ struct BlockGemmSpec template consteval BlockGemmSpec SetBlockGemm() { - constexpr auto& BG = ALGORITHM.block_gemm_pipeline; + constexpr auto& BG = ALGORITHM.gemm_pipeline; ck::BlockGemmPipelineScheduler scheduler; ck::BlockGemmPipelineVersion version; @@ -63,13 +65,16 @@ consteval BlockGemmSpec SetBlockGemm() default: throw "Unknown PipelineVersion"; } - return BlockGemmSpec{.pipeline_version = version, .scheduler = scheduler}; + return BlockGemmSpec{.num_conv_groups_to_merge = BG.num_conv_groups_to_merge, + .num_gemm_k_prefetch_stages = BG.num_gemm_k_prefetch_stages, + .pipeline_version = version, + .scheduler = scheduler}; } template consteval ck::LoopScheduler SetLoopScheduler() { - constexpr auto loop_scheduler = ALGORITHM.loop_scheduler; + constexpr auto loop_scheduler = ALGORITHM.gemm_pipeline.scheduler; using ck_loop_sched = ck::LoopScheduler; switch(loop_scheduler) { @@ -83,7 +88,7 @@ consteval ck::LoopScheduler SetLoopScheduler() template consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion() { - constexpr auto pipeline_version = ALGORITHM.pipeline_version; + constexpr auto pipeline_version = ALGORITHM.gemm_pipeline.pipeline_version; using ck_pipeline = ck::PipelineVersion; switch(pipeline_version) { diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index c4cca05e524..4d09ad8cb4e 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -232,13 +232,42 @@ enum class PipelineScheduler enum class ConvAlgorithmSpecialization { - LARGE_TENSOR, - REFERENCE, // GPU reference implementation for validation, - TWO_STAGE, - MULTIPLE_D + NONE = 0, + LARGE_TENSOR = 1 << 0, + REFERENCE = 1 << 1, // GPU reference implementation for validation, + TWO_STAGE = 1 << 2, + MULTIPLE_D = 1 << 3, + PIPELINE_V3 = 1 << 4 }; -// to_string methods for enum classes +constexpr ConvAlgorithmSpecialization operator|(ConvAlgorithmSpecialization lhs, + ConvAlgorithmSpecialization rhs) +{ + using T = std::underlying_type_t; + return static_cast(static_cast(lhs) | static_cast(rhs)); +} + +constexpr ConvAlgorithmSpecialization operator&(ConvAlgorithmSpecialization lhs, + ConvAlgorithmSpecialization rhs) +{ + using T = std::underlying_type_t; + return static_cast(static_cast(lhs) & static_cast(rhs)); +} + +// Enable direct boolean conversion for flag checks +constexpr bool operator!(ConvAlgorithmSpecialization spec) +{ + using T = std::underlying_type_t; + return static_cast(spec) == 0; +} + +enum class MatrixInstructionType +{ + XDL, + WMMA +}; + +// toString methods for enum classes inline std::string_view to_string(DataType dt) { using enum DataType; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle_v3.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle_v3.cpp index 404d1dbacdb..6f9f086fb75 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle_v3.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle_v3.cpp @@ -24,7 +24,7 @@ constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeightMultiple .with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave) .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT) - .with_block_gemm(cku::BlockGemmDesc_v1_intrawave); + .with_gemm_pipeline(cku::BlockGemmDesc_v1_intrawave); using Builder = ckb::ConvBuilder; using Instance = Builder::Instance; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_wmma_cshuffle_v3.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_wmma_cshuffle_v3.cpp index 782f33f8450..9ff27f6f997 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_wmma_cshuffle_v3.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_wmma_cshuffle_v3.cpp @@ -19,12 +19,12 @@ constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 2, .weight = {.config = {.layout = GKYXC}}, .output = {.config = {.layout = NGKHW}}}; -constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3{} +constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_CShuffle_V3{} .with_thread_block(cku::ThreadBlock_64_32x32x32) .with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave) .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT) - .with_block_gemm(cku::BlockGemmDesc_v1_intrawave) + .with_gemm_pipeline(cku::BlockGemmDesc_v1_intrawave) .with_num_conv_groups_to_merge(2) .with_transpose_params(2, 2); diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp index a2a877dbcd4..57b85fa8c78 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp @@ -19,12 +19,12 @@ constexpr auto SIGNATURE = .weight = {.config = {.layout = ckb::TensorLayout::GKYXC}}, .output = {.config = {.layout = ckb::TensorLayout::GNHWK}}}; -constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle{} +constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_CShuffle_V3{} .with_thread_block(cku::ThreadBlock_64_32x32x32) .with_gemm_config(cku::BwdGemmParams_Xdl_1x1_per_wave) .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT) - .with_block_gemm(cku::BlockGemmDesc_v2_intrawave) + .with_gemm_pipeline(cku::BlockGemmDesc_v2_intrawave) .with_num_conv_groups_to_merge(2) .with_transpose_params(2, 4); diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp index ff350ac8049..136c8fea40d 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp @@ -19,13 +19,13 @@ constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 3, .weight = {.config = {.layout = GKZYXC}}, .output = {.config = {.layout = NGKDHW}}}; -constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle{} - .with_thread_block(cku::ThreadBlock_64_32x32x32) - .with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave) - .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) - .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT) - .with_prefetch_config(1, ckb::PipelineScheduler::DEFAULT) - .with_gridwise_gemm_pipeline(ckb::PipelineVersion::V1); +constexpr auto ALGORITHM = + cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle{} + .with_thread_block(cku::ThreadBlock_64_32x32x32) + .with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave) + .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) + .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT) + .with_gemm_pipeline(ckb::PipelineVersion::V1, ckb::PipelineScheduler::DEFAULT); using Builder = ckb::ConvBuilder; using Instance = Builder::Instance; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp index 60f7d5bd643..b91d8de810a 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp @@ -25,7 +25,7 @@ constexpr auto ALGORITHM = .with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave) .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) .with_bwd_specialization(ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0) - .with_block_gemm(cku::BlockGemmDesc_v1_intrawave) + .with_gemm_pipeline(cku::BlockGemmDesc_v1_intrawave) .with_transpose_params(4, 4); using Builder = ckb::ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp index 4ad97209e5e..ccdad77e392 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp @@ -25,7 +25,7 @@ constexpr auto ALGORITHM = .with_gemm_config(cku::BwdGemmParams_Xdl_1x1_per_wave) .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) .with_bwd_specialization(ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0) - .with_block_gemm(cku::BlockGemmDesc_v2_intrawave); + .with_gemm_pipeline(cku::BlockGemmDesc_v2_intrawave); using Builder = ckb::ConvBuilder; using Instance = Builder::Instance; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp index 8d85370b268..ef271cd9879 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp @@ -29,13 +29,13 @@ TEST(FwdConvInstances, .output = {.config = {.layout = NGKW}, .operation = {.elementwise_operation = SCALE}}}; constexpr auto FwdConvAlgorithm = - ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_CShuffle_V3{} .with_thread_block(ThreadBlock_256_256x256x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) .with_transfer(Transfer_4x64x1) .with_fwd_specializations(ConvSpecialization::FILTER_1X1_STRIDE1_PAD0, GemmSpecialization::MNKPadding) - .with_block_gemm(BlockGemmDesc_v2_intrawave); + .with_gemm_pipeline(BlockGemmDesc_v2_intrawave); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp index d3ace110c4b..4a30766bdd5 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp @@ -26,12 +26,11 @@ TEST(FwdConvInstances, .output = {.config = {.layout = NWGK}}}; constexpr auto FwdConvAlgorithm = - ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_CShuffle{} .with_thread_block(ThreadBlock_64_64x32x32) .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) .with_transfer(Transfer_4x16x1) .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) - .with_prefetch_config(1, PipelineScheduler::DEFAULT) .with_num_conv_groups_to_merge(2); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp index 06d200429c4..ad36a78c1fd 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp @@ -28,14 +28,13 @@ TEST(FwdConvInstances, .output = {.config = {.layout = GNWK}}}; constexpr auto FwdConvAlgorithm = - ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle{} + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_CShuffle{} .with_thread_block(ThreadBlock_128_64x64x64) - .with_gemm_config(GemmParams_Wmma_2x1_per_wave) - .with_transfer(Transfer_4x32x1) + .with_gemm_config(GemmParams_Wmma_16x16_2x2_per_wave) + .with_transfer(Transfer_4x32x1_vector_load_16_generic) .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) - .with_prefetch_config(1, PipelineScheduler::DEFAULT) - .with_num_conv_groups_to_merge(2) - .with_gridwise_gemm_pipeline(PipelineVersion::V1); + .with_num_gemm_k_prefetch_stages(1) + .with_gemm_pipeline(PipelineVersion::V1, PipelineScheduler::DEFAULT); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp index 610e2fad5fe..9ed5758b4e8 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp @@ -26,12 +26,12 @@ TEST(FwdConvInstances, .output = {.config = {.layout = NHWGK}}}; constexpr auto FwdConvAlgorithm = - ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_CShuffle_V3{} .with_thread_block(ThreadBlock_256_256x256x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) .with_transfer(Transfer_4x64x1) .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) - .with_block_gemm(BlockGemmDesc_v1_intrawave); + .with_gemm_pipeline(BlockGemmDesc_v1_intrawave); using Builder = ConvBuilder; @@ -62,14 +62,13 @@ TEST(FwdConvInstances, .weight = {.config = {.layout = GKYXC}}, .output = {.config = {.layout = NHWGK}}}; - constexpr auto FwdConvAlgorithm = - ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} - .with_thread_block(ThreadBlock_256_256x256x32) - .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) - .with_transfer(Transfer_4x64x1) - .with_fwd_specializations(ConvSpecialization::FILTER_3x3, - GemmSpecialization::MNKPadding) - .with_block_gemm(BlockGemmDesc_v5_intrawave); + constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_CShuffle_V3{} + .with_thread_block(ThreadBlock_256_256x256x32) + .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) + .with_transfer(Transfer_4x64x1) + .with_fwd_specializations(ConvSpecialization::FILTER_3x3, + GemmSpecialization::MNKPadding) + .with_gemm_pipeline(BlockGemmDesc_v5_intrawave); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp index 23edef54369..be2fdd689af 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp @@ -31,12 +31,11 @@ TEST(FwdConvInstances, .with_auxiliary_operand_configs()}}; constexpr auto FwdConvAlgorithm = - ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_CShuffle{} .with_thread_block(ThreadBlock_64_64x32x32) .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) .with_transfer(Transfer_4x16x1) .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) - .with_prefetch_config(1, PipelineScheduler::DEFAULT) .with_num_conv_groups_to_merge(1); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp index 3e5e39191ee..1ea09918aac 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp @@ -24,13 +24,13 @@ constexpr auto SIGNATURE = .weight = {.config = {.layout = ckb::TensorLayout::GKYXC}}, .output = {.config = {.layout = ckb::TensorLayout::GNHWK}}}; -constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} +constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_CShuffle_V3{} .with_thread_block(cku::ThreadBlock_256_256x256x32) .with_gemm_config(cku::FwdGemmParams_Xdl_4x4_per_wave) .with_transfer(cku::Transfer_4x64x1) .with_fwd_specializations(ckb::ConvSpecialization::DEFAULT, ckb::GemmSpecialization::MNKPadding) - .with_block_gemm(cku::BlockGemmDesc_v3_intrawave); + .with_gemm_pipeline(cku::BlockGemmDesc_v3_intrawave); using Builder = ckb::ConvBuilder; using Instance = Builder::Instance; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp index bb35c53ba06..c26ff9a9cc3 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp @@ -25,13 +25,13 @@ TEST(FwdConvInstances, .output = {.config = {.layout = NGKHW}}}; constexpr auto FwdConvAlgorithm = - ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_CShuffle_V3{} .with_thread_block(ThreadBlock_256_128x128x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) .with_transfer(Transfer_4x64x1) .with_fwd_specializations(ConvSpecialization::FILTER_1X1_STRIDE1_PAD0, GemmSpecialization::MNKPadding) - .with_block_gemm(BlockGemmDesc_v4_intrawave); + .with_gemm_pipeline(BlockGemmDesc_v4_intrawave); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp index b117e693fe3..a83c9c76551 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp @@ -26,12 +26,11 @@ TEST(FwdConvInstances, .output = {.config = {.layout = NHWGK}}}; constexpr auto FwdConvAlgorithm = - ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_CShuffle{} .with_thread_block(ThreadBlock_256_256x128x32) .with_gemm_config(FwdGemmParams_Xdl_4x2_per_wave) .with_transfer(Transfer_4x64x1_fp8) .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) - .with_prefetch_config(1, PipelineScheduler::DEFAULT) .with_num_conv_groups_to_merge(1); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp index 97bc0a00e5d..a4dd6171d88 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp @@ -30,7 +30,6 @@ TEST(FwdConvInstances, .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) .with_transfer(Transfer_4x16x1) .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) - .with_prefetch_config(1, PipelineScheduler::DEFAULT) .with_num_conv_groups_to_merge(1); using Builder = ConvBuilder; @@ -67,7 +66,6 @@ TEST( .with_transfer(Transfer_4x16x1) .with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0, GemmSpecialization::MNKPadding) - .with_prefetch_config(1, PipelineScheduler::DEFAULT) .with_num_conv_groups_to_merge(1); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp index 9e6ca00e581..34f3e286468 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp @@ -26,12 +26,12 @@ TEST(FwdConvInstances, .output = {.config = {.layout = GNDHWK}}}; constexpr auto FwdConvAlgorithm = - ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_CShuffle_V3{} .with_thread_block(ThreadBlock_256_256x256x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) .with_transfer(Transfer_4x64x1) .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) - .with_block_gemm(BlockGemmDesc_v3_intrawave); + .with_gemm_pipeline(BlockGemmDesc_v3_intrawave); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp index 56d4b8be590..07c399a9795 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp @@ -26,13 +26,13 @@ TEST(FwdConvInstances, .output = {.config = {.layout = NDHWGK}}}; constexpr auto FwdConvAlgorithm = - ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_CShuffle_V3{} .with_thread_block(ThreadBlock_256_128x128x32) .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) .with_transfer(Transfer_4x64x1) .with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0, GemmSpecialization::MNKPadding) - .with_block_gemm(BlockGemmDesc_v4_intrawave); + .with_gemm_pipeline(BlockGemmDesc_v4_intrawave); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp index df8339241bc..d33ac55db3c 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp @@ -26,13 +26,13 @@ TEST(FwdConvInstances, .output = {.config = {.layout = NGKDHW}}}; constexpr auto FwdConvAlgorithm = - ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_CShuffle_V3{} .with_thread_block(ThreadBlock_256_256x256x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) .with_transfer(Transfer_4x64x1) .with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0, GemmSpecialization::MNKPadding) - .with_block_gemm(BlockGemmDesc_v1_intrawave); + .with_gemm_pipeline(BlockGemmDesc_v1_intrawave); using Builder = ConvBuilder; diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index 617686fda14..6989887a264 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -28,53 +28,28 @@ struct ThreadBlock }; static_assert(ckb::ThreadBlockDescriptor); -struct XdlParams +struct WarpGemmParams { - size_t m_per_xdl = 0; - size_t n_per_xdl = 0; - size_t m_xdl_per_wave = 0; - size_t n_xdl_per_wave = 0; + MatrixInstructionType matrix_instruction; + size_t gemm_m_per_instruction = 0; + size_t gemm_n_per_instruction = 0; + size_t gemm_m_iters_per_wave = 0; + size_t gemm_n_iters_per_wave = 0; }; -static_assert(ckb::GridwiseXdlGemmDescriptor); +static_assert(ckb::WarpGemmDescriptor); -// Describe gridwise XDL GEMM parameters. -struct GridwiseFwdXdlGemm +struct GemmPipeline { - // NOTE: ak1 and bk1 are difficult to verify in the kernel instantiation!!! - size_t ak1 = 0; - size_t bk1 = 0; - XdlParams xdl_params; -}; -static_assert(ckb::GridwiseFwdXdlGemmDescriptor); - -struct GridwiseBwdXdlGemm -{ - size_t k1 = 0; - XdlParams xdl_params; -}; -static_assert(ckb::GridwiseBwdXdlGemmDescriptor); - -// Describe gridwise WMMA GEMM parameters. -struct GridwiseWmmaGemm -{ - size_t k1 = 0; - size_t m_per_wmma = 0; - size_t n_per_wmma = 0; - size_t m_wmma_per_wave = 0; - size_t n_wmma_per_wave = 0; -}; -static_assert(ckb::GridwiseWmmaGemmDescriptor); - -struct BlockGemmPipeline -{ - PipelineVersion pipeline_version; - PipelineScheduler scheduler; + size_t num_gemm_k_prefetch_stages{1}; + size_t num_conv_groups_to_merge{1}; + PipelineVersion pipeline_version{PipelineVersion::V1}; + PipelineScheduler scheduler{PipelineScheduler::DEFAULT}; }; -static_assert(ckb::BlockGemmPipelineDescriptor); +static_assert(ckb::GemmPipelineDescriptor); -// Describe Aand B block transfer thread cluster lengths. -template -struct BlockTransfer +// Describe input tensor thread cluster lengths. +template +struct InputThreadCluster { size_t k0; size_t m_n; @@ -82,29 +57,30 @@ struct BlockTransfer size_t k_batch_size; }; -// Specialization for ThreadSliceLength == 3 +// Specialization for ThreadClusterRank == 3 template <> -struct BlockTransfer<3> +struct InputThreadCluster<3> { size_t k0; size_t m_n; size_t k1; }; -static_assert(ckb::BlockTransferDescriptor, 3>); -static_assert(ckb::BlockTransferDescriptor, 4>); +static_assert(ckb::InputTileThreadClusterDescriptor, 3>); +static_assert(ckb::InputTileThreadClusterDescriptor, 4>); // Describe C block transfer thread cluster lengths. -struct ThreadCluster +struct OutputThreadCluster { - size_t m_block; - size_t m_wave_per_xdl; - size_t n_block; - size_t n_wave_per_xdl; + size_t gemm_m_block_size; + size_t gemm_m_per_block; + size_t gemm_n_block_size; + size_t gemm_n_per_block; }; -static_assert(ThreadClusterDescriptor); +static_assert(OutputTileThreadClusterDescriptor); struct LdsTransfer { + size_t global_memory_vector_load_size; size_t src_vector_dim; size_t src_scalar_per_vector; size_t lds_dst_scalar_per_vector; @@ -121,35 +97,35 @@ struct Epilogue }; static_assert(EpilogueDescriptor); -template +template struct AccessOrder { - std::array order; + std::array order; }; static_assert(AccessOrderDescriptor>); static_assert(AccessOrderDescriptor>); -template -struct InputTransfer +template +struct InputTileTransfer { - BlockTransfer block_transfer; + InputThreadCluster thread_cluster; LdsTransfer lds_transfer; - AccessOrder block_transfer_access_order; - AccessOrder src_access_order; + AccessOrder thread_cluster_access_order; + AccessOrder src_access_order; }; -struct OutputTransfer +struct OutputTileTransfer { - ThreadCluster thread_cluster_dims; + OutputThreadCluster thread_cluster; Epilogue epilogue; }; -template -struct Transfer +template +struct InputOutputTileTransfer { - InputTransfer a; - InputTransfer b; - OutputTransfer c; + InputTileTransfer a; + InputTileTransfer b; + OutputTileTransfer c; }; // DL-specific descriptors @@ -199,25 +175,15 @@ struct ThreadBlock_ ThreadBlock thread_block; }; -struct FwdXdlGemm_ -{ - GridwiseFwdXdlGemm gridwise_gemm; -}; - -struct BwdXdlGemm_ +struct WarpGemm_ { - GridwiseBwdXdlGemm gridwise_gemm; + WarpGemmParams warp_gemm; }; -struct WmmaGemm_ +template +struct InputOutputTileTransfer_ { - GridwiseWmmaGemm gridwise_gemm; -}; - -template -struct Transfer_ -{ - Transfer transfer; + InputOutputTileTransfer transfer; }; struct ConvSpecializationFwd_ @@ -231,31 +197,15 @@ struct ConvSpecializationBwdWeight_ ConvSpecialization bwd_weight_specialization; }; -struct Prefetch_ -{ - size_t num_gemm_k_prefetch_stages; - PipelineScheduler loop_scheduler; -}; - struct TransposeParams_ { size_t max_transpose_transfer_src_scalar_per_vector{1}; size_t max_transpose_transfer_dst_scalar_per_vector{1}; }; -struct GemmBatchOptions_ -{ - size_t num_conv_groups_to_merge{1}; -}; - -struct BlockGemm_ +struct GemmPipeline_ { - BlockGemmPipeline block_gemm_pipeline; -}; - -struct GridGemm_ -{ - PipelineVersion pipeline_version; + GemmPipeline gemm_pipeline; }; struct DlThreadConfig_ @@ -282,22 +232,10 @@ struct DlTransfer_ DlTransfer transfer; }; -struct TwoStageSpecialization_ +template +struct AlgorithmSpecialization_ { - static constexpr ConvAlgorithmSpecialization specialization = - ConvAlgorithmSpecialization::TWO_STAGE; -}; - -struct MultipleDSpecialization_ -{ - static constexpr ConvAlgorithmSpecialization specialization = - ConvAlgorithmSpecialization::MULTIPLE_D; -}; - -struct LargeTensorSpecialization_ -{ - static constexpr ConvAlgorithmSpecialization specialization = - ConvAlgorithmSpecialization::LARGE_TENSOR; + static constexpr ConvAlgorithmSpecialization specialization = Specialization; }; // Specify thread block dimensions for a GEMM (CK Tile). @@ -386,30 +324,16 @@ struct ConvAlgorithmTemplate : Components... constexpr auto with_gemm_config(const GemmConfig& gemm) const { auto result = *this; - if constexpr(std::is_base_of_v) - { - result.gridwise_gemm = gemm; - } - else if constexpr(std::is_base_of_v) - { - result.gridwise_gemm = gemm; - } - else if constexpr(std::is_base_of_v) - { - result.gridwise_gemm = gemm; - } - else - { - static_assert(false, "Unrecognized GemmConfig type"); - } + static_assert(std::is_base_of_v); + result.warp_gemm = gemm; return result; } template constexpr auto with_transfer(const T& t) const { - static_assert(std::is_base_of_v, ConvAlgorithmTemplate> || - std::is_base_of_v, ConvAlgorithmTemplate>); + static_assert(std::is_base_of_v, ConvAlgorithmTemplate> || + std::is_base_of_v, ConvAlgorithmTemplate>); auto result = *this; result.transfer = t; return result; @@ -433,15 +357,6 @@ struct ConvAlgorithmTemplate : Components... return result; } - constexpr auto with_prefetch_config(size_t k_prefetch_stages, PipelineScheduler scheduler) const - { - static_assert(std::is_base_of_v); - auto result = *this; - result.num_gemm_k_prefetch_stages = k_prefetch_stages; - result.loop_scheduler = scheduler; - return result; - } - constexpr auto with_transpose_params(size_t max_src_scalar_per_vector, size_t max_dst_scalar_per_vector) const { @@ -454,26 +369,51 @@ struct ConvAlgorithmTemplate : Components... constexpr auto with_num_conv_groups_to_merge(size_t num_groups_to_merge) const { - static_assert(std::is_base_of_v); - auto result = *this; - result.num_conv_groups_to_merge = num_groups_to_merge; + static_assert(std::is_base_of_v); + auto result = *this; + result.gemm_pipeline.num_conv_groups_to_merge = num_groups_to_merge; return result; } - template - constexpr auto with_block_gemm(const BG& bg) const + constexpr auto with_num_gemm_k_prefetch_stages(size_t num_prefetch_stages) const { - static_assert(std::is_base_of_v); - auto result = *this; - result.block_gemm_pipeline = bg; + static_assert(std::is_base_of_v); + auto result = *this; + result.gemm_pipeline.num_gemm_k_prefetch_stages = num_prefetch_stages; + return result; + } + + template + constexpr auto with_gemm_pipeline(const PL& pl) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.gemm_pipeline = pl; return result; } - constexpr auto with_gridwise_gemm_pipeline(const PipelineVersion plv) const + constexpr auto with_gemm_pipeline(const PipelineVersion plv) const { - static_assert(std::is_base_of_v); - auto result = *this; - result.pipeline_version = plv; + static_assert(std::is_base_of_v); + auto result = *this; + result.gemm_pipeline.pipeline_version = plv; + return result; + } + + constexpr auto with_gemm_pipeline(const PipelineScheduler sch) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.gemm_pipeline.scheduler = sch; + return result; + } + + constexpr auto with_gemm_pipeline(const PipelineVersion plv, const PipelineScheduler sch) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.gemm_pipeline.pipeline_version = plv; + result.gemm_pipeline.scheduler = sch; return result; } @@ -553,29 +493,24 @@ struct ConvAlgorithmTemplate : Components... // Fwd algorithm types -using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = - ConvAlgorithmTemplate, - ConvSpecializationFwd_, - Prefetch_, - GemmBatchOptions_>; +using enum ckb::ConvAlgorithmSpecialization; -using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 = +// Covers both XDL and WMMA variants for generic fwd convolution +using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_CShuffle = ConvAlgorithmTemplate, + WarpGemm_, + InputOutputTileTransfer_<>, ConvSpecializationFwd_, - BlockGemm_>; + GemmPipeline_, + AlgorithmSpecialization_<>>; -using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle = +using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_CShuffle_V3 = ConvAlgorithmTemplate, + WarpGemm_, + InputOutputTileTransfer_<>, ConvSpecializationFwd_, - GridGemm_, - Prefetch_, - GemmBatchOptions_>; + GemmPipeline_, + AlgorithmSpecialization_>; using ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK = ConvAlgorithmTemplate, + WarpGemm_, + InputOutputTileTransfer_<>, ConvSpecializationFwd_, - Prefetch_, - GemmBatchOptions_, - LargeTensorSpecialization_>; + GemmPipeline_, + AlgorithmSpecialization_>; // CK Tile algorithm using ConvAlgorithm_Tile_GroupedConvolutionKernel = ConvAlgorithmTemplate; // Reference algorithm descriptor - for GPU reference validation -// This is a simple algorithm that requires no complex configuration, -// just a specialization marker to identify it as a reference implementation. -struct ConvAlgorithm_Reference -{ - static constexpr auto specialization = ckb::ConvAlgorithmSpecialization::REFERENCE; - // GPU reference uses simple algorithm, no tile configuration needed -}; +using ConvAlgorithm_Reference = ConvAlgorithmTemplate>; // Bwd weight algorithm types using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle = ConvAlgorithmTemplate, + WarpGemm_, + InputOutputTileTransfer_<4>, + ConvSpecializationBwdWeight_, + TransposeParams_, + AlgorithmSpecialization_<>>; + +using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle = + ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, - TransposeParams_>; + GemmPipeline_, + AlgorithmSpecialization_<>>; -using ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle = +// Covers both XDL and WMMA variants +using ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_CShuffle_V3 = ConvAlgorithmTemplate, + WarpGemm_, + InputOutputTileTransfer_<>, ConvSpecializationBwdWeight_, - BlockGemm_, + GemmPipeline_, TransposeParams_, - GemmBatchOptions_, - TwoStageSpecialization_>; + AlgorithmSpecialization_>; using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 = ConvAlgorithmTemplate, + WarpGemm_, + InputOutputTileTransfer_<>, ConvSpecializationBwdWeight_, - BlockGemm_>; + GemmPipeline_, + AlgorithmSpecialization_>; + +using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3 = + ConvAlgorithmTemplate, + ConvSpecializationBwdWeight_, + GemmPipeline_, + TransposeParams_, + AlgorithmSpecialization_>; using ConvAlgorithm_DeviceGroupedConvBwdWeight_Dl = ConvAlgorithmTemplate, - ConvSpecializationBwdWeight_, - MultipleDSpecialization_>; - -using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3 = - ConvAlgorithmTemplate, + WarpGemm_, + InputOutputTileTransfer_<4>, ConvSpecializationBwdWeight_, - BlockGemm_, - TransposeParams_>; + AlgorithmSpecialization_>; using ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3 = ConvAlgorithmTemplate, + WarpGemm_, + InputOutputTileTransfer_<>, ConvSpecializationBwdWeight_, - BlockGemm_, + GemmPipeline_, TransposeParams_, - GemmBatchOptions_, - TwoStageSpecialization_>; - -using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle = - ConvAlgorithmTemplate, - ConvSpecializationBwdWeight_, - GridGemm_, - Prefetch_>; + AlgorithmSpecialization_>; using ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_V3 = ConvAlgorithmTemplate, + WarpGemm_, + InputOutputTileTransfer_<>, ConvSpecializationBwdWeight_, - BlockGemm_, - MultipleDSpecialization_>; + GemmPipeline_, + AlgorithmSpecialization_>; } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/test_conv_description.cpp b/experimental/builder/test/test_conv_description.cpp index 9e8008ccf02..c8cfbe6b196 100644 --- a/experimental/builder/test/test_conv_description.cpp +++ b/experimental/builder/test/test_conv_description.cpp @@ -117,53 +117,60 @@ static_assert(!ckb::ConvSignatureDescriptor transfer{ + ckb::test::InputOutputTileTransfer<> transfer{ .a = { - .block_transfer = {.k0 = 1, .m_n = 128, .k1 = 2}, - .lds_transfer = {.src_vector_dim = 2, - .src_scalar_per_vector = 2, - .lds_dst_scalar_per_vector = 2, - .is_direct_load = false, - .lds_padding = false}, - .block_transfer_access_order = {.order = {0, 1, 2}}, + .thread_cluster = {.k0 = 1, .m_n = 128, .k1 = 2}, + .lds_transfer = {.global_memory_vector_load_size = 8, + .src_vector_dim = 2, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 2, + .is_direct_load = false, + .lds_padding = false}, + .thread_cluster_access_order = {.order = {0, 1, 2}}, .src_access_order = {.order = {0, 1, 2}}, }, .b = { - .block_transfer = {.k0 = 1, .m_n = 128, .k1 = 2}, - .lds_transfer = {.src_vector_dim = 2, - .src_scalar_per_vector = 2, - .lds_dst_scalar_per_vector = 2, - .is_direct_load = false, - .lds_padding = false}, - .block_transfer_access_order = {.order = {0, 1, 2}}, + .thread_cluster = {.k0 = 1, .m_n = 128, .k1 = 2}, + .lds_transfer = {.global_memory_vector_load_size = 8, + .src_vector_dim = 2, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 2, + .is_direct_load = false, + .lds_padding = false}, + .thread_cluster_access_order = {.order = {0, 1, 2}}, .src_access_order = {.order = {0, 1, 2}}, }, .c = { - .thread_cluster_dims = - {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, - .epilogue = {.m_xdl_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, - .scalar_per_vector = 2}, + .thread_cluster = {.gemm_m_block_size = 1, + .gemm_m_per_block = 32, + .gemm_n_block_size = 1, + .gemm_n_per_block = 8}, + .epilogue = {.m_xdl_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 2}, }, }; ckb::ConvSpecialization fwd_specialization = ckb::ConvSpecialization::DEFAULT; ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::Default; - ckb::test::BlockGemmPipeline block_gemm_pipeline{.pipeline_version = ckb::PipelineVersion::V4, - .scheduler = - ckb::PipelineScheduler::INTRAWAVE}; + ckb::test::GemmPipeline gemm_pipeline{.pipeline_version = ckb::PipelineVersion::V4, + .scheduler = ckb::PipelineScheduler::INTRAWAVE}; }; static_assert(ckb::ConvAlgorithmDescriptor); diff --git a/experimental/builder/test/unit_conv_tuning_params.cpp b/experimental/builder/test/unit_conv_tuning_params.cpp index 90057429309..26217dc22e1 100644 --- a/experimental/builder/test/unit_conv_tuning_params.cpp +++ b/experimental/builder/test/unit_conv_tuning_params.cpp @@ -17,9 +17,11 @@ TEST(ConvTuningParams, AssignsBlockGemmParams) { struct BlockGemm { + size_t num_conv_groups_to_merge = 1; + size_t num_gemm_k_prefetch_stages = 1; ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V3; ckb::PipelineScheduler scheduler = ckb::PipelineScheduler::INTRAWAVE; - } block_gemm_pipeline; + } gemm_pipeline; } kAlgorithm; constexpr auto block_gemm = SetBlockGemm(); @@ -31,7 +33,10 @@ TEST(ConvTuningParams, AssignsLoopSchedulerParam) { constexpr struct Algorithm { - ckb::PipelineScheduler loop_scheduler = ckb::PipelineScheduler::INTERWAVE; + struct GemmPipeline + { + ckb::PipelineScheduler scheduler = ckb::PipelineScheduler::INTERWAVE; + } gemm_pipeline; } kAlgorithm; constexpr auto loop_scheduler = SetLoopScheduler(); @@ -42,7 +47,10 @@ TEST(ConvTuningParams, AssignsGridwiseGemmPipelineVersion) { constexpr struct Algorithm { - ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V4; + struct GemmPipeline + { + ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V4; + } gemm_pipeline; } kAlgorithm; constexpr auto pipeline_version = SetGridwiseGemmPipelineVersion(); diff --git a/experimental/builder/test/utils/ckb_conv_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_test_configs.hpp index 3b83ead2d0d..4209d708cdf 100644 --- a/experimental/builder/test/utils/ckb_conv_test_configs.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_configs.hpp @@ -12,6 +12,7 @@ namespace ck_tile::builder::test_utils { using namespace ck_tile::builder; using namespace test; +// Test configs for DL algorithms constexpr DlThreadConfig DlThreadConfig_16x2x4x4x1{ .k0_per_block = 16, .k1 = 2, .m1_per_thread = 4, .n1_per_thread = 4, .k_per_thread = 1}; @@ -50,238 +51,322 @@ constexpr DlTransfer<5> DlTransfer5D{.a = DlBlockTransfer_1x8x1x1x1, .src_dst_vector_dim = 5, .dst_scalar_per_vector = 1}}; -constexpr Transfer<> Transfer_4x64x1{ +// XLD/WMMA test configs +constexpr InputOutputTileTransfer<> Transfer_4x64x1{ .a = { - .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, - .lds_transfer = {.src_vector_dim = 2, - .src_scalar_per_vector = 2, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = false}, - .block_transfer_access_order = {1, 0, 2}, + .thread_cluster = {.k0 = 4, .m_n = 64, .k1 = 1}, + .lds_transfer = {.global_memory_vector_load_size = 8, + .src_vector_dim = 2, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = false}, + .thread_cluster_access_order = {1, 0, 2}, .src_access_order = {1, 0, 2}, }, .b = { - .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, - .lds_transfer = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = false}, - .block_transfer_access_order = {1, 0, 2}, + .thread_cluster = {.k0 = 4, .m_n = 64, .k1 = 1}, + .lds_transfer = {.global_memory_vector_load_size = 8, + .src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = false}, + .thread_cluster_access_order = {1, 0, 2}, .src_access_order = {1, 0, 2}, }, .c = { - .thread_cluster_dims = - {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, - .epilogue = {.m_xdl_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, - .scalar_per_vector = 4}, + .thread_cluster = {.gemm_m_block_size = 1, + .gemm_m_per_block = 32, + .gemm_n_block_size = 1, + .gemm_n_per_block = 8}, + .epilogue = {.m_xdl_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 4}, }, }; -constexpr Transfer<4> BwdTransfer_4x64x1{ +constexpr InputOutputTileTransfer<4> BwdTransfer_4x64x1{ .a = { - .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1}, - .lds_transfer = {.src_vector_dim = 2, - .src_scalar_per_vector = 2, - .lds_dst_scalar_per_vector = 4, - .is_direct_load = false, - .lds_padding = true}, - .block_transfer_access_order = {0, 3, 1, 2}, + .thread_cluster = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1}, + .lds_transfer = {.global_memory_vector_load_size = 8, + .src_vector_dim = 2, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 4, + .is_direct_load = false, + .lds_padding = true}, + .thread_cluster_access_order = {0, 3, 1, 2}, .src_access_order = {0, 2, 1, 3}, }, .b = { - .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1}, - .lds_transfer = {.src_vector_dim = 2, - .src_scalar_per_vector = 2, - .lds_dst_scalar_per_vector = 4, - .is_direct_load = false, - .lds_padding = true}, - .block_transfer_access_order = {0, 3, 1, 2}, + .thread_cluster = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1}, + .lds_transfer = {.global_memory_vector_load_size = 8, + .src_vector_dim = 2, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 4, + .is_direct_load = false, + .lds_padding = true}, + .thread_cluster_access_order = {0, 3, 1, 2}, .src_access_order = {0, 2, 1, 3}, }, .c = { - .thread_cluster_dims = - {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, - .epilogue = {.m_xdl_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, - .scalar_per_vector = 8}, + .thread_cluster = {.gemm_m_block_size = 1, + .gemm_m_per_block = 32, + .gemm_n_block_size = 1, + .gemm_n_per_block = 8}, + .epilogue = {.m_xdl_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, }, }; -constexpr Transfer<> BwdTransfer_4x8x1_4x16x1_v3{ +constexpr InputOutputTileTransfer<> BwdTransfer_4x8x1_4x16x1_v3{ .a = { - .block_transfer = {.k0 = 4, .m_n = 8, .k1 = 1}, - .lds_transfer = {.src_vector_dim = 1, - .src_scalar_per_vector = 2, - .lds_dst_scalar_per_vector = 2, - .is_direct_load = false, - .lds_padding = false}, - .block_transfer_access_order = {2, 0, 1}, + .thread_cluster = {.k0 = 4, .m_n = 8, .k1 = 1}, + .lds_transfer = {.global_memory_vector_load_size = 8, + .src_vector_dim = 1, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 2, + .is_direct_load = false, + .lds_padding = false}, + .thread_cluster_access_order = {2, 0, 1}, .src_access_order = {1, 0, 2}, }, .b = { - .block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1}, - .lds_transfer = {.src_vector_dim = 1, - .src_scalar_per_vector = 2, - .lds_dst_scalar_per_vector = 2, - .is_direct_load = false, - .lds_padding = false}, - .block_transfer_access_order = {2, 0, 1}, + .thread_cluster = {.k0 = 4, .m_n = 16, .k1 = 1}, + .lds_transfer = {.global_memory_vector_load_size = 8, + .src_vector_dim = 1, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 2, + .is_direct_load = false, + .lds_padding = false}, + .thread_cluster_access_order = {2, 0, 1}, .src_access_order = {1, 0, 2}, }, .c = { - .thread_cluster_dims = - {.m_block = 1, .m_wave_per_xdl = 8, .n_block = 1, .n_wave_per_xdl = 8}, - .epilogue = {.m_xdl_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, - .scalar_per_vector = 2}, + .thread_cluster = {.gemm_m_block_size = 1, + .gemm_m_per_block = 8, + .gemm_n_block_size = 1, + .gemm_n_per_block = 8}, + .epilogue = {.m_xdl_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 2}, }, }; -constexpr Transfer<> Transfer_4x64x1_fp8{ +constexpr InputOutputTileTransfer<> Transfer_4x64x1_fp8{ .a = { - .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, - .lds_transfer = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = true}, - .block_transfer_access_order = {1, 0, 2}, + .thread_cluster = {.k0 = 4, .m_n = 64, .k1 = 1}, + .lds_transfer = {.global_memory_vector_load_size = 8, + .src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .thread_cluster_access_order = {1, 0, 2}, .src_access_order = {1, 0, 2}, }, .b = { - .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, - .lds_transfer = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = true}, - .block_transfer_access_order = {1, 0, 2}, + .thread_cluster = {.k0 = 4, .m_n = 64, .k1 = 1}, + .lds_transfer = {.global_memory_vector_load_size = 8, + .src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .thread_cluster_access_order = {1, 0, 2}, .src_access_order = {1, 0, 2}, }, .c = { - .thread_cluster_dims = - {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, - .epilogue = {.m_xdl_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, - .scalar_per_vector = 8}, + .thread_cluster = {.gemm_m_block_size = 1, + .gemm_m_per_block = 32, + .gemm_n_block_size = 1, + .gemm_n_per_block = 8}, + .epilogue = {.m_xdl_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, }, }; -constexpr Transfer<> Transfer_4x16x1{ +constexpr InputOutputTileTransfer<> Transfer_4x16x1{ .a = { - .block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1}, - .lds_transfer = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = true}, - .block_transfer_access_order = {1, 0, 2}, + .thread_cluster = {.k0 = 4, .m_n = 16, .k1 = 1}, + .lds_transfer = {.global_memory_vector_load_size = 8, + .src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .thread_cluster_access_order = {1, 0, 2}, .src_access_order = {1, 0, 2}, }, .b = { - .block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1}, - .lds_transfer = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = true}, - .block_transfer_access_order = {1, 0, 2}, + .thread_cluster = {.k0 = 4, .m_n = 16, .k1 = 1}, + .lds_transfer = {.global_memory_vector_load_size = 8, + .src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .thread_cluster_access_order = {1, 0, 2}, .src_access_order = {1, 0, 2}, }, .c = { - .thread_cluster_dims = - {.m_block = 1, .m_wave_per_xdl = 16, .n_block = 1, .n_wave_per_xdl = 4}, - .epilogue = {.m_xdl_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, - .scalar_per_vector = 8}, + .thread_cluster = {.gemm_m_block_size = 1, + .gemm_m_per_block = 16, + .gemm_n_block_size = 1, + .gemm_n_per_block = 4}, + .epilogue = {.m_xdl_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, }, }; -constexpr Transfer<> Transfer_4x32x1{ +constexpr InputOutputTileTransfer<> Transfer_4x32x1{ .a = { - .block_transfer = {.k0 = 4, .m_n = 32, .k1 = 1}, - .lds_transfer = {.src_vector_dim = 2, - .src_scalar_per_vector = 16, - .lds_dst_scalar_per_vector = 16, - .is_direct_load = false, - .lds_padding = true}, - .block_transfer_access_order = {1, 0, 2}, + .thread_cluster = {.k0 = 4, .m_n = 32, .k1 = 1}, + .lds_transfer = {.global_memory_vector_load_size = 8, + .src_vector_dim = 2, + .src_scalar_per_vector = 16, + .lds_dst_scalar_per_vector = 16, + .is_direct_load = false, + .lds_padding = true}, + .thread_cluster_access_order = {1, 0, 2}, .src_access_order = {1, 0, 2}, }, .b = { - .block_transfer = {.k0 = 4, .m_n = 32, .k1 = 1}, - .lds_transfer = {.src_vector_dim = 2, - .src_scalar_per_vector = 16, - .lds_dst_scalar_per_vector = 16, - .is_direct_load = false, - .lds_padding = true}, - .block_transfer_access_order = {1, 0, 2}, + .thread_cluster = {.k0 = 4, .m_n = 32, .k1 = 1}, + .lds_transfer = {.global_memory_vector_load_size = 8, + .src_vector_dim = 2, + .src_scalar_per_vector = 16, + .lds_dst_scalar_per_vector = 16, + .is_direct_load = false, + .lds_padding = true}, + .thread_cluster_access_order = {1, 0, 2}, .src_access_order = {1, 0, 2}, }, .c = { - .thread_cluster_dims = - {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 4}, - .epilogue = {.m_xdl_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, - .scalar_per_vector = 8}, + .thread_cluster = {.gemm_m_block_size = 1, + .gemm_m_per_block = 32, + .gemm_n_block_size = 1, + .gemm_n_per_block = 4}, + .epilogue = {.m_xdl_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, }, }; -constexpr GridwiseBwdXdlGemm BwdGemmParams_Xdl_4x4_per_wave{ - .k1 = 8, - .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}}; - -constexpr GridwiseBwdXdlGemm BwdGemmParams_Xdl_1x1_per_wave{ - .k1 = 8, - .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 1, .n_xdl_per_wave = 1}}; - -constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_4x4_per_wave{ - .ak1 = 8, - .bk1 = 8, - .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}}; - -constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_4x2_per_wave{ - .ak1 = 8, - .bk1 = 8, - .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 2}}; - -constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_2x2_per_wave{ - .ak1 = 8, - .bk1 = 8, - .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2}}; - -constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_2x1_per_wave{ - .ak1 = 8, - .bk1 = 8, - .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 1}}; - -constexpr GridwiseWmmaGemm GemmParams_Wmma_2x1_per_wave{ - .k1 = 8, .m_per_wmma = 32, .n_per_wmma = 32, .m_wmma_per_wave = 2, .n_wmma_per_wave = 1}; +constexpr InputOutputTileTransfer<> Transfer_4x32x1_vector_load_16_generic{ + .a = + { + .thread_cluster = {.k0 = 4, .m_n = 32, .k1 = 1}, + .lds_transfer = {.global_memory_vector_load_size = 16, + .src_vector_dim = 2, + .src_scalar_per_vector = 1, + .lds_dst_scalar_per_vector = 16, + .is_direct_load = false, + .lds_padding = true}, + .thread_cluster_access_order = {1, 0, 2}, + .src_access_order = {1, 0, 2}, + }, + .b = + { + .thread_cluster = {.k0 = 4, .m_n = 32, .k1 = 1}, + .lds_transfer = {.global_memory_vector_load_size = 16, + .src_vector_dim = 2, + .src_scalar_per_vector = 1, + .lds_dst_scalar_per_vector = 16, + .is_direct_load = false, + .lds_padding = true}, + .thread_cluster_access_order = {1, 0, 2}, + .src_access_order = {1, 0, 2}, + }, + .c = + { + .thread_cluster = {.gemm_m_block_size = 1, + .gemm_m_per_block = 32, + .gemm_n_block_size = 1, + .gemm_n_per_block = 4}, + .epilogue = {.m_xdl_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 1}, + }, +}; -constexpr GridwiseWmmaGemm GemmParams_Wmma_16x16_2x1_per_wave{ - .k1 = 8, .m_per_wmma = 16, .n_per_wmma = 16, .m_wmma_per_wave = 2, .n_wmma_per_wave = 1}; +constexpr WarpGemmParams BwdGemmParams_Xdl_4x4_per_wave{.matrix_instruction = + MatrixInstructionType::XDL, + .gemm_m_per_instruction = 32, + .gemm_n_per_instruction = 32, + .gemm_m_iters_per_wave = 4, + .gemm_n_iters_per_wave = 4}; + +constexpr WarpGemmParams BwdGemmParams_Xdl_1x1_per_wave{.matrix_instruction = + MatrixInstructionType::XDL, + .gemm_m_per_instruction = 32, + .gemm_n_per_instruction = 32, + .gemm_m_iters_per_wave = 1, + .gemm_n_iters_per_wave = 1}; + +constexpr WarpGemmParams FwdGemmParams_Xdl_4x4_per_wave{.matrix_instruction = + MatrixInstructionType::XDL, + .gemm_m_per_instruction = 32, + .gemm_n_per_instruction = 32, + .gemm_m_iters_per_wave = 4, + .gemm_n_iters_per_wave = 4}; + +constexpr WarpGemmParams FwdGemmParams_Xdl_4x2_per_wave{.matrix_instruction = + MatrixInstructionType::XDL, + .gemm_m_per_instruction = 32, + .gemm_n_per_instruction = 32, + .gemm_m_iters_per_wave = 4, + .gemm_n_iters_per_wave = 2}; + +constexpr WarpGemmParams FwdGemmParams_Xdl_2x2_per_wave{.matrix_instruction = + MatrixInstructionType::XDL, + .gemm_m_per_instruction = 32, + .gemm_n_per_instruction = 32, + .gemm_m_iters_per_wave = 2, + .gemm_n_iters_per_wave = 2}; + +constexpr WarpGemmParams FwdGemmParams_Xdl_2x1_per_wave{.matrix_instruction = + MatrixInstructionType::XDL, + .gemm_m_per_instruction = 32, + .gemm_n_per_instruction = 32, + .gemm_m_iters_per_wave = 2, + .gemm_n_iters_per_wave = 1}; + +constexpr WarpGemmParams GemmParams_Wmma_16x16_2x1_per_wave{.matrix_instruction = + MatrixInstructionType::WMMA, + .gemm_m_per_instruction = 16, + .gemm_n_per_instruction = 16, + .gemm_m_iters_per_wave = 2, + .gemm_n_iters_per_wave = 1}; + +constexpr WarpGemmParams GemmParams_Wmma_16x16_2x2_per_wave{.matrix_instruction = + MatrixInstructionType::WMMA, + .gemm_m_per_instruction = 16, + .gemm_n_per_instruction = 16, + .gemm_m_iters_per_wave = 2, + .gemm_n_iters_per_wave = 2}; constexpr ThreadBlock ThreadBlock_256_256x256x32{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; @@ -310,19 +395,19 @@ constexpr ThreadBlock ThreadBlock_128_128x128x32{.block_size = 128, constexpr ThreadBlock ThreadBlock_128_64x64x64{.block_size = 128, .tile_size = {.m = 64, .n = 64, .k = 64}}; -constexpr BlockGemmPipeline BlockGemmDesc_v1_intrawave = { - .pipeline_version = PipelineVersion::V1, .scheduler = PipelineScheduler::INTRAWAVE}; +constexpr GemmPipeline BlockGemmDesc_v1_intrawave = {.pipeline_version = PipelineVersion::V1, + .scheduler = PipelineScheduler::INTRAWAVE}; -constexpr BlockGemmPipeline BlockGemmDesc_v2_intrawave = { - .pipeline_version = PipelineVersion::V2, .scheduler = PipelineScheduler::INTRAWAVE}; +constexpr GemmPipeline BlockGemmDesc_v2_intrawave = {.pipeline_version = PipelineVersion::V2, + .scheduler = PipelineScheduler::INTRAWAVE}; -constexpr BlockGemmPipeline BlockGemmDesc_v3_intrawave = { - .pipeline_version = PipelineVersion::V3, .scheduler = PipelineScheduler::INTRAWAVE}; +constexpr GemmPipeline BlockGemmDesc_v3_intrawave = {.pipeline_version = PipelineVersion::V3, + .scheduler = PipelineScheduler::INTRAWAVE}; -constexpr BlockGemmPipeline BlockGemmDesc_v4_intrawave = { - .pipeline_version = PipelineVersion::V4, .scheduler = PipelineScheduler::INTRAWAVE}; +constexpr GemmPipeline BlockGemmDesc_v4_intrawave = {.pipeline_version = PipelineVersion::V4, + .scheduler = PipelineScheduler::INTRAWAVE}; -constexpr BlockGemmPipeline BlockGemmDesc_v5_intrawave = { - .pipeline_version = PipelineVersion::V5, .scheduler = PipelineScheduler::INTRAWAVE}; +constexpr GemmPipeline BlockGemmDesc_v5_intrawave = {.pipeline_version = PipelineVersion::V5, + .scheduler = PipelineScheduler::INTRAWAVE}; } // namespace ck_tile::builder::test_utils diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index 23f4cf33648..2f66a633fad 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -86,42 +86,25 @@ inline std::string to_string(ThreadBlock t) } template <> -inline std::string to_string(GridwiseBwdXdlGemm t) +inline std::string to_string(WarpGemmParams t) { std::ostringstream oss; - oss << t.k1 << "," << t.xdl_params.m_per_xdl << "," << t.xdl_params.n_per_xdl << "," - << t.xdl_params.m_xdl_per_wave << "," << t.xdl_params.n_xdl_per_wave; + oss << t.gemm_m_per_instruction << "," << t.gemm_n_per_instruction << "," + << t.gemm_m_iters_per_wave << "," << t.gemm_n_iters_per_wave; return oss.str(); } template <> -inline std::string to_string(GridwiseFwdXdlGemm t) +inline std::string to_string(GemmPipeline t) { std::ostringstream oss; - oss << t.ak1 << "," << t.bk1 << "," << t.xdl_params.m_per_xdl << "," << t.xdl_params.n_per_xdl - << "," << t.xdl_params.m_xdl_per_wave << "," << t.xdl_params.n_xdl_per_wave; - return oss.str(); -} - -template <> -inline std::string to_string(GridwiseWmmaGemm t) -{ - std::ostringstream oss; - oss << t.k1 << "," << t.m_per_wmma << "," << t.n_per_wmma << "," << t.m_wmma_per_wave << "," - << t.n_wmma_per_wave; - return oss.str(); -} - -template <> -inline std::string to_string(BlockGemmPipeline t) -{ - std::ostringstream oss; - oss << to_string(t.scheduler) << "," << to_string(t.pipeline_version); + oss << t.num_gemm_k_prefetch_stages << "," << t.num_conv_groups_to_merge << "," + << to_string(t.scheduler) << "," << to_string(t.pipeline_version); return oss.str(); } template -inline std::string to_string(BlockTransfer t) +inline std::string to_string(InputThreadCluster t) { if constexpr(ThreadClusterRank == 4) { @@ -139,19 +122,19 @@ inline std::string to_string(BlockTransfer t) } template <> -inline std::string to_string(ThreadCluster t) +inline std::string to_string(OutputThreadCluster t) { - return array_to_seq( - std::array{t.m_block, t.m_wave_per_xdl, t.n_block, t.n_wave_per_xdl}); + return array_to_seq(std::array{ + t.gemm_m_block_size, t.gemm_m_per_block, t.gemm_n_block_size, t.gemm_n_per_block}); } template <> inline std::string to_string(LdsTransfer t) { std::ostringstream oss; - oss << t.src_vector_dim << "," << t.src_scalar_per_vector << "," << t.lds_dst_scalar_per_vector - << "," << (t.lds_padding ? "true" : "false") << "," - << (t.is_direct_load ? "true" : "false"); + oss << t.global_memory_vector_load_size << "," << t.src_vector_dim << "," + << t.src_scalar_per_vector << "," << t.lds_dst_scalar_per_vector << "," + << (t.lds_padding ? "true" : "false") << "," << (t.is_direct_load ? "true" : "false"); return oss.str(); } @@ -162,10 +145,10 @@ inline std::string to_string(AccessOrder t) } template -inline std::string to_string(InputTransfer t) +inline std::string to_string(InputTileTransfer t) { std::ostringstream oss; - oss << to_string(t.block_transfer) << "," << to_string(t.block_transfer_access_order) << "," + oss << to_string(t.thread_cluster) << "," << to_string(t.thread_cluster_access_order) << "," << to_string(t.src_access_order) << "," << t.lds_transfer.src_vector_dim << "," << t.lds_transfer.src_scalar_per_vector << "," << t.lds_transfer.lds_dst_scalar_per_vector << "," << (t.lds_transfer.lds_padding ? "true" : "false"); @@ -173,16 +156,16 @@ inline std::string to_string(InputTransfer t) } template <> -inline std::string to_string(OutputTransfer t) +inline std::string to_string(OutputTileTransfer t) { std::ostringstream oss; oss << t.epilogue.m_xdl_per_wave_per_shuffle << "," << t.epilogue.n_per_wave_per_shuffle << "," - << to_string(t.thread_cluster_dims) << "," << t.epilogue.scalar_per_vector; + << to_string(t.thread_cluster) << "," << t.epilogue.scalar_per_vector; return oss.str(); } template -inline std::string to_string(Transfer t) +inline std::string to_string(InputOutputTileTransfer t) { std::ostringstream oss; oss << to_string(t.a) << "," << to_string(t.b) << "," << to_string(t.c); @@ -272,25 +255,13 @@ inline std::string to_string(ThreadBlock_ t) } template <> -inline std::string to_string(FwdXdlGemm_ t) -{ - return to_string(t.gridwise_gemm); -} - -template <> -inline std::string to_string(BwdXdlGemm_ t) -{ - return to_string(t.gridwise_gemm); -} - -template <> -inline std::string to_string(WmmaGemm_ t) +inline std::string to_string(WarpGemm_ t) { - return to_string(t.gridwise_gemm); + return to_string(t.warp_gemm); } template -inline std::string to_string(Transfer_ t) +inline std::string to_string(InputOutputTileTransfer_ t) { return to_string(t.transfer); } @@ -312,17 +283,9 @@ inline std::string to_string(ConvSpecializationBwd } template <> -inline std::string to_string(Prefetch_ t) -{ - std::ostringstream oss; - oss << t.num_gemm_k_prefetch_stages << "," << to_string(t.loop_scheduler); - return oss.str(); -} - -template <> -inline std::string to_string(BlockGemm_ t) +inline std::string to_string(GemmPipeline_ t) { - return to_string(t.block_gemm_pipeline); + return to_string(t.gemm_pipeline); } template <> @@ -352,32 +315,38 @@ inline std::string to_string>(DlTransfer_<5> t) // Template specializations for algorithm types template <> -inline std::string to_string( - ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle t) +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_CShuffle t) { std::ostringstream oss; - oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); - return oss.str(); -} - -template <> -inline std::string to_string( - ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 t) -{ - std::ostringstream oss; - oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); + if(t.warp_gemm.matrix_instruction == MatrixInstructionType::WMMA) + { + oss << to_string(static_cast(t)) << "," + << t.transfer.a.lds_transfer.global_memory_vector_load_size << "," + << to_string(static_cast(t)) << "," + << to_string(static_cast>(t)); + } + else + { + oss << to_string(static_cast(t)) << "," + << t.transfer.a.lds_transfer.global_memory_vector_load_size << "," + << t.transfer.b.lds_transfer.global_memory_vector_load_size << "," + << to_string(static_cast(t)) << "," + << to_string(static_cast>(t)); + } return oss.str(); } template <> -inline std::string to_string( - ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle t) +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_CShuffle_V3 t) { std::ostringstream oss; - oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); + oss << to_string(static_cast(t)) << "," + << t.transfer.a.lds_transfer.global_memory_vector_load_size << "," + << t.transfer.b.lds_transfer.global_memory_vector_load_size << "," + << to_string(static_cast(t)) << "," + << to_string(static_cast>(t)); return oss.str(); } @@ -398,8 +367,11 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); + oss << to_string(static_cast(t)) << "," + << t.transfer.a.lds_transfer.global_memory_vector_load_size << "," + << t.transfer.b.lds_transfer.global_memory_vector_load_size << "," + << to_string(static_cast(t)) << "," + << to_string(static_cast>(t)); return oss.str(); } @@ -408,8 +380,10 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); + oss << to_string(static_cast(t)) << "," + << t.transfer.a.lds_transfer.global_memory_vector_load_size << "," + << to_string(static_cast(t)) << "," + << to_string(static_cast>(t)); return oss.str(); } @@ -418,8 +392,10 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); + oss << to_string(static_cast(t)) << "," + << t.transfer.a.lds_transfer.global_memory_vector_load_size << "," + << to_string(static_cast(t)) << "," + << to_string(static_cast>(t)); return oss.str(); } @@ -428,8 +404,10 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); + oss << to_string(static_cast(t)) << "," + << t.transfer.a.lds_transfer.global_memory_vector_load_size << "," + << to_string(static_cast(t)) << "," + << to_string(static_cast>(t)); return oss.str(); } @@ -438,8 +416,10 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); + oss << to_string(static_cast(t)) << "," + << t.transfer.a.lds_transfer.global_memory_vector_load_size << "," + << to_string(static_cast(t)) << "," + << to_string(static_cast>(t)); return oss.str(); } @@ -448,28 +428,23 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); - return oss.str(); -} - -template <> -inline std::string to_string( - ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3 t) -{ - std::ostringstream oss; - oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); + oss << to_string(static_cast(t)) << "," + << t.transfer.a.lds_transfer.global_memory_vector_load_size << "," + << to_string(static_cast(t)) << "," + << to_string(static_cast>(t)); return oss.str(); } +// Covers both XDL and WMMA versions template <> -inline std::string to_string( - ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle t) +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_CShuffle_V3 t) { std::ostringstream oss; - oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); + oss << to_string(static_cast(t)) << "," + << t.transfer.a.lds_transfer.global_memory_vector_load_size << "," + << to_string(static_cast(t)) << "," + << to_string(static_cast>(t)); return oss.str(); } @@ -490,8 +465,10 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); + oss << to_string(static_cast(t)) << "," + << t.transfer.a.lds_transfer.global_memory_vector_load_size << "," + << to_string(static_cast(t)) << "," + << to_string(static_cast>(t)); return oss.str(); }