diff --git a/example/62_convnd_activ/convinvscale/CMakeLists.txt b/example/62_convnd_activ/convinvscale/CMakeLists.txt index 9748f50e519..cb3aca0f18d 100644 --- a/example/62_convnd_activ/convinvscale/CMakeLists.txt +++ b/example/62_convnd_activ/convinvscale/CMakeLists.txt @@ -5,4 +5,11 @@ if (NOT GPU_TARGETS MATCHES "gfx11") add_custom_target(example_convnd_activ_xdl_convinvscale) add_example_executable(example_convnd_fwd_xdl_convinvscale_fp8 convnd_fwd_xdl_convinvscale_fp8.cpp) add_example_dependencies(example_convnd_activ_xdl_convinvscale example_convnd_fwd_xdl_convinvscale_fp8) -endif() \ No newline at end of file +endif() + +# WMMA +if (GPU_TARGETS MATCHES "gfx12") + add_custom_target(example_convnd_activ_wmma_convinvscale) + add_example_executable(example_convnd_fwd_wmma_convinvscale_fp8 convnd_fwd_wmma_convinvscale_fp8.cpp) + add_example_dependencies(example_convnd_activ_wmma_convinvscale example_convnd_fwd_wmma_convinvscale_fp8) +endif() diff --git a/example/62_convnd_activ/convinvscale/convnd_fwd_wmma_convinvscale_fp8.cpp b/example/62_convnd_activ/convinvscale/convnd_fwd_wmma_convinvscale_fp8.cpp new file mode 100644 index 00000000000..2ef6d17807c --- /dev/null +++ b/example/62_convnd_activ/convinvscale/convnd_fwd_wmma_convinvscale_fp8.cpp @@ -0,0 +1,97 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "convnd_fwd_convinvscale_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::f8_t; +using AccDataType = float; +using CShuffleDataType = float; +using DsDataType = ck::Tuple<>; +using OutDataType = ck::f8_t; +using AComputeDataType = ck::f8_t; +using BComputeDataType = ck::f8_t; + +template +using S = ck::Sequence; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = ConvInvscale; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3< + NDimSpatial, // NDimSpatial + InLayout, // ALayout + WeiLayout, // BLayout + DsLayout, // DsLayout (empty tuple for ConvInvScale) + OutLayout, // ELayout + InDataType, // ADataType + WeiDataType, // BDataType + AccDataType, // AccDataType + CShuffleDataType, // CShuffleDataType + DsDataType, // DsDataType (empty tuple) + OutDataType, // EDataType + InElementOp, // AElementwiseOperation + WeiElementOp, // BElementwiseOperation + OutElementOp, // CDEElementwiseOperation + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 64, // BlockSize + 64, // MPerBlock + 64, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 16, // MPerWmma + 16, // NPerWmma + 4, // MRepeat + 2, // NRepeat + S<4, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 1, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 1, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMRepeatPerShuffle + 1, // CShuffleNRepeatPerShuffle + S<1, 16, 1, 4>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 1, // CDEBlockTransferScalarPerVector_NPerBlock + ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched + ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer + AComputeDataType, // AComputeDataType + BComputeDataType, // BComputeDataType + 1>; // NumGroupsToMerge + +#include "run_convnd_fwd_convinvscale_example.inc" + +int main(int argc, char* argv[]) +{ + if(!ck::is_gfx12_supported()) + { + std::cout << "This kernel support gfx12 only" << std::endl; + + return 0; + } + return run_convnd_fwd_example(argc, argv) ? 0 : 1; +} diff --git a/example/62_convnd_activ/convscale/CMakeLists.txt b/example/62_convnd_activ/convscale/CMakeLists.txt index 705160e01d0..ba63f59bcd3 100644 --- a/example/62_convnd_activ/convscale/CMakeLists.txt +++ b/example/62_convnd_activ/convscale/CMakeLists.txt @@ -15,3 +15,19 @@ if (NOT GPU_TARGETS MATCHES "gfx11") add_example_executable(example_convnd_fwd_xdl_convscale_bf8_fp8 convnd_fwd_xdl_convscale_bf8_fp8.cpp) add_example_dependencies(example_convnd_activ_xdl_convscale example_convnd_fwd_xdl_convscale_bf8_fp8) endif() + +# WMMA +if (GPU_TARGETS MATCHES "gfx12") + add_custom_target(example_convnd_activ_wmma_convscale) + add_example_executable(example_convnd_fwd_wmma_convscale_fp8 convnd_fwd_wmma_convscale_fp8.cpp) + add_example_dependencies(example_convnd_activ_wmma_convscale example_convnd_fwd_wmma_convscale_fp8) + + add_example_executable(example_convnd_fwd_wmma_convscale_bf8 convnd_fwd_wmma_convscale_bf8.cpp) + add_example_dependencies(example_convnd_activ_wmma_convscale example_convnd_fwd_wmma_convscale_bf8) + + add_example_executable(example_convnd_fwd_wmma_convscale_fp8_bf8 convnd_fwd_wmma_convscale_fp8_bf8.cpp) + add_example_dependencies(example_convnd_activ_wmma_convscale example_convnd_fwd_wmma_convscale_fp8_bf8) + + add_example_executable(example_convnd_fwd_wmma_convscale_bf8_fp8 convnd_fwd_wmma_convscale_bf8_fp8.cpp) + add_example_dependencies(example_convnd_activ_wmma_convscale example_convnd_fwd_wmma_convscale_bf8_fp8) +endif() diff --git a/example/62_convnd_activ/convscale/convnd_fwd_wmma_convscale_bf8.cpp b/example/62_convnd_activ/convscale/convnd_fwd_wmma_convscale_bf8.cpp new file mode 100644 index 00000000000..96c44536895 --- /dev/null +++ b/example/62_convnd_activ/convscale/convnd_fwd_wmma_convscale_bf8.cpp @@ -0,0 +1,97 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "convnd_fwd_convscale_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp" + +using InDataType = ck::bf8_t; +using WeiDataType = ck::bf8_t; +using AccDataType = float; +using CShuffleDataType = float; +using DsDataType = ck::Tuple<>; +using OutDataType = ck::f8_t; +using AComputeDataType = InDataType; +using BComputeDataType = AComputeDataType; + +template +using S = ck::Sequence; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = ConvScale; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3< + NDimSpatial, // NDimSpatial + InLayout, // ALayout + WeiLayout, // BLayout + DsLayout, // DsLayout (empty tuple for ConvScale) + OutLayout, // ELayout + InDataType, // ADataType + WeiDataType, // BDataType + AccDataType, // AccDataType + CShuffleDataType, // CShuffleDataType + DsDataType, // DsDataType (empty tuple) + OutDataType, // EDataType + InElementOp, // AElementwiseOperation + WeiElementOp, // BElementwiseOperation + OutElementOp, // CDEElementwiseOperation + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 64, // BlockSize + 64, // MPerBlock + 64, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 16, // MPerWmma + 16, // NPerWmma + 4, // MRepeat + 2, // NRepeat + S<4, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 1, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 1, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMRepeatPerShuffle + 1, // CShuffleNRepeatPerShuffle + S<1, 16, 1, 4>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 1, // CDEBlockTransferScalarPerVector_NPerBlock + ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched + ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer + AComputeDataType, // AComputeDataType + BComputeDataType, // BComputeDataType + 1>; // NumGroupsToMerge + +#include "run_convnd_fwd_convscale_example.inc" + +int main(int argc, char* argv[]) +{ + if(!ck::is_gfx12_supported()) + { + std::cout << "This kernel support gfx12 only" << std::endl; + + return 0; + } + return run_convnd_fwd_example(argc, argv) ? 0 : 1; +} diff --git a/example/62_convnd_activ/convscale/convnd_fwd_wmma_convscale_bf8_fp8.cpp b/example/62_convnd_activ/convscale/convnd_fwd_wmma_convscale_bf8_fp8.cpp new file mode 100644 index 00000000000..d5ad65c3f49 --- /dev/null +++ b/example/62_convnd_activ/convscale/convnd_fwd_wmma_convscale_bf8_fp8.cpp @@ -0,0 +1,97 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "convnd_fwd_convscale_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp" + +using InDataType = ck::bf8_t; +using WeiDataType = ck::f8_t; +using AccDataType = float; +using CShuffleDataType = float; +using DsDataType = ck::Tuple<>; +using OutDataType = ck::f8_t; +using AComputeDataType = ck::bf8_t; +using BComputeDataType = ck::f8_t; + +template +using S = ck::Sequence; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = ConvScale; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3< + NDimSpatial, // NDimSpatial + InLayout, // ALayout + WeiLayout, // BLayout + DsLayout, // DsLayout (empty tuple for ConvScale) + OutLayout, // ELayout + InDataType, // ADataType + WeiDataType, // BDataType + AccDataType, // AccDataType + CShuffleDataType, // CShuffleDataType + DsDataType, // DsDataType (empty tuple) + OutDataType, // EDataType + InElementOp, // AElementwiseOperation + WeiElementOp, // BElementwiseOperation + OutElementOp, // CDEElementwiseOperation + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 64, // BlockSize + 64, // MPerBlock + 64, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 16, // MPerWmma + 16, // NPerWmma + 4, // MRepeat + 2, // NRepeat + S<4, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 1, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 1, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMRepeatPerShuffle + 1, // CShuffleNRepeatPerShuffle + S<1, 16, 1, 4>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 1, // CDEBlockTransferScalarPerVector_NPerBlock + ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched + ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer + AComputeDataType, // AComputeDataType + BComputeDataType, // BComputeDataType + 1>; // NumGroupsToMerge + +#include "run_convnd_fwd_convscale_example.inc" + +int main(int argc, char* argv[]) +{ + if(!ck::is_gfx12_supported()) + { + std::cout << "This kernel support gfx12 only" << std::endl; + + return 0; + } + return run_convnd_fwd_example(argc, argv) ? 0 : 1; +} diff --git a/example/62_convnd_activ/convscale/convnd_fwd_wmma_convscale_fp8.cpp b/example/62_convnd_activ/convscale/convnd_fwd_wmma_convscale_fp8.cpp new file mode 100644 index 00000000000..56d0523825a --- /dev/null +++ b/example/62_convnd_activ/convscale/convnd_fwd_wmma_convscale_fp8.cpp @@ -0,0 +1,97 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "convnd_fwd_convscale_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::f8_t; +using AccDataType = float; +using CShuffleDataType = float; +using DsDataType = ck::Tuple<>; +using OutDataType = ck::f8_t; +using AComputeDataType = ck::f8_t; +using BComputeDataType = ck::f8_t; + +template +using S = ck::Sequence; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = ConvScale; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3< + NDimSpatial, // NDimSpatial + InLayout, // ALayout + WeiLayout, // BLayout + DsLayout, // DsLayout (empty tuple for ConvScale) + OutLayout, // ELayout + InDataType, // ADataType + WeiDataType, // BDataType + AccDataType, // AccDataType + CShuffleDataType, // CShuffleDataType + DsDataType, // DsDataType (empty tuple) + OutDataType, // EDataType + InElementOp, // AElementwiseOperation + WeiElementOp, // BElementwiseOperation + OutElementOp, // CDEElementwiseOperation + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 64, // BlockSize + 64, // MPerBlock + 64, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 16, // MPerWmma + 16, // NPerWmma + 4, // MRepeat + 2, // NRepeat + S<4, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 1, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 1, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMRepeatPerShuffle + 1, // CShuffleNRepeatPerShuffle + S<1, 16, 1, 4>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 1, // CDEBlockTransferScalarPerVector_NPerBlock + ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched + ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer + AComputeDataType, // AComputeDataType + BComputeDataType, // BComputeDataType + 1>; // NumGroupsToMerge + +#include "run_convnd_fwd_convscale_example.inc" + +int main(int argc, char* argv[]) +{ + if(!ck::is_gfx12_supported()) + { + std::cout << "This kernel support gfx12 only" << std::endl; + + return 0; + } + return run_convnd_fwd_example(argc, argv) ? 0 : 1; +} diff --git a/example/62_convnd_activ/convscale/convnd_fwd_wmma_convscale_fp8_bf8.cpp b/example/62_convnd_activ/convscale/convnd_fwd_wmma_convscale_fp8_bf8.cpp new file mode 100644 index 00000000000..00551f1c7ee --- /dev/null +++ b/example/62_convnd_activ/convscale/convnd_fwd_wmma_convscale_fp8_bf8.cpp @@ -0,0 +1,97 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "convnd_fwd_convscale_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::bf8_t; +using AccDataType = float; +using CShuffleDataType = float; +using DsDataType = ck::Tuple<>; +using OutDataType = ck::f8_t; +using AComputeDataType = ck::f8_t; +using BComputeDataType = ck::bf8_t; + +template +using S = ck::Sequence; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = ConvScale; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3< + NDimSpatial, // NDimSpatial + InLayout, // ALayout + WeiLayout, // BLayout + DsLayout, // DsLayout (empty tuple for ConvScale) + OutLayout, // ELayout + InDataType, // ADataType + WeiDataType, // BDataType + AccDataType, // AccDataType + CShuffleDataType, // CShuffleDataType + DsDataType, // DsDataType (empty tuple) + OutDataType, // EDataType + InElementOp, // AElementwiseOperation + WeiElementOp, // BElementwiseOperation + OutElementOp, // CDEElementwiseOperation + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 64, // BlockSize + 64, // MPerBlock + 64, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 16, // MPerWmma + 16, // NPerWmma + 4, // MRepeat + 2, // NRepeat + S<4, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 1, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 1, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMRepeatPerShuffle + 1, // CShuffleNRepeatPerShuffle + S<1, 16, 1, 4>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 1, // CDEBlockTransferScalarPerVector_NPerBlock + ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched + ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer + AComputeDataType, // AComputeDataType + BComputeDataType, // BComputeDataType + 1>; // NumGroupsToMerge + +#include "run_convnd_fwd_convscale_example.inc" + +int main(int argc, char* argv[]) +{ + if(!ck::is_gfx12_supported()) + { + std::cout << "This kernel support gfx12 only" << std::endl; + + return 0; + } + return run_convnd_fwd_example(argc, argv) ? 0 : 1; +} diff --git a/example/62_convnd_activ/convscale_add/CMakeLists.txt b/example/62_convnd_activ/convscale_add/CMakeLists.txt index e8f1488eb74..d0226f91399 100644 --- a/example/62_convnd_activ/convscale_add/CMakeLists.txt +++ b/example/62_convnd_activ/convscale_add/CMakeLists.txt @@ -5,4 +5,11 @@ if (NOT GPU_TARGETS MATCHES "gfx11") add_custom_target(example_convnd_activ_xdl_convscale_add) add_example_executable(example_convnd_fwd_xdl_convscale_add_fp8 convnd_fwd_xdl_convscale_add_fp8.cpp) add_example_dependencies(example_convnd_activ_xdl_convscale_add example_convnd_fwd_xdl_convscale_add_fp8) -endif() \ No newline at end of file +endif() + +# WMMA +if (GPU_TARGETS MATCHES "gfx12") + add_custom_target(example_convnd_activ_wmma_convscale_add) + add_example_executable(example_convnd_fwd_wmma_convscale_add_fp8 convnd_fwd_wmma_convscale_add_fp8.cpp) + add_example_dependencies(example_convnd_activ_wmma_convscale_add example_convnd_fwd_wmma_convscale_add_fp8) +endif() diff --git a/example/62_convnd_activ/convscale_add/convnd_fwd_wmma_convscale_add_fp8.cpp b/example/62_convnd_activ/convscale_add/convnd_fwd_wmma_convscale_add_fp8.cpp new file mode 100644 index 00000000000..f332dffc6ef --- /dev/null +++ b/example/62_convnd_activ/convscale_add/convnd_fwd_wmma_convscale_add_fp8.cpp @@ -0,0 +1,98 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/utility/tuple.hpp" +#include "convnd_fwd_convscale_add_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::f8_t; +using AccDataType = float; +using CShuffleDataType = float; +using DsDataType = float; +using OutDataType = ck::f8_t; +using AComputeDataType = ck::f8_t; +using BComputeDataType = ck::f8_t; + +template +using S = ck::Sequence; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = ConvScaleAdd; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3< + NDimSpatial, // NDimSpatial + InLayout, // ALayout + WeiLayout, // BLayout + ck::Tuple, // DsLayout + OutLayout, // ELayout + InDataType, // ADataType + WeiDataType, // BDataType + AccDataType, // AccDataType + CShuffleDataType, // CShuffleDataType + ck::Tuple, // DsDataType + OutDataType, // EDataType + InElementOp, // AElementwiseOperation + WeiElementOp, // BElementwiseOperation + OutElementOp, // CDEElementwiseOperation + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 64, // BlockSize + 64, // MPerBlock + 64, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 16, // MPerWmma + 16, // NPerWmma + 4, // MRepeat + 2, // NRepeat + S<4, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 1, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 1, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMRepeatPerShuffle + 1, // CShuffleNRepeatPerShuffle + S<1, 16, 1, 4>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 1, // CDEBlockTransferScalarPerVector_NPerBlock + ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched + ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer + AComputeDataType, // AComputeDataType + BComputeDataType, // BComputeDataType + 1>; // NumGroupsToMerge + +#include "run_convnd_fwd_convscale_add_example.inc" + +int main(int argc, char* argv[]) +{ + if(!ck::is_gfx12_supported()) + { + std::cout << "This kernel support gfx12 only" << std::endl; + + return 0; + } + return run_convnd_fwd_example(argc, argv) ? 0 : 1; +} diff --git a/example/62_convnd_activ/convscale_reduce/CMakeLists.txt b/example/62_convnd_activ/convscale_reduce/CMakeLists.txt index 0cbf17b2ec0..ee2e06f9398 100644 --- a/example/62_convnd_activ/convscale_reduce/CMakeLists.txt +++ b/example/62_convnd_activ/convscale_reduce/CMakeLists.txt @@ -8,4 +8,11 @@ if (NOT GPU_TARGETS MATCHES "gfx11") add_example_executable(example_convnd_fwd_xdl_convscale_amax_fp8 convnd_fwd_xdl_convscale_amax_fp8.cpp) add_example_dependencies(example_convnd_activ_xdl_convscale_reduce example_convnd_fwd_xdl_convscale_amax_fp8) -endif() \ No newline at end of file +endif() + +# WMMA +if (GPU_TARGETS MATCHES "gfx12") + add_custom_target(example_convnd_activ_wmma_convscale_reduce) + add_example_executable(example_convnd_fwd_wmma_convscale_amax_fp8 convnd_fwd_wmma_convscale_amax_fp8.cpp) + add_example_dependencies(example_convnd_activ_wmma_convscale_reduce example_convnd_fwd_wmma_convscale_amax_fp8) +endif() diff --git a/example/62_convnd_activ/convscale_reduce/convnd_fwd_wmma_convscale_amax_fp8.cpp b/example/62_convnd_activ/convscale_reduce/convnd_fwd_wmma_convscale_amax_fp8.cpp new file mode 100644 index 00000000000..a4053ca675c --- /dev/null +++ b/example/62_convnd_activ/convscale_reduce/convnd_fwd_wmma_convscale_amax_fp8.cpp @@ -0,0 +1,93 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "convnd_fwd_convscale_reduce_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::f8_t; +using AccDataType = float; +using CShuffleDataType = float; +using ConvOutDataType = float; // data type of convolution result +using OutDataType = ck::f8_t; // data type of final result +using AComputeDataType = ck::f8_t; +using BComputeDataType = ck::f8_t; + +template +using S = ck::Sequence; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = ConvScale; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3< + NDimSpatial, // NDimSpatial + InLayout, // ALayout + WeiLayout, // BLayout + ck::Tuple<>, // DsLayout + OutLayout, // ELayout + InDataType, // ADataType + WeiDataType, // BDataType + AccDataType, // AccDataType + CShuffleDataType, // CShuffleDataType + ck::Tuple<>, // DsDataType + ConvOutDataType, // EDataType + InElementOp, // AElementwiseOperation + WeiElementOp, // BElementwiseOperation + OutElementOp, // CDEElementwiseOperation + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 64, // BlockSize + 64, // MPerBlock + 64, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 16, // MPerWmma + 16, // NPerWmma + 4, // MRepeat + 2, // NRepeat + S<4, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 1, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 1, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMRepeatPerShuffle + 1, // CShuffleNRepeatPerShuffle + S<1, 16, 1, 4>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 1, // CDEBlockTransferScalarPerVector_NPerBlock + ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched + ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer + AComputeDataType, // AComputeDataType + BComputeDataType, // BComputeDataType + 1>; // NumGroupsToMerge + +#include "run_convnd_fwd_example.inc" + +int main(int argc, char* argv[]) +{ + if(!ck::is_gfx12_supported()) + { + std::cout << "This kernel support gfx12 only" << std::endl; + + return 0; + } + return run_convnd_fwd_example(argc, argv) ? 0 : 1; +} diff --git a/example/62_convnd_activ/convscale_relu/CMakeLists.txt b/example/62_convnd_activ/convscale_relu/CMakeLists.txt index 307a4102a6a..27fcdc01588 100644 --- a/example/62_convnd_activ/convscale_relu/CMakeLists.txt +++ b/example/62_convnd_activ/convscale_relu/CMakeLists.txt @@ -6,3 +6,10 @@ if (NOT GPU_TARGETS MATCHES "gfx11") add_example_executable(example_convnd_fwd_xdl_convscale_relu_fp8 convnd_fwd_xdl_convscale_relu_fp8.cpp) add_example_dependencies(example_convnd_activ_xdl_convscale_relu example_convnd_fwd_xdl_convscale_relu_fp8) endif() + +# WMMA +if (GPU_TARGETS MATCHES "gfx12") + add_custom_target(example_convnd_activ_wmma_convscale_relu) + add_example_executable(example_convnd_fwd_wmma_convscale_relu_fp8 convnd_fwd_wmma_convscale_relu_fp8.cpp) + add_example_dependencies(example_convnd_activ_wmma_convscale_relu example_convnd_fwd_wmma_convscale_relu_fp8) +endif() diff --git a/example/62_convnd_activ/convscale_relu/convnd_fwd_wmma_convscale_relu_fp8.cpp b/example/62_convnd_activ/convscale_relu/convnd_fwd_wmma_convscale_relu_fp8.cpp new file mode 100644 index 00000000000..4b1787aa27a --- /dev/null +++ b/example/62_convnd_activ/convscale_relu/convnd_fwd_wmma_convscale_relu_fp8.cpp @@ -0,0 +1,97 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "convnd_fwd_convscale_relu_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::f8_t; +using AccDataType = float; +using CShuffleDataType = float; +using DsDataType = ck::Tuple<>; +using OutDataType = ck::f8_t; +using AComputeDataType = ck::f8_t; +using BComputeDataType = ck::f8_t; + +template +using S = ck::Sequence; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = ConvScaleRelu; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3< + NDimSpatial, // NDimSpatial + InLayout, // ALayout + WeiLayout, // BLayout + DsLayout, // DsLayout (empty tuple for ConvScaleRelu) + OutLayout, // ELayout + InDataType, // ADataType + WeiDataType, // BDataType + AccDataType, // AccDataType + CShuffleDataType, // CShuffleDataType + DsDataType, // DsDataType (empty tuple) + OutDataType, // EDataType + InElementOp, // AElementwiseOperation + WeiElementOp, // BElementwiseOperation + OutElementOp, // CDEElementwiseOperation + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 64, // BlockSize + 64, // MPerBlock + 64, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 16, // MPerWmma + 16, // NPerWmma + 4, // MRepeat + 2, // NRepeat + S<4, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 1, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 1, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMRepeatPerShuffle + 1, // CShuffleNRepeatPerShuffle + S<1, 16, 1, 4>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 1, // CDEBlockTransferScalarPerVector_NPerBlock + ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched + ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer + AComputeDataType, // AComputeDataType + BComputeDataType, // BComputeDataType + 1>; // NumGroupsToMerge + +#include "run_convnd_fwd_convscale_relu_example.inc" + +int main(int argc, char* argv[]) +{ + if(!ck::is_gfx12_supported()) + { + std::cout << "This kernel support gfx12 only" << std::endl; + + return 0; + } + return run_convnd_fwd_example(argc, argv) ? 0 : 1; +} diff --git a/example/62_convnd_activ/dynamic_unary/CMakeLists.txt b/example/62_convnd_activ/dynamic_unary/CMakeLists.txt index 9efc48f9054..9de3514fc37 100644 --- a/example/62_convnd_activ/dynamic_unary/CMakeLists.txt +++ b/example/62_convnd_activ/dynamic_unary/CMakeLists.txt @@ -37,4 +37,43 @@ add_example_executable(example_convnd_fwd_xdl_dynamic_passthrough_fp16 convnd_fw add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_passthrough_fp16) # Logistic add_example_executable(example_convnd_fwd_xdl_dynamic_logistic_fp16 convnd_fwd_xdl_dynamic_logistic_fp16.cpp) -add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_logistic_fp16) \ No newline at end of file +add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_logistic_fp16) + +# WMMA +add_custom_target(example_convnd_activ_dynamic_unary_wmma) +# Abs +add_example_executable(example_convnd_fwd_wmma_dynamic_abs_fp16 convnd_fwd_wmma_dynamic_abs_fp16.cpp) +add_example_dependencies(example_convnd_activ_dynamic_unary_wmma example_convnd_fwd_wmma_dynamic_abs_fp16) +# Relu +add_example_executable(example_convnd_fwd_wmma_dynamic_relu_fp16 convnd_fwd_wmma_dynamic_relu_fp16.cpp) +add_example_dependencies(example_convnd_activ_dynamic_unary_wmma example_convnd_fwd_wmma_dynamic_relu_fp16) +# Sigmoid +add_example_executable(example_convnd_fwd_wmma_dynamic_sigmoid_fp16 convnd_fwd_wmma_dynamic_sigmoid_fp16.cpp) +add_example_dependencies(example_convnd_activ_dynamic_unary_wmma example_convnd_fwd_wmma_dynamic_sigmoid_fp16) +# Tanh +add_example_executable(example_convnd_fwd_wmma_dynamic_tanh_fp16 convnd_fwd_wmma_dynamic_tanh_fp16.cpp) +add_example_dependencies(example_convnd_activ_dynamic_unary_wmma example_convnd_fwd_wmma_dynamic_tanh_fp16) +# Pow +add_example_executable(example_convnd_fwd_wmma_dynamic_pow_fp16 convnd_fwd_wmma_dynamic_pow_fp16.cpp) +add_example_dependencies(example_convnd_activ_dynamic_unary_wmma example_convnd_fwd_wmma_dynamic_pow_fp16) +# Elu +add_example_executable(example_convnd_fwd_wmma_dynamic_elu_fp16 convnd_fwd_wmma_dynamic_elu_fp16.cpp) +add_example_dependencies(example_convnd_activ_dynamic_unary_wmma example_convnd_fwd_wmma_dynamic_elu_fp16) +# Swish +add_example_executable(example_convnd_fwd_wmma_dynamic_swish_fp16 convnd_fwd_wmma_dynamic_swish_fp16.cpp) +add_example_dependencies(example_convnd_activ_dynamic_unary_wmma example_convnd_fwd_wmma_dynamic_swish_fp16) +# Clipped Relu +add_example_executable(example_convnd_fwd_wmma_dynamic_clippedrelu_fp16 convnd_fwd_wmma_dynamic_clippedrelu_fp16.cpp) +add_example_dependencies(example_convnd_activ_dynamic_unary_wmma example_convnd_fwd_wmma_dynamic_clippedrelu_fp16) +# Leaky Relu +add_example_executable(example_convnd_fwd_wmma_dynamic_leakyrelu_fp16 convnd_fwd_wmma_dynamic_leakyrelu_fp16.cpp) +add_example_dependencies(example_convnd_activ_dynamic_unary_wmma example_convnd_fwd_wmma_dynamic_leakyrelu_fp16) +# Soft Relu +add_example_executable(example_convnd_fwd_wmma_dynamic_softrelu_fp16 convnd_fwd_wmma_dynamic_softrelu_fp16.cpp) +add_example_dependencies(example_convnd_activ_dynamic_unary_wmma example_convnd_fwd_wmma_dynamic_softrelu_fp16) +# PassThrough +add_example_executable(example_convnd_fwd_wmma_dynamic_passthrough_fp16 convnd_fwd_wmma_dynamic_passthrough_fp16.cpp) +add_example_dependencies(example_convnd_activ_dynamic_unary_wmma example_convnd_fwd_wmma_dynamic_passthrough_fp16) +# Logistic +add_example_executable(example_convnd_fwd_wmma_dynamic_logistic_fp16 convnd_fwd_wmma_dynamic_logistic_fp16.cpp) +add_example_dependencies(example_convnd_activ_dynamic_unary_wmma example_convnd_fwd_wmma_dynamic_logistic_fp16) diff --git a/example/62_convnd_activ/dynamic_unary/convnd_fwd_activ_dynamic_unary_wmma_common.hpp b/example/62_convnd_activ/dynamic_unary/convnd_fwd_activ_dynamic_unary_wmma_common.hpp new file mode 100644 index 00000000000..126f00bbfa4 --- /dev/null +++ b/example/62_convnd_activ/dynamic_unary/convnd_fwd_activ_dynamic_unary_wmma_common.hpp @@ -0,0 +1,244 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +constexpr ck::index_t NDimSpatial = 3; +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using OutDataType = ck::half_t; +using AComputeDataType = ck::half_t; +using BComputeDataType = ck::half_t; + +template +using S = ck::Sequence; + +// Use correct tensor layouts for WMMA (matching working tests) +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using DynamicElementOp = ck::tensor_operation::element_wise::DynamicUnaryOp; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceGroupedConvNDActivInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3< + NDimSpatial, // NDimSpatial + InLayout, // ALayout + WeiLayout, // BLayout + ck::Tuple<>, // DsLayout + OutLayout, // ELayout + InDataType, // ADataType + WeiDataType, // BDataType + AccDataType, // AccDataType + CShuffleDataType, // CShuffleDataType + ck::Tuple<>, // DsDataType + OutDataType, // EDataType + InElementOp, // AElementwiseOperation + WeiElementOp, // BElementwiseOperation + DynamicElementOp, // CDEElementwiseOperation + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 64, // BlockSize + 64, // MPerBlock + 64, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 16, // MPerWmma + 16, // NPerWmma + 4, // MRepeat + 2, // NRepeat + S<4, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 1, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 1, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMRepeatPerShuffle + 1, // CShuffleNRepeatPerShuffle + S<1, 16, 1, 4>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 1, // CDEBlockTransferScalarPerVector_NPerBlock + ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched + ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer + AComputeDataType, // AComputeDataType + BComputeDataType, // BComputeDataType + 1>; // NumGroupsToMerge + +template +bool run_grouped_conv(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op, + const OutElementOp& out_element_op) +{ + Tensor in(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor out_host(out_g_n_k_wos_desc); + Tensor out_device(out_g_n_k_wos_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "out: " << out_host.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + wei.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-0.05, 0.05}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + // do Conv + auto conv = DeviceConvNDFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + std::array{}, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, 0>{{}}, + std::array, 0>{{}}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error("The device op with the specified compilation parameters does " + "not support this convolution problem."); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = conv_param.GetFlops(); + std::size_t num_btype = conv_param.GetByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + if(do_verification) + { + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei, + out_host, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op); + + ref_invoker.Run(ref_argument); + + out_device_buf.FromDevice(out_device.mData.data()); + + return ck::utils::check_err(out_device, out_host, "Error: incorrect results!", 1e-3, 0.1); + } + + return true; +} diff --git a/example/62_convnd_activ/dynamic_unary/convnd_fwd_wmma_dynamic_abs_fp16.cpp b/example/62_convnd_activ/dynamic_unary/convnd_fwd_wmma_dynamic_abs_fp16.cpp new file mode 100644 index 00000000000..b24545e99c4 --- /dev/null +++ b/example/62_convnd_activ/dynamic_unary/convnd_fwd_wmma_dynamic_abs_fp16.cpp @@ -0,0 +1,12 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "convnd_fwd_activ_dynamic_unary_wmma_common.hpp" + +#include "../run_convnd_activ_dynamic_example.inc" + +int main(int argc, char* argv[]) +{ + ck::tensor_operation::element_wise::UnaryAbs out_element_op; + return !run_convnd_example(argc, argv, out_element_op); +} diff --git a/example/62_convnd_activ/dynamic_unary/convnd_fwd_wmma_dynamic_clippedrelu_fp16.cpp b/example/62_convnd_activ/dynamic_unary/convnd_fwd_wmma_dynamic_clippedrelu_fp16.cpp new file mode 100644 index 00000000000..06cda3c7be0 --- /dev/null +++ b/example/62_convnd_activ/dynamic_unary/convnd_fwd_wmma_dynamic_clippedrelu_fp16.cpp @@ -0,0 +1,12 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "convnd_fwd_activ_dynamic_unary_wmma_common.hpp" + +#include "../run_convnd_activ_dynamic_example.inc" + +int main(int argc, char* argv[]) +{ + ck::tensor_operation::element_wise::ClippedRelu out_element_op(6.f); + return !run_convnd_example(argc, argv, out_element_op); +} diff --git a/example/62_convnd_activ/dynamic_unary/convnd_fwd_wmma_dynamic_elu_fp16.cpp b/example/62_convnd_activ/dynamic_unary/convnd_fwd_wmma_dynamic_elu_fp16.cpp new file mode 100644 index 00000000000..c327798db58 --- /dev/null +++ b/example/62_convnd_activ/dynamic_unary/convnd_fwd_wmma_dynamic_elu_fp16.cpp @@ -0,0 +1,12 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "convnd_fwd_activ_dynamic_unary_wmma_common.hpp" + +#include "../run_convnd_activ_dynamic_example.inc" + +int main(int argc, char* argv[]) +{ + ck::tensor_operation::element_wise::Elu out_element_op(2.f); + return !run_convnd_example(argc, argv, out_element_op); +} diff --git a/example/62_convnd_activ/dynamic_unary/convnd_fwd_wmma_dynamic_leakyrelu_fp16.cpp b/example/62_convnd_activ/dynamic_unary/convnd_fwd_wmma_dynamic_leakyrelu_fp16.cpp new file mode 100644 index 00000000000..4e32c47ccd5 --- /dev/null +++ b/example/62_convnd_activ/dynamic_unary/convnd_fwd_wmma_dynamic_leakyrelu_fp16.cpp @@ -0,0 +1,12 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "convnd_fwd_activ_dynamic_unary_wmma_common.hpp" + +#include "../run_convnd_activ_dynamic_example.inc" + +int main(int argc, char* argv[]) +{ + ck::tensor_operation::element_wise::LeakyRelu out_element_op(0.2f); + return !run_convnd_example(argc, argv, out_element_op); +} diff --git a/example/62_convnd_activ/dynamic_unary/convnd_fwd_wmma_dynamic_logistic_fp16.cpp b/example/62_convnd_activ/dynamic_unary/convnd_fwd_wmma_dynamic_logistic_fp16.cpp new file mode 100644 index 00000000000..e55ed7f66e8 --- /dev/null +++ b/example/62_convnd_activ/dynamic_unary/convnd_fwd_wmma_dynamic_logistic_fp16.cpp @@ -0,0 +1,12 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "convnd_fwd_activ_dynamic_unary_wmma_common.hpp" + +#include "../run_convnd_activ_dynamic_example.inc" + +int main(int argc, char* argv[]) +{ + ck::tensor_operation::element_wise::Logistic out_element_op(1.0f); + return !run_convnd_example(argc, argv, out_element_op); +} diff --git a/example/62_convnd_activ/dynamic_unary/convnd_fwd_wmma_dynamic_passthrough_fp16.cpp b/example/62_convnd_activ/dynamic_unary/convnd_fwd_wmma_dynamic_passthrough_fp16.cpp new file mode 100644 index 00000000000..b2045afadb4 --- /dev/null +++ b/example/62_convnd_activ/dynamic_unary/convnd_fwd_wmma_dynamic_passthrough_fp16.cpp @@ -0,0 +1,12 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "convnd_fwd_activ_dynamic_unary_wmma_common.hpp" + +#include "../run_convnd_activ_dynamic_example.inc" + +int main(int argc, char* argv[]) +{ + ck::tensor_operation::element_wise::PassThrough out_element_op; + return !run_convnd_example(argc, argv, out_element_op); +} diff --git a/example/62_convnd_activ/dynamic_unary/convnd_fwd_wmma_dynamic_pow_fp16.cpp b/example/62_convnd_activ/dynamic_unary/convnd_fwd_wmma_dynamic_pow_fp16.cpp new file mode 100644 index 00000000000..73d68a58169 --- /dev/null +++ b/example/62_convnd_activ/dynamic_unary/convnd_fwd_wmma_dynamic_pow_fp16.cpp @@ -0,0 +1,12 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "convnd_fwd_activ_dynamic_unary_wmma_common.hpp" + +#include "../run_convnd_activ_dynamic_example.inc" + +int main(int argc, char* argv[]) +{ + ck::tensor_operation::element_wise::Power out_element_op(4.f, 1.f, 2.f); + return !run_convnd_example(argc, argv, out_element_op); +} diff --git a/example/62_convnd_activ/dynamic_unary/convnd_fwd_wmma_dynamic_relu_fp16.cpp b/example/62_convnd_activ/dynamic_unary/convnd_fwd_wmma_dynamic_relu_fp16.cpp new file mode 100644 index 00000000000..e708d3f72bc --- /dev/null +++ b/example/62_convnd_activ/dynamic_unary/convnd_fwd_wmma_dynamic_relu_fp16.cpp @@ -0,0 +1,12 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "convnd_fwd_activ_dynamic_unary_wmma_common.hpp" + +#include "../run_convnd_activ_dynamic_example.inc" + +int main(int argc, char* argv[]) +{ + ck::tensor_operation::element_wise::Relu out_element_op; + return !run_convnd_example(argc, argv, out_element_op); +} diff --git a/example/62_convnd_activ/dynamic_unary/convnd_fwd_wmma_dynamic_sigmoid_fp16.cpp b/example/62_convnd_activ/dynamic_unary/convnd_fwd_wmma_dynamic_sigmoid_fp16.cpp new file mode 100644 index 00000000000..f1dbf8eac5f --- /dev/null +++ b/example/62_convnd_activ/dynamic_unary/convnd_fwd_wmma_dynamic_sigmoid_fp16.cpp @@ -0,0 +1,12 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "convnd_fwd_activ_dynamic_unary_wmma_common.hpp" + +#include "../run_convnd_activ_dynamic_example.inc" + +int main(int argc, char* argv[]) +{ + ck::tensor_operation::element_wise::Sigmoid out_element_op; + return !run_convnd_example(argc, argv, out_element_op); +} diff --git a/example/62_convnd_activ/dynamic_unary/convnd_fwd_wmma_dynamic_softrelu_fp16.cpp b/example/62_convnd_activ/dynamic_unary/convnd_fwd_wmma_dynamic_softrelu_fp16.cpp new file mode 100644 index 00000000000..5fefbeeb5ac --- /dev/null +++ b/example/62_convnd_activ/dynamic_unary/convnd_fwd_wmma_dynamic_softrelu_fp16.cpp @@ -0,0 +1,12 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "convnd_fwd_activ_dynamic_unary_wmma_common.hpp" + +#include "../run_convnd_activ_dynamic_example.inc" + +int main(int argc, char* argv[]) +{ + ck::tensor_operation::element_wise::SoftRelu out_element_op; + return !run_convnd_example(argc, argv, out_element_op); +} diff --git a/example/62_convnd_activ/dynamic_unary/convnd_fwd_wmma_dynamic_swish_fp16.cpp b/example/62_convnd_activ/dynamic_unary/convnd_fwd_wmma_dynamic_swish_fp16.cpp new file mode 100644 index 00000000000..8d8947280b3 --- /dev/null +++ b/example/62_convnd_activ/dynamic_unary/convnd_fwd_wmma_dynamic_swish_fp16.cpp @@ -0,0 +1,12 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "convnd_fwd_activ_dynamic_unary_wmma_common.hpp" + +#include "../run_convnd_activ_dynamic_example.inc" + +int main(int argc, char* argv[]) +{ + ck::tensor_operation::element_wise::Swish out_element_op; + return !run_convnd_example(argc, argv, out_element_op); +} diff --git a/example/62_convnd_activ/dynamic_unary/convnd_fwd_wmma_dynamic_tanh_fp16.cpp b/example/62_convnd_activ/dynamic_unary/convnd_fwd_wmma_dynamic_tanh_fp16.cpp new file mode 100644 index 00000000000..fef6d85e629 --- /dev/null +++ b/example/62_convnd_activ/dynamic_unary/convnd_fwd_wmma_dynamic_tanh_fp16.cpp @@ -0,0 +1,12 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "convnd_fwd_activ_dynamic_unary_wmma_common.hpp" + +#include "../run_convnd_activ_dynamic_example.inc" + +int main(int argc, char* argv[]) +{ + ck::tensor_operation::element_wise::TanH out_element_op; + return !run_convnd_example(argc, argv, out_element_op); +} diff --git a/example/62_convnd_activ/multi_AB/CMakeLists.txt b/example/62_convnd_activ/multi_AB/CMakeLists.txt index fc637cdbc6b..0ed9a770837 100644 --- a/example/62_convnd_activ/multi_AB/CMakeLists.txt +++ b/example/62_convnd_activ/multi_AB/CMakeLists.txt @@ -11,3 +11,12 @@ add_example_executable(example_conv_fwd_xdl_scaleadd_ab_bf16 conv_fwd_xdl_scalea add_example_dependencies(example_convnd_activ_multi_ab_xdl example_conv_fwd_xdl_scaleadd_ab_bf16) add_example_executable(example_conv_fwd_xdl_scaleadd_ab_int8 conv_fwd_xdl_scaleadd_ab_int8.cpp) add_example_dependencies(example_convnd_activ_multi_ab_xdl example_conv_fwd_xdl_scaleadd_ab_int8) + +add_custom_target(example_convnd_activ_multi_ab_wmma_cshufflev3) +# ScaleAdd on A and B +add_example_executable(example_conv_fwd_wmma_cshufflev3_scaleadd_ab_fp16 conv_fwd_wmma_cshufflev3_scaleadd_ab_fp16.cpp) +add_example_dependencies(example_convnd_activ_multi_ab_wmma_cshufflev3 example_conv_fwd_wmma_cshufflev3_scaleadd_ab_fp16) +add_example_executable(example_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16 conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16.cpp) +add_example_dependencies(example_convnd_activ_multi_ab_wmma_cshufflev3 example_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16) +add_example_executable(example_conv_fwd_wmma_cshufflev3_scaleadd_ab_int8 conv_fwd_wmma_cshufflev3_scaleadd_ab_int8.cpp) +add_example_dependencies(example_convnd_activ_multi_ab_wmma_cshufflev3 example_conv_fwd_wmma_cshufflev3_scaleadd_ab_int8) diff --git a/example/62_convnd_activ/multi_AB/conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16.cpp b/example/62_convnd_activ/multi_AB/conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16.cpp new file mode 100644 index 00000000000..746f56a6ac3 --- /dev/null +++ b/example/62_convnd_activ/multi_AB/conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16.cpp @@ -0,0 +1,27 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#define EXAMPLE_USE_WMMA +#include "convnd_fwd_activ_multi_ab_common.hpp" + +using DataType = ck::bhalf_t; +using AccDataType = float; +using InDataType = DataType; +using WeiDataType = DataType; +using OutDataType = DataType; +using ADataTypes = ck::Tuple; +using BDataTypes = ck::Tuple; + +using InElementOp = ck::tensor_operation::element_wise::ScaleAdd; +using WeiElementOp = ck::tensor_operation::element_wise::ScaleAdd; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDMultiABFwdInstance; + +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/multi_AB/conv_fwd_wmma_cshufflev3_scaleadd_ab_fp16.cpp b/example/62_convnd_activ/multi_AB/conv_fwd_wmma_cshufflev3_scaleadd_ab_fp16.cpp new file mode 100644 index 00000000000..2de837452d1 --- /dev/null +++ b/example/62_convnd_activ/multi_AB/conv_fwd_wmma_cshufflev3_scaleadd_ab_fp16.cpp @@ -0,0 +1,27 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#define EXAMPLE_USE_WMMA +#include "convnd_fwd_activ_multi_ab_common.hpp" + +using DataType = ck::half_t; +using AccDataType = float; +using InDataType = DataType; +using WeiDataType = DataType; +using OutDataType = DataType; +using ADataTypes = ck::Tuple; +using BDataTypes = ck::Tuple; + +using InElementOp = ck::tensor_operation::element_wise::ScaleAdd; +using WeiElementOp = ck::tensor_operation::element_wise::ScaleAdd; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDMultiABFwdInstance; + +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/multi_AB/conv_fwd_wmma_cshufflev3_scaleadd_ab_int8.cpp b/example/62_convnd_activ/multi_AB/conv_fwd_wmma_cshufflev3_scaleadd_ab_int8.cpp new file mode 100644 index 00000000000..037311d17fe --- /dev/null +++ b/example/62_convnd_activ/multi_AB/conv_fwd_wmma_cshufflev3_scaleadd_ab_int8.cpp @@ -0,0 +1,27 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#define EXAMPLE_USE_WMMA +#include "convnd_fwd_activ_multi_ab_common.hpp" + +using DataType = int8_t; +using AccDataType = int32_t; +using InDataType = DataType; +using WeiDataType = DataType; +using OutDataType = DataType; +using ADataTypes = ck::Tuple; +using BDataTypes = ck::Tuple; + +using InElementOp = ck::tensor_operation::element_wise::ScaleAdd; +using WeiElementOp = ck::tensor_operation::element_wise::ScaleAdd; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDMultiABFwdInstance; + +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/multi_AB/convnd_fwd_activ_multi_ab_common.hpp b/example/62_convnd_activ/multi_AB/convnd_fwd_activ_multi_ab_common.hpp index a9aa4645d31..ba6a2ea88be 100644 --- a/example/62_convnd_activ/multi_AB/convnd_fwd_activ_multi_ab_common.hpp +++ b/example/62_convnd_activ/multi_AB/convnd_fwd_activ_multi_ab_common.hpp @@ -9,7 +9,11 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#ifdef EXAMPLE_USE_WMMA +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp" +#else #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" +#endif #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" @@ -41,6 +45,62 @@ static constexpr auto ConvSpec = static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +#ifdef EXAMPLE_USE_WMMA +template +using DeviceGroupedConvNDMultiABFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple<>, + OutLayout, + InDataTypes, + WeiDataTypes, + AccDataType, + DataType, + ck::Tuple<>, + DataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 16, // MPerWmma + 16, // NPerWmma + 4, // MWmmaPerWave + 4, // NWmmaPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1>; +#else template , 4>; +#endif namespace { template && + init_method != 2) + { + std::cout << "Running SoftRelu op with int initialization. Risk of overflow.\n\n"; + } + const auto in_element_op = InElementOp{}; const auto wei_element_op = WeiElementOp{}; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp new file mode 100644 index 00000000000..df128c10b9c --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp @@ -0,0 +1,2353 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/library/utility/numeric.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +namespace { + +/* + * \brief Wrapper function of GridwiseGemm Wmma Cshuffle V3 to realize grouped forward convolution. + * + * \tparam ComputePtrOffset Class that computes the base pointer offsets of A, B, D and E + * matrices for groups or splitN. Currently it works for identical strides, but this can be extended + * to other layouts. The returned offset can be either \p index_t or \p long_index_t. If it returns + * \p long_index_t, we are not subject to the 2GB limitations. + * + * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in the id of a workgroup and + * returns the 2D index of the tile that it computes. \see + * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run(). + */ +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_grouped_conv_fwd_wmma_cshuffle_v3( + typename GridwiseGemm::Argument karg, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const ComputePtrOffset compute_ptr_offset_of_batch, + const ComputePtrOffset compute_ptr_offset_of_n, + const index_t num_k_per_block) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + using e_data_type = remove_cvref_t>; + if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) + { +#endif + __shared__ char p_shared[GridwiseGemm::template GetSharedMemoryNumberOfByte< + typename GridwiseGemm::EpilogueCShuffle>()]; + + auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + + GridwiseGemm::template Run(p_shared, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + compute_ptr_offset_of_batch, + compute_ptr_offset_of_n, + num_k_per_block, + karg, + epilogue_args); +#if defined(__gfx11__) + } +#endif +#else + ignore = karg; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = compute_ptr_offset_of_batch; + ignore = compute_ptr_offset_of_n; + ignore = num_k_per_block; +#endif // End of if (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) +} + +} // namespace + +template +using is_tuple = decltype(std::declval().IsTuple()); + +// +// @brief Device Convolution operation. +// +// Supports: +// @li Forward convolution with up to 3 spatial dimentions +// @li Input tensor in GNWC data format +// @li Weight tensor in GKXC data format +// @li Output tensor in GNWK data format +// +// 1D: +// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C] +// 2D: +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +// 3D: +// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C] +// +template ::value, + Number<0>, + ADataType>()), // ComputeType is InputType by default (first + // in tuple for MultiAB), unpack if tuple was + // passed + typename BComputeDataType = AComputeDataType, + index_t NumGroupsToMerge = 1> +struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 + : public DeviceGroupedConvFwdMultipleABD +{ + using DeviceOp = DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3; + + static_assert(NumGroupsToMerge >= 1); + + static constexpr bool isMultiA = is_detected::value; + static constexpr bool isMultiB = is_detected::value; + static constexpr bool isMultiAB = isMultiA || isMultiB; + static constexpr bool isMultiD = DsDataType::Size() > 0; + + // Note: I don't think this case ever occurs. + static constexpr bool isMultiABD = isMultiA && isMultiB && isMultiD; + + // NGCHW is not supported for multiAB. + static_assert(!(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) || + !(isMultiA || isMultiB)); + + static constexpr index_t NumATensor = GetNumABTensors(); + static constexpr index_t NumBTensor = GetNumABTensors(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + // TODO: This parameter is no longer supported by Gridwise! + // static constexpr bool DoElementwiseBeforeCShuffle = + // !isMultiD && is_same_v && + // !is_same_v; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr bool isATensorColMajor = + (ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) && + (ABlockTransferSrcVectorDim == 1) && (NumGroupsToMerge == 1) && + (is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()); + + static constexpr bool NeedTransposeKernel = + (isATensorColMajor == false) && (is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()); + + static constexpr bool CTranspose = (NeedTransposeKernel == false) && (isMultiAB == false) && + (is_same_v || + is_same_v); + + // Generate vector size for C & Ds + using CDEBlockTransferScalarPerVectors = + typename uniform_sequence_gen::type; + + using ConvToGemmFwdTransformer = TransformConvFwdToGemm; + + using ComputePtrOffset = ComputePtrOffsetOfStridedBatch; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + static constexpr index_t ClusterLengthNPerBlock = + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3); + + static constexpr auto conv_ngchw_to_nhwgc_transformer = + TransformConvNGCHWToNHWGC{}; + + template + static auto + MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + + { + namespace ctc = tensor_layout::convolution; + using Layout = std::conditional_t< + is_NGCHW_NGKHW() && NeedTransposeKernel, + ctc::NHWGC, + std::conditional_t() && NeedTransposeKernel, + ctc::NDHWGC, + ALay>>; + + const auto in_gemmmraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeADescriptor_M_K(); + + const auto in_gemmm_gemmk_desc = + matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); + + const auto M = in_gemmm_gemmk_desc.GetLength(I0); + const auto K = in_gemmm_gemmk_desc.GetLength(I1); + + const auto AK0 = K / AK1; + + return transform_tensor_descriptor(in_gemmm_gemmk_desc, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + static auto + MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + { + namespace ctc = tensor_layout::convolution; + using Layout = std::conditional_t< + is_NGCHW_NGKHW() && NeedTransposeKernel, + ctc::GKYXC, + std::conditional_t() && NeedTransposeKernel, + ctc::GKZYXC, + BLay>>; + + const auto wei_gemmnraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeBDescriptor_N_K(); + + const auto wei_gemmn_gemmk_desc = + matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); + + const auto N = wei_gemmn_gemmk_desc.GetLength(I0); + const auto K = wei_gemmn_gemmk_desc.GetLength(I1); + + const auto BK0 = K / BK1; + + return transform_tensor_descriptor(wei_gemmn_gemmk_desc, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + + { + namespace ctc = tensor_layout::convolution; + using Layout = std::conditional_t< + is_NGCHW_NGKHW() && NeedTransposeKernel, + ctc::NHWGK, + std::conditional_t() && NeedTransposeKernel, + ctc::NDHWGK, + ELay>>; + + const auto out_gemmmraw_gemmnraw_desc = + conv_to_gemm_transformer.template MakeCDescriptor_M_N(); + + if constexpr(CTranspose) + { + constexpr auto matrix_padder_trans = + MatrixPadder{NPerBlock, MPerBlock, KPerBlock}; + return matrix_padder_trans.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); + } + else + { + return matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); + } + } + + // Shape of Ds and E must be aligned. Strides can be different. + // Pass e_g_n_k_wos_lengths for logical broadcast. + static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer); + }, + Number{}); + } + + // Gridwise always expects tuple of datatypes. + using GemmAsDataType = std::conditional_t, ADataType>; + using GemmBsDataType = std::conditional_t, BDataType>; + + // Use appropriate gridwise gemm + // Note / TODO: After the convolution has been converted to gemm, the 2D tensor descriptors will + // in general not be RowMajor or ColumnMajor but have a more complex layout. For now we just + // pass RCR (or CRC for CTranspose) to the gridwise gemm. This is currently only used to + // determine the LDS block descriptors, *IF* we are not using extraM and extraN. It seems like + // we are able to freely set these anyway without affecting results, but RCR (or CRC for + // CTranspose) is supposedly the most accurate (and perhaps performant). The 2D layouts are also + // used in the gridwise CheckValidity() function, where it determines some vector access checks + // and MNPerBlock if we are not using padding. We may not actually needs these checks but keep + // them for now. + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< + tensor_layout::gemm::RowMajor, // See Note above + tensor_layout::gemm::ColumnMajor, // See Note above + DsLayout, + tensor_layout::gemm::RowMajor, // See Note above + GemmAsDataType, + GemmBsDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, // AThreadTransferSrcResetCoordinateAfterRun + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, // BThreadTransferSrcResetCoordinateAfterRun + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + AComputeDataType, + BComputeDataType, + false, // PermuteA + false, // PermuteB + false, // IsBPreShuffled + true>; // ForceThreadTileTransfer + + // TODO: Previously available template param DoElementwiseBeforeCShuffle! + + // In case of CTranspose we swap the following template parameters: + // DataType, ElementWiseOp, PerBlock, K1, PerWmma, Repeat, All block transfer params. + using GridwiseGemmSwappedParams = GridwiseGemm_wmma_cshuffle_v3< + tensor_layout::gemm::ColumnMajor, // See Note above + tensor_layout::gemm::RowMajor, // See Note above + + DsLayout, + tensor_layout::gemm::ColumnMajor, // See Note above + + GemmBsDataType, + GemmAsDataType, + + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + + BElementwiseOperation, + AElementwiseOperation, + + CDEElementwiseOperation, + GemmSpec, + BlockSize, + + NPerBlock, + MPerBlock, + + KPerBlock, + + BK1, + AK1, + + NPerWmma, + MPerWmma, + + NRepeat, + MRepeat, + + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, // BThreadTransferSrcResetCoordinateAfterRun + BBlockLdsExtraN, + + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, // AThreadTransferSrcResetCoordinateAfterRun + ABlockLdsExtraM, + + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + + BComputeDataType, + AComputeDataType, // TODO: Swapped these but will probably never get verified because the + // only mixed precision instances are not NCHW. + + false, // PermuteB + false, // PermuteA + false, // IsBPreShuffled + true>; // ForceThreadTileTransfer + + using GridwiseGemmCTranspose = + std::conditional_t; + + // If ADataTypes or BDataTypes is tuple, user has to pass std::array with pointers. + using APointers = + std::conditional_t&, const void*>; + using BPointers = + std::conditional_t&, const void*>; + + // desc for problem definition + constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer; + using EGridDesc_M_N = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + using DsGridDesc_M_N = + remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemmCTranspose::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{}, 1, 1))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemmCTranspose::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + EGridDesc_M_N{}, 1, 1))>; + + using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt; + + using NGCHWTransposeDescType = + remove_cvref_t({}, {}))>; + using NHWGCTransposeDescType = + remove_cvref_t({}, {}))>; + + using GKCYXTransposeDescType = + remove_cvref_t({}, {}))>; + using GKYXCTransposeDescType = + remove_cvref_t({}, {}))>; + + static constexpr index_t ElementwiseBlocksize = ClusterLengthNPerBlock * ClusterLengthNPerBlock; + + using GridwiseElementwiseInputTranspose = + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapElementwise, + element_wise::PassThrough, + ElementwiseBlocksize, + NPerBlock, + NPerBlock, + NPerBlock / ClusterLengthNPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<1, 0>, + Sequence, + Sequence, + I1, + I0>; + + using GridwiseElementwiseWeightTranspose = + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapElementwise, + element_wise::PassThrough, + ElementwiseBlocksize, + NPerBlock, + NPerBlock, + NPerBlock / ClusterLengthNPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<1, 0>, + Sequence<1>, + Sequence, + I0, + I1>; + + using GridwiseElementwiseOutputTranspose = + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapElementwise, + element_wise::PassThrough, + ElementwiseBlocksize, + NPerBlock, + NPerBlock, + NPerBlock / ClusterLengthNPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<1, 0>, + Sequence, + Sequence, + I0, + I1>; + + // desc for blockwise copy + using AGridDesc_AK0_M_AK1 = remove_cvref_t( + dummy_conv_to_gemm_transformer))>; + using BGridDesc_BK0_N_BK1 = remove_cvref_t( + dummy_conv_to_gemm_transformer))>; + + // Argument + struct Argument : public BaseArgument + { + Argument(APointers p_as, + BPointers p_bs, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + : p_as_grid_{}, + p_bs_grid_{}, + p_ds_grid_{p_ds}, + p_e_grid_{static_cast(p_e)}, + a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths}, + a_g_n_c_wis_strides_{NeedTransposeKernel + ? conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides) + : a_g_n_c_wis_strides}, + b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, + b_g_k_c_xs_strides_{NeedTransposeKernel + ? conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides( + b_g_k_c_xs_lengths, b_g_k_c_xs_strides) + : b_g_k_c_xs_strides}, + ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths}, + ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides}, + e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths}, + e_g_n_k_wos_strides_{NeedTransposeKernel + ? conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides( + e_g_n_k_wos_lengths, e_g_n_k_wos_strides) + : e_g_n_k_wos_strides}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads}, + num_group_{a_g_n_c_wis_lengths_[0]}, + conv_to_gemm_transformer_{a_g_n_c_wis_lengths_, + a_g_n_c_wis_strides_, + b_g_k_c_xs_lengths_, + b_g_k_c_xs_strides_, + e_g_n_k_wos_lengths_, + e_g_n_k_wos_strides_, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_}, + conv_N_per_block_{conv_to_gemm_transformer_.N_}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{ + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_)}, + a_grid_desc_ak0_m_ak1_{ + MakeAGridDescriptor_AK0_M_AK1(conv_to_gemm_transformer_)}, + b_grid_desc_bk0_n_bk1_{ + MakeBGridDescriptor_BK0_N_BK1(conv_to_gemm_transformer_)}, + compute_ptr_offset_of_groups_{}, + compute_ptr_offset_of_n_{}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + // A/B/E Batch/N Stride + compute_ptr_offset_of_groups_.BatchStrideA_ = + CTranspose ? b_g_k_c_xs_strides_[0] * NumGroupsToMerge + : a_g_n_c_wis_strides_[0] * NumGroupsToMerge; + compute_ptr_offset_of_groups_.BatchStrideB_ = + CTranspose ? a_g_n_c_wis_strides_[0] * NumGroupsToMerge + : b_g_k_c_xs_strides_[0] * NumGroupsToMerge; + compute_ptr_offset_of_n_.BatchStrideA_ = + CTranspose ? 0 : a_g_n_c_wis_strides_[1] * conv_N_per_block_; + compute_ptr_offset_of_n_.BatchStrideB_ = + CTranspose ? a_g_n_c_wis_strides_[1] * conv_N_per_block_ : 0; + + // Deal with the awkward APointers / BPointers types and convert to variable length + // array of const void pointers. + if constexpr(isMultiA) + { + p_as_grid_ = p_as; + } + else + { + p_as_grid_[0] = p_as; + } + if constexpr(isMultiB) + { + p_bs_grid_ = p_bs; + } + else + { + p_bs_grid_[0] = p_bs; + } + + // populate pointer, batch stride, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + // D batch stride + compute_ptr_offset_of_groups_.BatchStrideDs_(i) = + ds_g_n_k_wos_strides_[i][0] * NumGroupsToMerge; + compute_ptr_offset_of_n_.BatchStrideDs_(i) = + ds_g_n_k_wos_strides_[i][1] * conv_N_per_block_; + + ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths_, + a_g_n_c_wis_strides_, + b_g_k_c_xs_lengths_, + b_g_k_c_xs_strides_, + e_g_n_k_wos_lengths_, + ds_g_n_k_wos_strides_[i], + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_}; + + // D desc + ds_grid_desc_m_n_(i) = + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_d); + }); + + compute_ptr_offset_of_groups_.BatchStrideE_ = + e_g_n_k_wos_strides_[0] * NumGroupsToMerge; + compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides_[1] * conv_N_per_block_; + + if constexpr(NeedTransposeKernel) + { + // Use not modified base strides + a_in_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides); + a_out_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides); + + b_in_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeGKCYXTransposeDesc( + b_g_k_c_xs_lengths, b_g_k_c_xs_strides); + b_out_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeGKYXCTransposeDesc( + b_g_k_c_xs_lengths, b_g_k_c_xs_strides); + + e_in_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc( + e_g_n_k_wos_lengths, e_g_n_k_wos_strides); + e_out_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc( + e_g_n_k_wos_lengths, e_g_n_k_wos_strides); + + elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{ + a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)}; + elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapElementwise{ + b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)}; + elementwise_block_2_ctile_map_transpose_e_ = Block2TileMapElementwise{ + e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)}; + } + + { + const index_t GemmM = a_grid_desc_ak0_m_ak1_.GetLength(I1); + const index_t GemmN = b_grid_desc_bk0_n_bk1_.GetLength(I1); + const auto MBlock = CTranspose ? GridwiseGemmCTranspose::CalculateMBlock(GemmN) + : GridwiseGemmCTranspose::CalculateMBlock(GemmM); + const auto NBlock = CTranspose ? GridwiseGemmCTranspose::CalculateNBlock(GemmM) + : GridwiseGemmCTranspose::CalculateNBlock(GemmN); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemmCTranspose::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_, MBlock, NBlock); + + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemmCTranspose::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_, MBlock, NBlock); + } + } + + std::size_t GetWorkspaceATensorSizeBytes() const + { + if constexpr(NeedTransposeKernel) + { + const long_index_t a_acum = ck::accumulate_n( + a_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); + // Align to 128B + return math::integer_divide_ceil(sizeof(ADataType) * a_acum, 128) * 128; + } + else + { + return 0; + } + } + + std::size_t GetWorkspaceBTensorSizeBytes() const + { + if constexpr(NeedTransposeKernel && + (is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW())) + { + const long_index_t b_acum = ck::accumulate_n( + b_g_k_c_xs_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); + // Align to 128B + return math::integer_divide_ceil(sizeof(BDataType) * b_acum, 128) * 128; + } + else + { + return 0; + } + } + + std::size_t GetWorkspaceETensorSizeBytes() const + { + if constexpr(NeedTransposeKernel) + { + const long_index_t e_accum = ck::accumulate_n( + e_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); + return sizeof(EDataType) * e_accum; + } + else + { + return 0; + } + } + + std::size_t GetWorkspaceSizeBytes() const + { + return GetWorkspaceATensorSizeBytes() + GetWorkspaceBTensorSizeBytes() + + GetWorkspaceETensorSizeBytes(); + } + + // Calculate rotating memory buffer sizes ahead of time. The convolution to gemm + // transformer doesn't always lead to correct GetElementSpaceSize() results for the + // Tensor descriptor, so we have to do a bunch of ad-hoc corrections. There might be + // a better way to do this. + auto GetRotMemAsTensorSizeBytes() const + { + std::array size_as_buffers; + ck::index_t eff_num_group = num_group_ / NumGroupsToMerge; + + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType_single = remove_cvref_t>; + if constexpr(is_same_v || + is_same_v || + is_same_v) + { + size_as_buffers[i] = + (a_grid_desc_ak0_m_ak1_.GetElementSpaceSize() + + (num_group_ - NumGroupsToMerge) * (a_g_n_c_wis_strides_[0])) * + sizeof(ADataType_single) / GridwiseGemm::APackedSize; + } + else + { + if(CTranspose && a_g_n_c_wis_lengths_[I1] > 1) + { + size_as_buffers[i] = (a_grid_desc_ak0_m_ak1_.GetElementSpaceSize() + + (eff_num_group - 1) * (a_g_n_c_wis_strides_[0])) * + sizeof(ADataType_single) / GridwiseGemm::APackedSize; + } + else + { + size_as_buffers[i] = a_grid_desc_ak0_m_ak1_.GetElementSpaceSize() * + eff_num_group * sizeof(ADataType_single) / + GridwiseGemm::APackedSize; + } + } + }); + + return size_as_buffers; + } + + auto GetRotMemBsTensorSizeBytes() const + { + std::array size_bs_buffers; + ck::index_t eff_num_group = num_group_ / NumGroupsToMerge; + + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType_single = remove_cvref_t>; + size_bs_buffers[i] = b_grid_desc_bk0_n_bk1_.GetElementSpaceSize() * eff_num_group * + sizeof(BDataType_single) / GridwiseGemm::BPackedSize; + }); + + return size_bs_buffers; + } + + auto GetRotMemDsTensorSizeBytes() const + { + std::array size_ds_buffers; + ck::index_t eff_num_group = num_group_ / NumGroupsToMerge; + + // TODO: Ds packed size consideration? + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + using DLayout = remove_cvref_t>; + + if constexpr(is_same_v || + is_same_v || + is_same_v) + { + size_ds_buffers[i] = + (ds_grid_desc_m_n_[i].GetElementSpaceSize() + + (num_group_ - NumGroupsToMerge) * ds_g_n_k_wos_strides_[i][0]) * + sizeof(DDataType); + } + else + { + if(CTranspose && ds_g_n_k_wos_lengths_[i][I1] > 1) + { + size_ds_buffers[i] = (ds_grid_desc_m_n_[i].GetElementSpaceSize() + + (eff_num_group - 1) * (ds_g_n_k_wos_strides_[i][0])) * + sizeof(DDataType); + } + else + { + size_ds_buffers[i] = ds_grid_desc_m_n_[i].GetElementSpaceSize() * + eff_num_group * sizeof(DDataType); + } + } + }); + + return size_ds_buffers; + } + + void Print() const + { + std::cout << "A[AK0, M, AK1]: " << a_grid_desc_ak0_m_ak1_ << std::endl; + std::cout << "B[BK0, N, BK1]: " << b_grid_desc_bk0_n_bk1_ << std::endl; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); + std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; + } + + // private: + std::array p_as_grid_; + std::array p_bs_grid_; + const std::array p_ds_grid_; + EDataType* p_e_grid_; + + // for checking IsSupportedArgument() + std::array a_g_n_c_wis_lengths_; + std::array a_g_n_c_wis_strides_; + std::array b_g_k_c_xs_lengths_; + std::array b_g_k_c_xs_strides_; + std::array, NumDTensor> ds_g_n_k_wos_lengths_; + std::array, NumDTensor> ds_g_n_k_wos_strides_; + std::array e_g_n_k_wos_lengths_; + std::array e_g_n_k_wos_strides_; + std::array conv_filter_strides_; + std::array conv_filter_dilations_; + std::array input_left_pads_; + std::array input_right_pads_; + + // tensor descriptors for problem definiton + index_t num_group_; + + ConvToGemmFwdTransformer conv_to_gemm_transformer_; + index_t conv_N_per_block_; + + // tensor descriptors for block/thread-wise copy + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // for computing batch offset + ComputePtrOffset compute_ptr_offset_of_groups_; + ComputePtrOffset compute_ptr_offset_of_n_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // block-to-e-tile map + Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_, + elementwise_block_2_ctile_map_transpose_b_, elementwise_block_2_ctile_map_transpose_e_; + + NGCHWTransposeDescType a_in_transpose_desc_, e_out_transpose_desc_; + NHWGCTransposeDescType a_out_transpose_desc_, e_in_transpose_desc_; + GKCYXTransposeDescType b_in_transpose_desc_; + GKYXCTransposeDescType b_out_transpose_desc_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float RunGemm(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("\033[035mCTranspose %d extraM %d extraN %d\033[0m\n", + CTranspose, + ABlockLdsExtraM, + BBlockLdsExtraN); + } + + float ave_time = 0; + + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + + const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); + const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_.GetLength(I1); + const index_t GemmK = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + const index_t num_workgroups_per_Conv_N = + arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_; + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = + CTranspose + ? GridwiseGemmCTranspose::CalculateGridSize(GemmN, GemmM, I1 /*arg.KBatch*/) + : GridwiseGemmCTranspose::CalculateGridSize(GemmM, GemmN, I1 /*arg.KBatch*/); + + gdy = arg.num_group_ / NumGroupsToMerge; + gdz = num_workgroups_per_Conv_N; + + index_t K_split = (GemmK + KPerBlock - 1) / KPerBlock * KPerBlock; + const bool has_main_k_block_loop = + GridwiseGemmCTranspose::CalculateHasMainKBlockLoop(K_split); + const TailNumber tail_num = GridwiseGemmCTranspose::CalculateKBlockLoopTailNum(K_split); + + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("\033[092mnum_loop %d, has_main_k_loop %d, tail_num %d, G chunks %d, N " + "chunks %d\033[0m\n", + K_split / KPerBlock, + has_main_k_block_loop, + static_cast(tail_num), + gdy, + gdz); + } + + std::array p_as_grid = arg.p_as_grid_; + std::array p_bs_grid = arg.p_bs_grid_; + EDataType* p_e_grid = arg.p_e_grid_; + + // Transpose A and B, or just A. Not compatible with multi-AB. + if constexpr(NeedTransposeKernel) + { + static_assert(NumATensor == 1, "Num A Tensor should be 1\n"); + static_assert(NumBTensor == 1, "Num B Tensor should be 1\n"); + + if constexpr(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + { + p_as_grid[0] = type_convert(arg.p_workspace_); + p_bs_grid[0] = type_convert( + type_convert(arg.p_workspace_) + + arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType)); + p_e_grid = + type_convert(arg.p_workspace_) + + (arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) / + sizeof(EDataType); + } + else if constexpr(is_NGCHW_GKYXC_NGKHW() || + is_NGCDHW_GKZYXC_NGKDHW()) + { + p_as_grid[0] = type_convert(arg.p_workspace_); + p_e_grid = + type_convert(arg.p_workspace_) + + (arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) / + sizeof(EDataType); + } + } + + const auto Run = [&](const auto& kernel) { + if constexpr(CTranspose) + { + static_assert(NumATensor == 1, "Num A Tensor should be 1\n"); + static_assert(NumBTensor == 1, "Num B Tensor should be 1\n"); + + typename GridwiseGemmCTranspose::Argument gemm_arg{ + p_bs_grid, // p_bs_grid + p_as_grid, // p_as_grid + arg.p_ds_grid_, + p_e_grid, + + GemmN, + GemmM, + + GemmK, + // No need to set strides, we pass descs to kernel + {}, // StrideBs + {}, // StrideAs + {}, // StrideDs + I0, // StrideE + I1, // kbatch + arg.b_element_op_, + arg.a_element_op_, + arg.cde_element_op_}; + // TODO: No is_reduce argument, defaults to false. + + if(stream_config.flush_cache && !NeedTransposeKernel) + { + typename GridwiseGemmCTranspose::Argument gemm_arg_ = gemm_arg; + + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("\033[032mUsing rotating memory num group %d eff %d!\033[0m\n", + arg.num_group_, + arg.num_group_ / NumGroupsToMerge); + } + + ck::utility::RotatingMemWrapperMultiABD< + typename GridwiseGemmCTranspose::Argument, + GemmBsDataType, + GemmAsDataType, + DsDataType> + rotating_mem(gemm_arg_, + stream_config.rotating_count, + arg.GetRotMemBsTensorSizeBytes(), + arg.GetRotMemAsTensorSizeBytes(), + arg.GetRotMemDsTensorSizeBytes()); + + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear E mem + + // TODO: The calculation of the E buffer size may not be correct in all + // cases, for example if the memory is not contiguous due to padding or + // unusual strides. Investigate when implementing splitK. It may be + // safer to use GetElementSpaceSize(). + + // if(arg_.KBatch > 1) + // HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_e_grid, + // 0, + // arg_.M * arg_.N * + // sizeof(EDataType), + // stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg_, + arg.b_grid_desc_bk0_n_bk1_, + arg.a_grid_desc_ak0_m_ak1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.compute_ptr_offset_of_groups_, + arg.compute_ptr_offset_of_n_, + KPerBlock); // TODO: splitK consideration (num_k_per_block) + } + else + { + ave_time += launch_and_time_kernel( + stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg, + arg.b_grid_desc_bk0_n_bk1_, + arg.a_grid_desc_ak0_m_ak1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.compute_ptr_offset_of_groups_, + arg.compute_ptr_offset_of_n_, + KPerBlock); // TODO: splitK consideration (num_k_per_block) + } + } + else + { + typename GridwiseGemm::Argument gemm_arg{ + p_as_grid, // p_as_grid + p_bs_grid, // p_bs_grid + arg.p_ds_grid_, + p_e_grid, + GemmM, + GemmN, + GemmK, + // No need to set strides, we pass descs to kernel + {}, // StrideAs + {}, // StrideBs + {}, // StrideDs + I0, // StrideE + I1, // kbatch + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_}; + // TODO: No is_reduce argument, defaults to false. + + if(stream_config.flush_cache && !NeedTransposeKernel) + { + typename GridwiseGemm::Argument gemm_arg_ = gemm_arg; + + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("\033[032mUsing rotating memory num group %d eff %d!\033[0m\n", + arg.num_group_, + arg.num_group_ / NumGroupsToMerge); + } + + ck::utility::RotatingMemWrapperMultiABD + rotating_mem(gemm_arg_, + stream_config.rotating_count, + arg.GetRotMemAsTensorSizeBytes(), + arg.GetRotMemBsTensorSizeBytes(), + arg.GetRotMemDsTensorSizeBytes()); + + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + + // TODO: The calculation of the E buffer size may not be correct in all + // cases, for example if the memory is not contiguous due to padding or + // unusual strides. Investigate when implementing splitK. It may be + // safer to use GetElementSpaceSize(). + + // if(arg_.KBatch > 1) + // HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_e_grid, + // 0, + // arg_.M * arg_.N * + // sizeof(EDataType), + // stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.compute_ptr_offset_of_groups_, + arg.compute_ptr_offset_of_n_, + KPerBlock); // TODO: splitK consideration (num_k_per_block) + } + else + { + ave_time += launch_and_time_kernel( + stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.compute_ptr_offset_of_groups_, + arg.compute_ptr_offset_of_n_, + KPerBlock); // TODO: splitK consideration (num_k_per_block) + } + } + }; + + auto CreateAndRunKernel = [&](auto has_main_k_block_loop_, auto tail_number_) { + constexpr bool has_loop = decltype(has_main_k_block_loop_)::value; + constexpr TailNumber tn = tail_number_; + + if constexpr(CTranspose) + { + const auto kernel = kernel_grouped_conv_fwd_wmma_cshuffle_v3< + GridwiseGemmCTranspose, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffset, + has_loop, // HasMainKBlockLoop + InMemoryDataOperationEnum::Set, + minimum_occupancy, + tn>; // TailNumber TailNum = TailNumber::Full + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_fwd_wmma_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffset, + has_loop, // HasMainKBlockLoop + InMemoryDataOperationEnum::Set, + minimum_occupancy, + tn>; // TailNumber TailNum = TailNumber::Full + Run(kernel); + } + }; + + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(has_main_k_block_loop && tail_num == TailNumber::Full) + { + CreateAndRunKernel(std::integral_constant{}, + std::integral_constant{}); + } + else if(!has_main_k_block_loop && tail_num == TailNumber::Full) + { + CreateAndRunKernel(std::integral_constant{}, + std::integral_constant{}); + } + else + { + printf("Invalid has_main_k_block_loop and tail_num combination for V1!\n"); + return 0.0f; + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(has_main_k_block_loop && tail_num == TailNumber::Full) + { + CreateAndRunKernel(std::integral_constant{}, + std::integral_constant{}); + } + else if(!has_main_k_block_loop && tail_num == TailNumber::Even) + { + CreateAndRunKernel(std::integral_constant{}, + std::integral_constant{}); + } + else if(!has_main_k_block_loop && tail_num == TailNumber::Odd) + { + CreateAndRunKernel(std::integral_constant{}, + std::integral_constant{}); + } + else + { + printf("Invalid has_main_k_block_loop and tail_num combination for V3!\n"); + return 0.0f; + } + } + else + { + printf("Invalid pipeline version!\n"); + return 0.0f; + } + + return ave_time; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float avg_time = 0.f; + if constexpr(!isMultiABD) + { + // At least transpose A from NGCHW to NHWGC, and if necessary transpose B from GKCYX + // to GKYXC. + if constexpr(NeedTransposeKernel) + { + static_assert(NumATensor == 1, "Num A Tensor should be 1\n"); + static_assert(NumBTensor == 1, "Num B Tensor should be 1\n"); + + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("\033[32mPerforming transpose forward\033[0m\n"); + } + const index_t a_grid_size = + arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize( + arg.a_in_transpose_desc_); + const index_t b_grid_size = + (is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + ? arg.elementwise_block_2_ctile_map_transpose_b_.CalculateGridSize( + arg.b_in_transpose_desc_) + : 0; // Dont run transpose B if not needed + + ADataType* p_a_out_grid = type_convert(arg.p_workspace_); + BDataType* p_b_out_grid = + type_convert(arg.p_workspace_) + + arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType); + + auto kernel_transpose = + kernel_elementwise_dual, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Block2TileMapElementwise, + Block2TileMapElementwise, + element_wise::PassThrough>; + + avg_time += launch_and_time_kernel( + stream_config, + kernel_transpose, + dim3(a_grid_size + b_grid_size), + dim3(ElementwiseBlocksize), + 0, + make_tuple(arg.a_in_transpose_desc_), + make_tuple(arg.b_in_transpose_desc_), + make_tuple(arg.a_out_transpose_desc_), + make_tuple(arg.b_out_transpose_desc_), + make_tuple(type_convert(arg.p_as_grid_[0])), + make_tuple(type_convert(arg.p_bs_grid_[0])), + make_tuple(p_a_out_grid), + make_tuple(p_b_out_grid), + arg.elementwise_block_2_ctile_map_transpose_a_, + arg.elementwise_block_2_ctile_map_transpose_b_, + element_wise::PassThrough{}, + a_grid_size); + } + + avg_time += RunGemm(arg, stream_config); + + // Transpose result back to NGCHW + if constexpr(NeedTransposeKernel) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("\033[32mPerforming transpose back\033[0m\n"); + } + const index_t grid_size = + arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize( + arg.e_in_transpose_desc_); + + const EDataType* p_e_in_grid = + type_convert(arg.p_workspace_) + + (arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) / + sizeof(EDataType); + + EDataType* p_e_out_grid = arg.p_e_grid_; + + auto kernel_transpose = kernel_elementwise, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Block2TileMapElementwise, + element_wise::PassThrough>; + + avg_time += + launch_and_time_kernel(stream_config, + kernel_transpose, + dim3(grid_size), + dim3(ElementwiseBlocksize), + 0, + make_tuple(arg.e_in_transpose_desc_), + make_tuple(arg.e_out_transpose_desc_), + make_tuple(p_e_in_grid), + make_tuple(p_e_out_grid), + arg.elementwise_block_2_ctile_map_transpose_e_, + element_wise::PassThrough{}); + } + } + return avg_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + namespace ctc = tensor_layout::convolution; + + const index_t G = arg.b_g_k_c_xs_lengths_[I0]; + const index_t K = arg.b_g_k_c_xs_lengths_[I1]; + const index_t C = arg.b_g_k_c_xs_lengths_[I2]; + + const index_t input_spatial_acum = ck::accumulate_n( + arg.a_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>()); + const index_t output_spatial_acum = ck::accumulate_n( + arg.e_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>()); + + // Move this to runtime check to align Conv instances + // with Conv Multiple D instances + if constexpr(isMultiABD) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "The MultiABD is not supported!" << " In " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1 && + BlkGemmPipelineVer != BlockGemmPipelineVersion::v3) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported pipeline version!" << " In " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported())) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Current device does not support wmma instructions!" << " In " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + + // check ConvolutionForwardSpecialization + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t SpatialDim = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t ConvStride = arg.conv_filter_strides_[i]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "The input parameters do not align with specialization " + "Filter1x1Stride1Pad0!" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t SpatialDim = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "The input parameters do not align with specialization " + "Filter1x1Pad0!" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + } + else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter3x3) + { + if(C != 1) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "When using 3x3 ConvSpec C must be 1!" << " In " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t filter_spatial_dim = arg.b_g_k_c_xs_lengths_[i + I3]; + + if(filter_spatial_dim != I3) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Filter spatial dims do not match 3x3 ConvSpec!" << " In " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + } + + if constexpr(NumGroupsToMerge > 1) + { + if(!(C == 1)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "When using mergegroups C must be 1!" << " In " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + if(G % NumGroupsToMerge != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Number of groups must be devisable by NumGroupsToMerge!" << " In " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + if constexpr(!(is_NSpatialGC_GKSpatial_NSpatialGK() || + is_NGCSpatial_GKSpatial_NGKSpatial() || + is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW())) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported layout in combination with mergegroups!" << " In " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + // check vector access of A + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + NeedTransposeKernel) + { + // TODO: This check originally said "ABlockTransferSrcVectorDim == 2", basically + // blocking all instances with a value of 1. I've tried some though and they work just + // fine. So I changed it to allow a value of 1 for now but there might be cases where + // this does not work. + // Check access per C + if(!(ABlockTransferSrcVectorDim <= 2 && C % ABlockTransferSrcScalarPerVector == 0)) + { + // If not possible, check access per G + if(!(ABlockTransferSrcVectorDim == 1 && (C == 1 || NumGroupsToMerge == 1) && + (is_NSpatialGC_GKSpatial_NSpatialGK() || + is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) && + G % ABlockTransferSrcScalarPerVector == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[A Layout] The number of input channels is not a multiple of " + "ABlockTransferSrcScalarPerVector!" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + } + else if constexpr(is_same_v || is_same_v) + { + static_assert(NeedTransposeKernel == false); + static_assert(NumGroupsToMerge == 1); + + if constexpr(ABlockTransferSrcScalarPerVector != 1) + { + if(ABlockTransferSrcVectorDim != 1) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "ABlockTransferSrcVectorDim must be 1!" << " In " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + if(input_spatial_acum % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[A Layout] The number of input channels is not a multiple of " + "ABlockTransferSrcScalarPerVector!" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + } + else + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported A Layout!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + + // check vector access of B + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v) + + { + if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[B Layout] The number of input channels is not a multiple of " + "BBlockTransferSrcScalarPerVector!" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported B Layout!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + + // Check vector access of Ds + bool valid = 1; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v) + { + if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[D Layout] D tensor number " << i + << " has a K size which is not a multiple of " + "CDEBlockTransferScalarPerVector_NPerBlock!" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + valid = 0; + } + + if constexpr(is_same_v) + { + // G and K must be the same + if(arg.ds_g_n_k_wos_lengths_[i][0] != arg.e_g_n_k_wos_lengths_[0] || + arg.ds_g_n_k_wos_lengths_[i][2] != arg.e_g_n_k_wos_lengths_[2]) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[D Layout] D tensor number " << i + << " shape does not match E shape! (GK case)" << " In " + << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + valid = 0; + } + } + else + { + // E and D must have the same shape + for(index_t d = 0; d < NDimSpatial + 3; d++) + { + if(arg.ds_g_n_k_wos_lengths_[i][d] != arg.e_g_n_k_wos_lengths_[d]) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[D Layout] D tensor number " << i + << " shape does not match E shape! (generic case)" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + valid = 0; + } + } + } + } + else + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[D Layout] D tensor number " << i << " has an unknown layout!" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + valid = 0; + } + }); + + if(!valid) + return false; + + if constexpr(NeedTransposeKernel) + { + if((G * C) % CDEBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[NGCHW Layout] The G * C is not a multiple of " + "CDEBlockTransferScalarPerVector_NPerBlock" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + + if((G * K) % CDEBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[NGCHW Layout] The G * K is not a multiple of " + "CDEBlockTransferScalarPerVector_NPerBlock" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + + if(input_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[NGCHW Layout] The input_spatial_acum is not a multiple of " + "CDEBlockTransferScalarPerVector_NPerBlock" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + + if(output_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[NGCHW Layout] The output_spatial_acum is not a multiple of " + "CDEBlockTransferScalarPerVector_NPerBlock" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + + if(!arg.p_workspace_) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout + << "Warning: Workspace for " + "DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3::Argument is not " + "allocated, use SetWorkSpacePointer." + << std::endl; + } + return false; + } + + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + if(!(arg.a_out_transpose_desc_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + arg.e_in_transpose_desc_.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[NGCHW Layout] One of the transposed vectors is exceeding 2GB " + "memory size!" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + // check vector access of E + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v) + { + if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[E Layout] The K is not a multiple of " + "CDEBlockTransferScalarPerVector_NPerBlock" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported E Layout!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + + // Gridwise gemm v3 doesn't verify descriptors size + if(!arg.conv_to_gemm_transformer_.AreDescriptorsSmallerThan2GB()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout + << "[conv_to_gemm_transformer_] One of the descriptors is bigger than 2GB!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + + // check Gridwise GEMM + const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); + const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_.GetLength(I1); + const index_t GemmK = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + if constexpr(CTranspose) + { + typename GridwiseGemmCTranspose::Argument gemm_arg{{nullptr}, + {nullptr}, + {}, + nullptr, + GemmN, + GemmM, + GemmK, + {I0}, + {I0}, + {}, + I0, + I1 /*KBatch*/, + arg.b_element_op_, + arg.a_element_op_, + arg.cde_element_op_}; + // TODO: No is_reduce argument, defaults to false. + + return GridwiseGemmCTranspose::CheckValidity(gemm_arg, true); // allow_short_v3_pipe + } + else + { + typename GridwiseGemmCTranspose::Argument gemm_arg{{nullptr}, + {nullptr}, + {}, + nullptr, + GemmM, + GemmN, + GemmK, + {I0}, + {I0}, + {}, + I0, + I1 /*KBatch*/, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_}; + // TODO: No is_reduce argument, defaults to false. + + return GridwiseGemmCTranspose::CheckValidity(gemm_arg, true); // allow_short_v3_pipe + } + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument( + APointers p_as, + BPointers p_bs, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + { + return Argument{p_as, + p_bs, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto + MakeArgument(APointers p_as, + BPointers p_bs, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + { + std::array a_g_n_c_wis_lengths_i32; + std::array a_g_n_c_wis_strides_i32; + std::array b_g_k_c_xs_lengths_i32; + std::array b_g_k_c_xs_strides_i32; + std::array, NumDTensor> ds_g_n_k_wos_lengths_i32; + std::array, NumDTensor> ds_g_n_k_wos_strides_i32; + std::array e_g_n_k_wos_lengths_i32; + std::array e_g_n_k_wos_strides_i32; + std::array conv_filter_strides_i32; + std::array conv_filter_dilations_i32; + std::array input_left_pads_i32; + std::array input_right_pads_i32; + + array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths); + array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides); + array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths); + array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides); + for(index_t d = 0; d < NumDTensor; d++) + { + array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]); + array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]); + } + array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths); + array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides); + array_convert(conv_filter_strides_i32, conv_filter_strides); + array_convert(conv_filter_dilations_i32, conv_filter_dilations); + array_convert(input_left_pads_i32, input_left_pads); + array_convert(input_right_pads_i32, input_right_pads); + + return Argument{p_as, + p_bs, + p_ds, + p_e, + a_g_n_c_wis_lengths_i32, + a_g_n_c_wis_strides_i32, + b_g_k_c_xs_lengths_i32, + b_g_k_c_xs_strides_i32, + ds_g_n_k_wos_lengths_i32, + ds_g_n_k_wos_strides_i32, + e_g_n_k_wos_lengths_i32, + e_g_n_k_wos_strides_i32, + conv_filter_strides_i32, + conv_filter_dilations_i32, + input_left_pads_i32, + input_right_pads_i32, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeArgumentPointer( + APointers p_as, + BPointers p_bs, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) override + { + return std::make_unique(p_as, + p_bs, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr + MakeArgumentPointer(APointers p_as, + BPointers p_bs, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) override + { + std::array a_g_n_c_wis_lengths_i32; + std::array a_g_n_c_wis_strides_i32; + std::array b_g_k_c_xs_lengths_i32; + std::array b_g_k_c_xs_strides_i32; + std::array, NumDTensor> ds_g_n_k_wos_lengths_i32; + std::array, NumDTensor> ds_g_n_k_wos_strides_i32; + std::array e_g_n_k_wos_lengths_i32; + std::array e_g_n_k_wos_strides_i32; + std::array conv_filter_strides_i32; + std::array conv_filter_dilations_i32; + std::array input_left_pads_i32; + std::array input_right_pads_i32; + + array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths); + array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides); + array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths); + array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides); + for(index_t d = 0; d < NumDTensor; d++) + { + array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]); + array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]); + } + array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths); + array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides); + array_convert(conv_filter_strides_i32, conv_filter_strides); + array_convert(conv_filter_dilations_i32, conv_filter_dilations); + array_convert(input_left_pads_i32, input_left_pads); + array_convert(input_right_pads_i32, input_right_pads); + + return std::make_unique(p_as, + p_bs, + p_ds, + p_e, + a_g_n_c_wis_lengths_i32, + a_g_n_c_wis_strides_i32, + b_g_k_c_xs_lengths_i32, + b_g_k_c_xs_strides_i32, + ds_g_n_k_wos_lengths_i32, + ds_g_n_k_wos_strides_i32, + e_g_n_k_wos_lengths_i32, + e_g_n_k_wos_strides_i32, + conv_filter_strides_i32, + conv_filter_dilations_i32, + input_left_pads_i32, + input_right_pads_i32, + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << getConvForwardSpecializationString(ConvForwardSpecialization) << ", " + << MPerWmma << ", " + << NPerWmma << ", " + << MRepeat << ", " + << NRepeat << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CDEBlockTransferScalarPerVector_NPerBlock << ", " + << CShuffleMRepeatPerShuffle << ", " + << CShuffleNRepeatPerShuffle << ", " + << "BlkGemmPipelineScheduler: " + << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", " + << "BlkGemmPipelineVersion: " + << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", " + << NumGroupsToMerge + << ">"; + // clang-format on + + return str.str(); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto arg = dynamic_cast(p_arg); + if(arg) + { + return arg->GetWorkspaceSizeBytes(); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle::Argument structure!"); + } + + void SetWorkSpacePointer(BaseArgument* p_arg, + void* p_workspace, + const StreamConfig& = StreamConfig{}) const override + { + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + p_arg_->p_workspace_ = p_workspace; + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle::Argument structure!"); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp index 83a59d9f8ed..1f88b692134 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp @@ -198,10 +198,10 @@ struct ComputePtrOffsetOfStridedBatch(g_idx) * BatchStrideE_; } - Array BatchStrideA_; - Array BatchStrideB_; - Array BatchStrideDs_; - long_index_t BatchStrideE_; + Array BatchStrideA_{}; + Array BatchStrideB_{}; + Array BatchStrideDs_{}; + long_index_t BatchStrideE_{}; long_index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D }; @@ -253,10 +253,10 @@ struct ComputePtrOffsetOfStridedBatch(g_idx) * BatchStrideE_; } - long_index_t BatchStrideA_; - long_index_t BatchStrideB_; - Array BatchStrideDs_; - long_index_t BatchStrideE_; + long_index_t BatchStrideA_{}; + long_index_t BatchStrideB_{}; + Array BatchStrideDs_{}; + long_index_t BatchStrideE_{}; long_index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp index fea01023375..3c698b05dea 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp @@ -762,6 +762,144 @@ struct GridwiseGemm_wmma_cshuffle_v3 { return Block2CTileMap{problem.M, problem.N, 4}; } + + // Run method for convolution (grid descriptors are passed as arguments, + // not generated internally) + template + __device__ static void Run(void* p_shared, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + const ComputePtrOffsetOfBatch& compute_ptr_offset_of_batch, + const ComputePtrOffsetOfN& compute_ptr_offset_of_n, + [[maybe_unused]] const index_t num_k_per_block, + Argument& karg, + EpilogueArgument& epilogue_args) + { + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / karg.KBatch); + // offset base pointer for each work-group + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); + + const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + + const long_index_t a_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); + const long_index_t b_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetBPtrOffset(n_idx)); + const long_index_t e_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); + + const auto ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx); + + AsGridPointer p_as_grid_; + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType_ = remove_cvref_t>; + p_as_grid_(i) = + static_cast(karg.p_as_grid[i]) + a_batch_offset + a_n_offset; + }); + + BsGridPointer p_bs_grid_; + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType_ = remove_cvref_t>; + p_bs_grid_(i) = + static_cast(karg.p_bs_grid[i]) + b_batch_offset + b_n_offset; + }); + + DsGridPointer p_ds_grid_grp; + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType_ = remove_cvref_t>; + p_ds_grid_grp(i) = static_cast(karg.p_ds_grid[i]) + + ds_batch_offset[i] + ds_n_offset[i]; + }); + + // Currently supporting one A and one B + const auto as_grid_desc_ak0_m_ak1 = generate_tuple( + [&](auto i) { + ignore = i; + return a_grid_desc_ak0_m_ak1; + }, + Number{}); + + const auto bs_grid_desc_bk0_n_bk1 = generate_tuple( + [&](auto i) { + ignore = i; + return b_grid_desc_bk0_n_bk1; + }, + Number{}); + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // AScale struct (Empty) + using AScale = typename BlockwiseGemmPipe::Empty; + auto a_scale_struct = AScale{}; + + // BScale struct (Empty) + using BScale = typename BlockwiseGemmPipe::Empty; + auto b_scale_struct = BScale{}; + + const index_t num_k_block_per_scale = GetKBlockPerScale(); + + Base::template Run(p_as_grid_, + p_bs_grid_, + p_ds_grid_grp, + karg.p_e_grid + e_batch_offset + e_n_offset, + p_shared, + as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + karg.a_element_op, + karg.b_element_op, + karg.cde_element_op, + block_m_id, + block_n_id, + num_k_block_per_scale, + a_scale_struct, + b_scale_struct, + epilogue_args); + } }; } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index 81aa1ac9866..9cef618bf3c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -344,11 +344,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_base // return block_id to C matrix tile idx (m0, n0) mapping using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + // Calculate grid size taking into account splitk (KBatch) + // 2D grid (x,z) __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) { return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); } + // Calculate grid size taking into account splitk (KBatch) and multiple groups (Batch) + // 3D grid (x,y,z) + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch, index_t Batch) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), KBatch, Batch); + } + __host__ static auto CalculateMPadded(index_t M) { return math::integer_least_multiple(M, MPerBlock); @@ -606,8 +615,10 @@ struct GridwiseGemm_wmma_cshuffle_v3_base } template - __device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock) + __device__ __host__ static constexpr auto + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc& ds_grid_desc_m_n, + index_t MBlock, + index_t NBlock) { return generate_tuple( [&](auto i) { @@ -706,8 +717,10 @@ struct GridwiseGemm_wmma_cshuffle_v3_base ReduceTrait>; template - __device__ static constexpr auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - const DEGridDesc& de_grid_desc_m_n, index_t MBlock, index_t NBlock) + __host__ __device__ static constexpr auto + MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DEGridDesc& de_grid_desc_m_n, + index_t MBlock, + index_t NBlock) { const auto de_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( de_grid_desc_m_n, @@ -721,7 +734,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template - __host__ static constexpr bool CheckValidity(const Argument& karg) + __host__ static constexpr bool CheckValidity(const Argument& karg, + bool allow_short_v3_pipe = false) { static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) && (NPerBlock % (NPerWmma * NRepeat)) == 0, @@ -912,14 +926,17 @@ struct GridwiseGemm_wmma_cshuffle_v3_base { if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + if(!(allow_short_v3_pipe && BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)) { - std::cout << "Pipeline validation failed: num_k_loop (" << num_k_loop - << ") <= PrefetchStages (" << BlockwiseGemmPipe::PrefetchStages - << ") for pipeline version != v1." << __FILE__ << ":" << __LINE__ - << ", in function: " << __func__ << std::endl; + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Pipeline validation failed: num_k_loop (" << num_k_loop + << ") <= PrefetchStages (" << BlockwiseGemmPipe::PrefetchStages + << ") for pipeline version != v1." << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; } - return false; } } diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp index 058cf58ec98..60734d34939 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp @@ -286,8 +286,25 @@ struct ThreadwiseTensorSliceTransfer_v7r3 static_for<0, nDst, 1>{}([&](auto i) { using elm_vector_t = typename remove_cvref_t::type; - elm_vectors(i).template AsType()(I0) = - oob_val ? elm_vectors(i).template AsType()[I0] : elm_vector_t{0}; + using scalar_t = std::remove_cvref_t< + decltype(elm_vectors(i).template AsType()[I0])>; + + // This is a bit ugly but necessary to be able to compile f8 instances for grouped + // convolution forward. For some reason for that specific type there is an ambiguity + // in the type resolution for the ternary expression. I added an explicit cast to + // disambiguate and only use it for f8 just in case it affects performance. + if constexpr(std::is_same_v) + { + elm_vectors(i).template AsType()(I0) = + oob_val ? elm_vector_t{elm_vectors(i).template AsType()[I0]} + : elm_vector_t{0}; + } + else + { + elm_vectors(i).template AsType()(I0) = + oob_val ? elm_vectors(i).template AsType()[I0] + : elm_vector_t{0}; + } }); elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_bilinear_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_bilinear_instance.hpp new file mode 100644 index 00000000000..57bbe649138 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_bilinear_instance.hpp @@ -0,0 +1,137 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +// BF16 instances +template +using device_grouped_conv_fwd_wmma_cshufflev3_bilinear_bf16_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> +#ifndef ONE_INSTANCE_PER_LIST + , + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 128, 48, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 128, 48, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 256, 64, 64, 64, 8, 8, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 256, 64, 64, 64, 8, 8, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 64, 32, 64, 8, 8, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 64, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 64, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 128, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 128, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 128, 96, 64, 8, 8, 16, 16, 4, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 256, 256, 64, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#endif + // clang-format on + >; + +// F16 instances +template +using device_grouped_conv_fwd_wmma_cshufflev3_bilinear_f16_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#ifndef ONE_INSTANCE_PER_LIST + , + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 64, 32, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 128, 48, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 128, 48, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 64, 32, 64, 8, 8, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 64, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 128, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 128, 96, 64, 8, 8, 16, 16, 4, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#endif + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_binary_outelementop_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_binary_outelementop_instance.hpp new file mode 100644 index 00000000000..a864804a121 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_binary_outelementop_instance.hpp @@ -0,0 +1,95 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +#ifdef CK_ENABLE_FP8 +using F8 = ck::f8_t; +#endif + +#ifdef CK_ENABLE_BF8 +using BF8 = ck::bf8_t; +#endif + +template +using S = ck::Sequence; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +#ifdef CK_ENABLE_FP8 + +template +using device_grouped_conv_fwd_wmma_cshufflev3_binary_outelementop_f8_instances = std::tuple< +// clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#ifdef CK_ENABLE_FP8 + // generic instance + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8> +#ifndef ONE_INSTANCE_PER_LIST + , + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 64, 32, 8, 8, 16, 16, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 64, 128, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 32, 128, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 32, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 96, 64, 8, 8, 16, 16, 4, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8> +#endif +#endif + // clang-format on + >; + +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_dynamic_op_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_dynamic_op_instance.hpp new file mode 100644 index 00000000000..0ad528935ab --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_dynamic_op_instance.hpp @@ -0,0 +1,140 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; +using I8 = int8_t; + +template +using S = ck::Sequence; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using DynamicUnaryOp = ck::tensor_operation::element_wise::DynamicUnaryOp; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +template +using device_grouped_conv_fwd_wmma_cshufflev3_dynamic_op_bf16_instances = + std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> +#ifndef ONE_INSTANCE_PER_LIST + , + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 128, 128, 48, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 128, 128, 48, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 256, 64, 64, 64, 8, 8, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 256, 64, 64, 64, 8, 8, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 128, 64, 32, 64, 8, 8, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 128, 64, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 128, 64, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 128, 128, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 128, 128, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 128, 128, 96, 64, 8, 8, 16, 16, 4, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 256, 256, 64, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#endif + // clang-format on + >; + +template +using device_grouped_conv_fwd_wmma_cshufflev3_dynamic_op_f16_instances = + std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#ifndef ONE_INSTANCE_PER_LIST + , + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 64, 32, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 128, 128, 48, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 128, 128, 48, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 128, 64, 32, 64, 8, 8, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 128, 64, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 128, 128, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 128, 128, 96, 64, 8, 8, 16, 16, 4, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#endif + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp new file mode 100644 index 00000000000..ff627408703 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp @@ -0,0 +1,162 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using Empty_Tuple = ck::Tuple<>; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddClamp = ck::tensor_operation::element_wise::AddClamp; +using Clamp = ck::tensor_operation::element_wise::Clamp; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto ConvFwdOddC = + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; + +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +template , + typename OutElementOp = PassThrough> +using device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version | + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | | + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#ifndef ONE_INSTANCE_PER_LIST + , + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#endif + // clang-format on + >; + +template , + typename OutElementOp = PassThrough> +using device_grouped_conv_fwd_wmma_cshufflev3_f16_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version | + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | | + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#ifndef ONE_INSTANCE_PER_LIST + , + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#endif + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_outelementop_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_outelementop_instance.hpp new file mode 100644 index 00000000000..64304f9bf00 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_outelementop_instance.hpp @@ -0,0 +1,275 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +#ifdef CK_ENABLE_FP8 +using F8 = ck::f8_t; +#endif + +#ifdef CK_ENABLE_BF8 +using BF8 = ck::bf8_t; +#endif + +template +using S = ck::Sequence; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +#ifdef CK_ENABLE_FP8 + +template +using device_grouped_conv_fwd_wmma_cshufflev3_outelementop_f8_instances = std::tuple< +// clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#ifdef CK_ENABLE_FP8 + // generic instance + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8> +#ifndef ONE_INSTANCE_PER_LIST + , + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 64, 32, 8, 8, 16, 16, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 64, 128, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 32, 128, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 32, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 96, 64, 8, 8, 16, 16, 4, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8> +#endif +#endif + // clang-format on + >; + +#endif + +#ifdef CK_ENABLE_BF8 + +template +using device_grouped_conv_fwd_wmma_cshufflev3_outelementop_bf8_instances = std::tuple< +// clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#ifdef CK_ENABLE_BF8 + // generic instance + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF8> +#ifndef ONE_INSTANCE_PER_LIST + , + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 64, 32, 8, 8, 16, 16, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 64, 128, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 32, 128, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 32, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 96, 64, 8, 8, 16, 16, 4, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF8> +#endif +#endif + // clang-format on + >; + +#endif + +#if defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8) + +template +using device_grouped_conv_fwd_wmma_cshufflev3_outelementop_f8_bf8_instances = std::tuple< +// clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if(defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8)) + // generic instance + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8, BF8> +#ifndef ONE_INSTANCE_PER_LIST + , + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 64, 32, 8, 8, 16, 16, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 64, 128, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 32, 128, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 32, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 96, 64, 8, 8, 16, 16, 4, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8, BF8> +#endif +#endif + // clang-format on + >; + +template +using device_grouped_conv_fwd_wmma_cshufflev3_outelementop_bf8_f8_instances = std::tuple< +// clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if(defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8)) + // generic instance + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF8, F8> +#ifndef ONE_INSTANCE_PER_LIST + , + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 64, 32, 8, 8, 16, 16, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 64, 128, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 32, 128, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 32, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 96, 64, 8, 8, 16, 16, 4, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF8, F8> +#endif +#endif + // clang-format on + >; + +#endif + +#ifdef CK_ENABLE_FP8 + +template +using device_grouped_conv_fwd_wmma_cshufflev3_outelementop_f8_f8_f32_instances = std::tuple< +// clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#ifdef CK_ENABLE_FP8 + // generic instance + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8, F8> +#ifndef ONE_INSTANCE_PER_LIST + , + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 64, 32, 8, 8, 16, 16, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 64, 128, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 32, 128, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 32, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 96, 64, 8, 8, 16, 16, 4, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8, F8> +#endif +#endif + // clang-format on + >; + +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_scale_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_scale_instance.hpp new file mode 100644 index 00000000000..ce69fe6623f --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_scale_instance.hpp @@ -0,0 +1,138 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; +using I8 = int8_t; + +template +using S = ck::Sequence; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +// BF16 instances +template +using device_grouped_conv_fwd_wmma_cshufflev3_scale_bf16_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> +#ifndef ONE_INSTANCE_PER_LIST + , + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 128, 48, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 128, 48, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 256, 64, 64, 64, 8, 8, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 256, 64, 64, 64, 8, 8, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 64, 32, 64, 8, 8, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 64, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 64, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 128, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 128, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 128, 96, 64, 8, 8, 16, 16, 4, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 256, 256, 64, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#endif + // clang-format on + >; + +// F16 instances +template +using device_grouped_conv_fwd_wmma_cshufflev3_scale_f16_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#ifndef ONE_INSTANCE_PER_LIST + , + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 64, 32, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 128, 48, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 128, 48, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 64, 32, 64, 8, 8, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 64, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 128, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 128, 96, 64, 8, 8, 16, 16, 4, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#endif + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_instance.hpp new file mode 100644 index 00000000000..d97cd6f04c7 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_instance.hpp @@ -0,0 +1,145 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; +using I8 = int8_t; + +template +using S = ck::Sequence; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; +static constexpr auto ConvFwd1x1P0 = ConvolutionForwardSpecialization::Filter1x1Pad0; +static constexpr auto ConvFwd1x1S1P0 = ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; +static constexpr auto ConvFwdOddC = + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; + +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +template +using device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16_instances = + std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version | + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | | + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#ifndef ONE_INSTANCE_PER_LIST + , + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 48, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 48, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 64, 64, 64, 8, 8, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 64, 64, 64, 8, 8, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 32, 64, 8, 8, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 96, 64, 8, 8, 16, 16, 4, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 256, 64, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#endif + // clang-format on + >; + +template +using device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_f16_instances = + std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version | + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | | + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#ifndef ONE_INSTANCE_PER_LIST + , + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 48, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 48, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 64, 64, 64, 8, 8, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 64, 64, 64, 8, 8, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 32, 64, 8, 8, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 96, 64, 8, 8, 16, 16, 4, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 256, 64, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#endif + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_scaleadd_relu_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_scaleadd_relu_instance.hpp new file mode 100644 index 00000000000..6c9fb5c4545 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_scaleadd_relu_instance.hpp @@ -0,0 +1,139 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; +using I8 = int8_t; + +template +using S = ck::Sequence; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ScaleAddScaleAddRelu = ck::tensor_operation::element_wise::ScaleAddScaleAddRelu; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto ConvFwdOddC = + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; + +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +template +using device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_scaleadd_relu_bf16_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> +#ifndef ONE_INSTANCE_PER_LIST + , + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 128, 128, 48, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 128, 128, 48, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 256, 64, 64, 64, 8, 8, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 256, 64, 64, 64, 8, 8, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 128, 64, 32, 64, 8, 8, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 128, 64, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 128, 64, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 128, 128, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 128, 128, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 128, 128, 96, 64, 8, 8, 16, 16, 4, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 256, 256, 64, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#endif + // clang-format on + >; + +template +using device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_scaleadd_relu_f16_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveNPerWmma| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#ifndef ONE_INSTANCE_PER_LIST + , + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 64, 32, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 128, 128, 48, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 128, 128, 48, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 128, 64, 32, 64, 8, 8, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 128, 64, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 128, 128, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 128, 128, 96, 64, 8, 8, 16, 16, 4, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#endif + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index 5089ea2c1e5..6a6b8b1d600 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -24,8 +24,14 @@ #include "grouped_convolution_forward_mem_intra_xdl.inc" #endif #ifdef CK_USE_WMMA +#define CK_USE_WMMA_OLD +#ifdef CK_USE_WMMA_OLD #include "grouped_convolution_forward_wmma.inc" #endif +#include "grouped_convolution_forward_wmma_cshufflev3.inc" +#include "grouped_convolution_forward_mem_inter_wmma_cshufflev3.inc" +#include "grouped_convolution_forward_mem_intra_wmma_cshufflev3.inc" +#endif namespace ck { namespace tensor_operation { @@ -652,7 +658,9 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { @@ -766,6 +774,69 @@ struct DeviceOperationInstanceFactory && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances( + op_ptrs); + // add_device_grouped_conv2d_fwd_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances( + // op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && + is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances( + op_ptrs); + // add_device_grouped_conv2d_fwd_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances( + // op_ptrs); + } +#endif + } + + // 3D + // layout NDHWGC/GKZYXC/NDHWGK + if constexpr(NumDimSpatial == 3 && is_same_v && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances( + op_ptrs); + // add_device_grouped_conv3d_fwd_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances( + // op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && + is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + op_ptrs); + // add_device_grouped_conv3d_fwd_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + // op_ptrs); + } +#endif + } #endif // CK_USE_WMMA return op_ptrs; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp.hpp index 748a5a78d9b..1f72f34c1d2 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp.hpp @@ -16,6 +16,10 @@ #include "grouped_convolution_forward_bias_clamp_xdl.inc" #endif +#ifdef CK_USE_WMMA +#include "grouped_convolution_forward_bias_clamp_wmma_cshufflev3.inc" +#endif + namespace ck { namespace tensor_operation { namespace device { @@ -269,6 +273,67 @@ struct DeviceOperationInstanceFactory && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && + is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances( + op_ptrs); + // add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances( + // op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances( + op_ptrs); + // add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances( + // op_ptrs); + } +#endif + } + // layout NDHWGC/GKZYXC/NDHWGK + if constexpr(NumDimSpatial == 3 && is_same_v && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && + is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + op_ptrs); + // add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + // op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances( + op_ptrs); + // add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances( + // op_ptrs); + } +#endif + } +#endif // CK_USE_WMMA + return op_ptrs; } }; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp_wmma_cshufflev3.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp_wmma_cshufflev3.inc new file mode 100644 index 00000000000..b0c376e9ce6 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp_wmma_cshufflev3.inc @@ -0,0 +1,138 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#ifdef CK_ENABLE_BF16 + +void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddClamp>>>& instances); + +// void +// add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances( +// std::vector, +// NHWGK, +// BF16, +// BF16, +// Tuple, +// BF16, +// PassThrough, +// PassThrough, +// AddClamp>>>& instances); + +void add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddClamp>>>& instances); + +// void +// add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances( +// std::vector, +// NDHWGK, +// BF16, +// BF16, +// Tuple, +// BF16, +// PassThrough, +// PassThrough, +// AddClamp>>>& instances); + +#endif + +#ifdef CK_ENABLE_FP16 + +void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector, + NHWGK, + F16, + F16, + Tuple, + F16, + PassThrough, + PassThrough, + AddClamp>>>& instances); + +// void +// add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances( +// std::vector, +// NHWGK, +// F16, +// F16, +// Tuple, +// F16, +// PassThrough, +// PassThrough, +// AddClamp>>>& instances); + +void add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector, + NDHWGK, + F16, + F16, + Tuple, + F16, + PassThrough, + PassThrough, + AddClamp>>>& instances); + +// void +// add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances( +// std::vector, +// NDHWGK, +// F16, +// F16, +// Tuple, +// F16, +// PassThrough, +// PassThrough, +// AddClamp>>>& instances); + +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bilinear.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bilinear.hpp index e15a32a9d78..de51c76833a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bilinear.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bilinear.hpp @@ -21,6 +21,7 @@ namespace instance { using PassThrough = ck::tensor_operation::element_wise::PassThrough; using Bilinear = ck::tensor_operation::element_wise::Bilinear; +#ifdef CK_USE_XDL #ifdef CK_ENABLE_BF16 // grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK void add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instances( @@ -103,6 +104,42 @@ void add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_int8_instan PassThrough, Bilinear>>>& instances); #endif +#endif // CK_USE_XDL + +#ifdef CK_USE_WMMA +#ifdef CK_ENABLE_BF16 +// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Bilinear>>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector, + NDHWGK, + F16, + F16, + ck::Tuple, + F16, + PassThrough, + PassThrough, + Bilinear>>>& instances); +#endif +#endif // CK_USE_WMMA template > op_ptrs; + +#ifdef CK_USE_XDL if constexpr(NumDimSpatial == 3 && is_same_v && is_same_v && is_same_v && DLayouts::Size() == 1 && is_same_v, NDHWGK>) @@ -194,6 +233,31 @@ struct DeviceOperationInstanceFactory && + is_same_v && is_same_v && + DLayouts::Size() == 1 && is_same_v, NDHWGK>) + { +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv3d_fwd_wmma_cshufflev3_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_wmma_cshufflev3_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instances( + op_ptrs); + } +#endif + } +#endif return op_ptrs; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp index 090c99819f3..89036fed23e 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp @@ -16,6 +16,10 @@ #include "grouped_convolution_forward_clamp_xdl.inc" #endif +#ifdef CK_USE_WMMA +#include "grouped_convolution_forward_clamp_wmma_cshufflev3.inc" +#endif + namespace ck { namespace tensor_operation { namespace device { @@ -266,6 +270,67 @@ struct DeviceOperationInstanceFactory && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && + is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances( + op_ptrs); + // add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances( + // op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances( + op_ptrs); + // add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances( + // op_ptrs); + } +#endif + } + // layout NDHWGC/GKZYXC/NDHWGK + if constexpr(NumDimSpatial == 3 && is_same_v && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && + is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + op_ptrs); + // add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + // op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances( + op_ptrs); + // add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances( + // op_ptrs); + } +#endif + } +#endif // CK_USE_WMMA + return op_ptrs; } }; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp_wmma_cshufflev3.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp_wmma_cshufflev3.inc new file mode 100644 index 00000000000..8be584949f2 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp_wmma_cshufflev3.inc @@ -0,0 +1,138 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#ifdef CK_ENABLE_BF16 + +void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances); + +// void +// add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances( +// std::vector, +// NHWGK, +// BF16, +// BF16, +// Tuple<>, +// BF16, +// PassThrough, +// PassThrough, +// Clamp>>>& instances); + +void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances); + +// void +// add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances( +// std::vector, +// NDHWGK, +// BF16, +// BF16, +// Tuple<>, +// BF16, +// PassThrough, +// PassThrough, +// Clamp>>>& instances); + +#endif + +#ifdef CK_ENABLE_FP16 + +void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector, + NHWGK, + F16, + F16, + Tuple<>, + F16, + PassThrough, + PassThrough, + Clamp>>>& instances); + +void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector, + NDHWGK, + F16, + F16, + Tuple<>, + F16, + PassThrough, + PassThrough, + Clamp>>>& instances); + +// void +// add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances( +// std::vector, +// NHWGK, +// F16, +// F16, +// Tuple<>, +// F16, +// PassThrough, +// PassThrough, +// Clamp>>>& instances); + +// void +// add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances( +// std::vector, +// NDHWGK, +// F16, +// F16, +// Tuple<>, +// F16, +// PassThrough, +// PassThrough, +// Clamp>>>& instances); + +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convinvscale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convinvscale.hpp index 86d1d61a1da..52b1288ce2e 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convinvscale.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convinvscale.hpp @@ -22,6 +22,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ConvInvscale = ck::tensor_operation::element_wise::ConvInvscale; #ifdef CK_ENABLE_FP8 +#ifdef CK_USE_XDL void add_device_grouped_conv3d_fwd_xdl_convinvscale_ndhwgc_gkzyxc_ndhwgk_f8_instances( std::vector>>& instances); #endif +#ifdef CK_USE_WMMA_FP8 +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_convinvscale_ndhwgc_gkzyxc_ndhwgk_f8_instances( + std::vector, + NDHWGK, + F8, + F8, + ck::Tuple<>, + F8, + PassThrough, + PassThrough, + ConvInvscale, + F8, + F8>>>& instances); +#endif +#endif + template && is_same_v && is_same_v) { +#ifdef CK_USE_XDL add_device_grouped_conv3d_fwd_xdl_convinvscale_ndhwgc_gkzyxc_ndhwgk_f8_instances( op_ptrs); +#endif +#ifdef CK_USE_WMMA_FP8 + add_device_grouped_conv3d_fwd_wmma_cshufflev3_convinvscale_ndhwgc_gkzyxc_ndhwgk_f8_instances( + op_ptrs); +#endif } #endif } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp index c84e2516722..720daf91329 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp @@ -20,6 +20,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ConvScale = ck::tensor_operation::element_wise::ConvScale; #ifdef CK_ENABLE_FP8 +#ifdef CK_USE_XDL void add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instances( std::vector>>& instances); #endif +#ifdef CK_USE_WMMA_FP8 +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instances( + std::vector, + NDHWGK, + F8, + F8, + ck::Tuple<>, + F8, + PassThrough, + PassThrough, + ConvScale, + F8, + F8>>>& instances); +#endif +#endif + #if(defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8)) +#ifdef CK_USE_XDL void add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_instances( std::vector>>& instances); #endif +#ifdef CK_USE_WMMA_FP8 +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_instances( + std::vector, + NDHWGK, + BF8, + BF8, + ck::Tuple<>, + F8, + PassThrough, + PassThrough, + ConvScale, + BF8, + BF8>>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_convscale_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instances( + std::vector, + NDHWGK, + F8, + BF8, + ck::Tuple<>, + F8, + PassThrough, + PassThrough, + ConvScale, + F8, + BF8>>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instances( + std::vector, + NDHWGK, + BF8, + F8, + ck::Tuple<>, + F8, + PassThrough, + PassThrough, + ConvScale, + BF8, + F8>>>& instances); +#endif +#endif + template && is_same_v && is_same_v) { +#ifdef CK_USE_XDL add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instances( op_ptrs); +#endif +#ifdef CK_USE_WMMA_FP8 + add_device_grouped_conv3d_fwd_wmma_cshufflev3_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instances( + op_ptrs); +#endif } #endif @@ -150,24 +228,42 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { +#ifdef CK_USE_XDL add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_instances( op_ptrs); +#endif +#ifdef CK_USE_WMMA_FP8 + add_device_grouped_conv3d_fwd_wmma_cshufflev3_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_instances( + op_ptrs); +#endif } if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && is_same_v) { +#ifdef CK_USE_XDL add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instances( op_ptrs); +#endif +#ifdef CK_USE_WMMA_FP8 + add_device_grouped_conv3d_fwd_wmma_cshufflev3_convscale_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instances( + op_ptrs); +#endif } if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && is_same_v) { +#ifdef CK_USE_XDL add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instances( op_ptrs); +#endif +#ifdef CK_USE_WMMA_FP8 + add_device_grouped_conv3d_fwd_wmma_cshufflev3_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instances( + op_ptrs); +#endif } #endif } @@ -178,6 +274,7 @@ struct DeviceOperationInstanceFactory>>& instances); #endif +#ifdef CK_USE_WMMA_FP8 +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_combconvscale_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instances( + std::vector, + NDHWGK, + F8, + F8, + ck::Tuple<>, + F32, + PassThrough, + PassThrough, + CombConvScale, + F8, + F8>>>& instances); +#endif +#endif + template && is_same_v && is_same_v) { +#ifdef CK_USE_XDL add_device_grouped_conv3d_fwd_xdl_combconvscale_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instances( op_ptrs); +#endif +#ifdef CK_USE_WMMA_FP8 + add_device_grouped_conv3d_fwd_wmma_cshufflev3_combconvscale_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instances( + op_ptrs); +#endif } #endif } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_add.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_add.hpp index d20b5e3d25d..207ffb45ded 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_add.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_add.hpp @@ -20,6 +20,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ConvScaleAdd = ck::tensor_operation::element_wise::ConvScaleAdd; #ifdef CK_ENABLE_FP8 +#ifdef CK_USE_XDL void add_device_grouped_conv3d_fwd_xdl_convscale_add_ndhwgc_gkzyxc_ndhwgk_f8_instances( std::vector>>& instances); #endif +#ifdef CK_USE_WMMA_FP8 +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_convscale_add_ndhwgc_gkzyxc_ndhwgk_f8_instances( + std::vector, + NDHWGK, + F8, + F8, + ck::Tuple, + F8, + PassThrough, + PassThrough, + ConvScaleAdd, + F8, + F8>>>& instances); +#endif +#endif + template && is_same_v && is_same_v) { +#ifdef CK_USE_XDL add_device_grouped_conv3d_fwd_xdl_convscale_add_ndhwgc_gkzyxc_ndhwgk_f8_instances( op_ptrs); +#endif +#ifdef CK_USE_WMMA_FP8 + add_device_grouped_conv3d_fwd_wmma_cshufflev3_convscale_add_ndhwgc_gkzyxc_ndhwgk_f8_instances( + op_ptrs); +#endif } #endif } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_relu.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_relu.hpp index 3320a805e18..a3af3bdc98a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_relu.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_relu.hpp @@ -20,6 +20,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ConvScaleRelu = ck::tensor_operation::element_wise::ConvScaleRelu; #ifdef CK_ENABLE_FP8 +#ifdef CK_USE_XDL void add_device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instances( std::vector>>& instances); #endif +#ifdef CK_USE_WMMA_FP8 +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instances( + std::vector, + NDHWGK, + F8, + F8, + ck::Tuple<>, + F8, + PassThrough, + PassThrough, + ConvScaleRelu, + F8, + F8>>>& instances); +#endif +#endif + template && is_same_v && is_same_v) { +#ifdef CK_USE_XDL add_device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instances( op_ptrs); +#endif +#ifdef CK_USE_WMMA_FP8 + add_device_grouped_conv3d_fwd_wmma_cshufflev3_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instances( + op_ptrs); +#endif } #endif } @@ -102,6 +128,7 @@ struct DeviceOperationInstanceFactory< using CombConvScaleRelu = ck::tensor_operation::element_wise::ScaleScaleRelu; #ifdef CK_ENABLE_FP8 +#ifdef CK_USE_XDL void add_device_grouped_conv3d_fwd_xdl_combconvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instances( std::vector>>& instances); #endif +#ifdef CK_USE_WMMA_FP8 +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_combconvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instances( + std::vector, + NDHWGK, + F8, + F8, + ck::Tuple<>, + F32, + PassThrough, + PassThrough, + CombConvScaleRelu, + F8, + F8>>>& instances); +#endif +#endif + template && is_same_v && is_same_v) { +#ifdef CK_USE_XDL add_device_grouped_conv3d_fwd_xdl_combconvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instances( op_ptrs); +#endif +#ifdef CK_USE_WMMA_FP8 + add_device_grouped_conv3d_fwd_wmma_cshufflev3_combconvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instances( + op_ptrs); +#endif } #endif } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_dynamic_op.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_dynamic_op.hpp index abe35e6a24c..164150781cf 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_dynamic_op.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_dynamic_op.hpp @@ -21,6 +21,7 @@ namespace instance { using PassThrough = ck::tensor_operation::element_wise::PassThrough; using DynamicUnaryOp = ck::tensor_operation::element_wise::DynamicUnaryOp; +#ifdef CK_USE_XDL #ifdef CK_ENABLE_BF16 // grouped conv2d forward, NHWGC/GKYXC/NHWGK void add_device_grouped_conv2d_fwd_xdl_dynamic_op_nhwgc_gkyxc_nhwgk_bf16_instances( @@ -150,6 +151,80 @@ void add_device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_int8_inst PassThrough, DynamicUnaryOp>>>& instances); #endif +#endif // CK_USE_XDL + +#ifdef CK_USE_WMMA +#ifdef CK_ENABLE_BF16 +// grouped conv2d forward, NHWGC/GKYXC/NHWGK +void add_device_grouped_conv2d_fwd_wmma_cshufflev3_dynamic_op_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector, + NHWGK, + BF16, + BF16, + ck::Tuple<>, + BF16, + PassThrough, + PassThrough, + DynamicUnaryOp>>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_fwd_wmma_cshufflev3_dynamic_op_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector, + NHWGK, + F16, + F16, + ck::Tuple<>, + F16, + PassThrough, + PassThrough, + DynamicUnaryOp, + F16, + F16>>>& instances); +#endif +#ifdef CK_ENABLE_BF16 +// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_dynamic_op_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + ck::Tuple<>, + BF16, + PassThrough, + PassThrough, + DynamicUnaryOp, + BF16, + BF16>>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_dynamic_op_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector, + NDHWGK, + F16, + F16, + ck::Tuple<>, + F16, + PassThrough, + PassThrough, + DynamicUnaryOp, + F16, + F16>>>& instances); +#endif +#endif // CK_USE_WMMA template > op_ptrs; + +#ifdef CK_USE_XDL + // layout NDHWGC/GKZYXC/NDHWGK if constexpr(NumDimSpatial == 3 && is_same_v && is_same_v && is_same_v && DLayouts::Size() == 0) @@ -271,6 +349,53 @@ struct DeviceOperationInstanceFactory && + is_same_v && is_same_v && + DLayouts::Size() == 0) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv3d_fwd_wmma_cshufflev3_dynamic_op_ndhwgc_gkzyxc_ndhwgk_f16_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv3d_fwd_wmma_cshufflev3_dynamic_op_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + op_ptrs); + } +#endif + } + else if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v && + DLayouts::Size() == 0) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv2d_fwd_wmma_cshufflev3_dynamic_op_nhwgc_gkyxc_nhwgk_f16_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv2d_fwd_wmma_cshufflev3_dynamic_op_nhwgc_gkyxc_nhwgk_bf16_instances( + op_ptrs); + } +#endif + } +#endif // CK_USE_WMMA return op_ptrs; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_wmma_cshufflev3.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_wmma_cshufflev3.inc new file mode 100644 index 00000000000..319015b0eeb --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_wmma_cshufflev3.inc @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#ifdef CK_ENABLE_INT8 +void add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_mem_inter_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_wmma_cshufflev3.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_wmma_cshufflev3.inc new file mode 100644 index 00000000000..cf4105d9e10 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_wmma_cshufflev3.inc @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#ifdef CK_ENABLE_INT8 +void add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_mem_intra_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scale.hpp index 6c5cedc1a62..879816b7d24 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scale.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scale.hpp @@ -21,6 +21,7 @@ namespace instance { using PassThrough = ck::tensor_operation::element_wise::PassThrough; using Scale = ck::tensor_operation::element_wise::Scale; +#ifdef CK_USE_XDL #ifdef CK_ENABLE_BF16 // grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instances( @@ -103,6 +104,42 @@ void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_int8_instances PassThrough, Scale>>>& instances); #endif +#endif + +#ifdef CK_USE_WMMA +#ifdef CK_ENABLE_BF16 +// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + ck::Tuple<>, + BF16, + PassThrough, + PassThrough, + Scale>>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scale_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector, + NDHWGK, + F16, + F16, + ck::Tuple<>, + F16, + PassThrough, + PassThrough, + Scale>>>& instances); +#endif +#endif template && is_same_v && is_same_v && is_same_v) { +#ifdef CK_USE_XDL add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs); +#endif +#ifdef CK_USE_WMMA + add_device_grouped_conv3d_fwd_wmma_cshufflev3_scale_ndhwgc_gkzyxc_ndhwgk_f16_instances( + op_ptrs); +#endif } #endif #ifdef CK_ENABLE_BF16 if constexpr(is_same_v && is_same_v && is_same_v) { +#ifdef CK_USE_XDL add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instances( op_ptrs); +#endif +#ifdef CK_USE_WMMA + add_device_grouped_conv3d_fwd_wmma_cshufflev3_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + op_ptrs); +#endif } #endif #ifdef CK_ENABLE_INT8 if constexpr(is_same_v && is_same_v && is_same_v) { +#ifdef CK_USE_XDL add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_int8_instances( op_ptrs); +#endif } #endif } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp index ae182e14339..c651aab2c9d 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp @@ -21,6 +21,7 @@ namespace instance { using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd; +#ifdef CK_USE_XDL #ifdef CK_ENABLE_BF16 // grouped conv3d forward multi AB scaleadd, NDHWGC/GKZYXC/NDHWGK void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances( @@ -85,6 +86,42 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_int8_ins ScaleAdd, PassThrough>>>& instances); #endif +#endif // CK_USE_XDL + +#ifdef CK_USE_WMMA +#ifdef CK_ENABLE_BF16 +// grouped conv3d forward multi AB scaleadd, NDHWGC/GKZYXC/NDHWGK +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + ScaleAdd, + ScaleAdd, + PassThrough>>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector, + NDHWGK, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + F16, + ScaleAdd, + ScaleAdd, + PassThrough>>>& instances); +#endif +#endif // CK_USE_WMMA template > op_ptrs; +#ifdef CK_USE_XDL if constexpr(NumDimSpatial == 3 && is_same_v && is_same_v && is_same_v) { @@ -169,6 +207,32 @@ struct DeviceOperationInstanceFactory && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v> && + is_same_v> && + is_same_v && is_same_v) + { + add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v> && + is_same_v> && + is_same_v && is_same_v) + { + add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + op_ptrs); + } +#endif + } +#endif // CK_USE_WMMA return op_ptrs; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_scaleadd_relu.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_scaleadd_relu.hpp index ee4e7ebc239..8b60be1d0e3 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_scaleadd_relu.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_scaleadd_relu.hpp @@ -21,6 +21,7 @@ namespace instance { using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ScaleAddScaleAddRelu = ck::tensor_operation::element_wise::ScaleAddScaleAddRelu; +#ifdef CK_USE_XDL #ifdef CK_ENABLE_BF16 // grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instances( @@ -85,6 +86,42 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw PassThrough, ScaleAddScaleAddRelu>>>& instances); #endif +#endif + +#ifdef CK_USE_WMMA +#ifdef CK_ENABLE_BF16 +// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + ScaleAddScaleAddRelu>>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector, + NDHWGK, + F16, + F16, + ck::Tuple, + F16, + PassThrough, + PassThrough, + ScaleAddScaleAddRelu>>>& instances); +#endif +#endif template && is_same_v && is_same_v) { +#ifdef CK_USE_XDL add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_instances( op_ptrs); +#endif } #endif #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) { +#ifdef CK_USE_XDL add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f16_instances( op_ptrs); +#endif +#ifdef CK_USE_WMMA + add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f16_instances( + op_ptrs); +#endif } #endif #ifdef CK_ENABLE_BF16 if constexpr(is_same_v && is_same_v && is_same_v) { +#ifdef CK_USE_XDL add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instances( op_ptrs); +#endif +#ifdef CK_USE_WMMA + add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + op_ptrs); +#endif } #endif #ifdef CK_ENABLE_INT8 if constexpr(is_same_v && is_same_v && is_same_v) { +#ifdef CK_USE_XDL add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_int8_instances( op_ptrs); +#endif } #endif } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_wmma_cshufflev3.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_wmma_cshufflev3.inc new file mode 100644 index 00000000000..7064de30fa3 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_wmma_cshufflev3.inc @@ -0,0 +1,80 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// grouped conv2d forward, NHWGC/GKYXC/NHWGK +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances); +#endif + +// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index 575e14d5bbe..cca07bb4532 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -98,6 +98,11 @@ function(add_instance_library INSTANCE_NAME) message(DEBUG "removing gemm_wmma_universal_f8 instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() + # Do not build WMMA grouped conv 3d fwd fp8 / bf8 for any targets except gfx12+ + if(NOT INST_TARGETS MATCHES "gfx12" AND source_name MATCHES "grouped_conv3d_fwd_wmma" AND (source_name MATCHES "_fp8_" OR source_name MATCHES "_bf8_")) + message(DEBUG "removing grouped_conv3d_fwd_wmma fp8/bf8 instance ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() # Do not build gemm_universal_preshuffle_f8 for any targets except gfx94, gfx95 and gfx12 if(NOT (INST_TARGETS MATCHES "gfx94|gfx95|gfx12") AND (source_name MATCHES "gemm_universal_preshuffle") AND source_name MATCHES "_f8_" ) message(DEBUG "removing gemm_universal_preshuffle_f8 instance ${source} ") diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt index 704079f181c..55ce9f72c96 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt @@ -89,10 +89,11 @@ set(GROUPED_CONV2D_FWD # GNHWC, GKYXC, GNHWK dl/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp dl/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp - # NHWGC, GKYXC, NHWGK + # # NHWGC, GKYXC, NHWGK dl/device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instance.cpp dl/device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instance.cpp - # WMMA + + # WMMA_OLD # GNHWC, GKYXC, GNHWK wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp @@ -111,6 +112,11 @@ set(GROUPED_CONV2D_FWD wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instance.cpp wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instance.cpp wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instance.cpp + + # WMMA CSHUFFLEV3 + ## NHWGC, GKYXC, NHWGK + wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp ) # Add generated files for sharded instantiations. include(ShardInstantiation) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp new file mode 100644 index 00000000000..7100bedc708 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -0,0 +1,65 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp new file mode 100644 index 00000000000..65a7af51aae --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -0,0 +1,65 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_f16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_f16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_f16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_f16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/CMakeLists.txt index 59312849c36..b300c5929d9 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/CMakeLists.txt @@ -1,7 +1,7 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS add_instance_library(device_grouped_conv2d_fwd_bias_clamp_instance xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp @@ -39,4 +39,8 @@ add_instance_library(device_grouped_conv2d_fwd_bias_clamp_instance xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_inter_instance.cpp xdl/comp/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_comp_instance.cpp xdl/comp/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_comp_instance.cpp + + # WMMA CSHUFFLE V3 + wmma/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + wmma/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_fp16_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/wmma/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/wmma/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp new file mode 100644 index 00000000000..e5a09f8b6be --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/wmma/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -0,0 +1,62 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddClamp>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + AddClamp>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1P0, + Tuple, + AddClamp>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1S1P0, + Tuple, + AddClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/wmma/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_fp16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/wmma/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_fp16_instance.cpp new file mode 100644 index 00000000000..8eac80e4086 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/wmma/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_fp16_instance.cpp @@ -0,0 +1,62 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector, + NHWGK, + F16, + F16, + Tuple, + F16, + PassThrough, + PassThrough, + AddClamp>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_f16_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + AddClamp>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_f16_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1P0, + Tuple, + AddClamp>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_f16_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1S1P0, + Tuple, + AddClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/CMakeLists.txt index 69823c3246a..ad803ea0cb1 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/CMakeLists.txt @@ -1,7 +1,7 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS add_instance_library(device_grouped_conv2d_fwd_clamp_instance xdl/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp xdl/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp @@ -39,4 +39,8 @@ add_instance_library(device_grouped_conv2d_fwd_clamp_instance xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_inter_instance.cpp xdl/comp/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_comp_instance.cpp xdl/comp/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_comp_instance.cpp + + # WMMA CSHUFFLE V3 + wmma/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + wmma/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_fp16_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/wmma/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/wmma/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp new file mode 100644 index 00000000000..8dec66f3358 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/wmma/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -0,0 +1,62 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwd1x1P0, + Tuple<>, + Clamp>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwd1x1S1P0, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/wmma/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_fp16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/wmma/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_fp16_instance.cpp new file mode 100644 index 00000000000..224fa0f5c7a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/wmma/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_fp16_instance.cpp @@ -0,0 +1,62 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector, + NHWGK, + F16, + F16, + Tuple<>, + F16, + PassThrough, + PassThrough, + Clamp>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_f16_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_f16_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwd1x1P0, + Tuple<>, + Clamp>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_f16_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwd1x1S1P0, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_dynamic_op/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_dynamic_op/CMakeLists.txt index dd5c69a7c2a..e84f9d906f4 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_dynamic_op/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_dynamic_op/CMakeLists.txt @@ -1,11 +1,13 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS set(GROUPED_CONV2D_FWD_DYNAMIC_OP xdl/device_grouped_conv2d_fwd_xdl_dynamic_op_nhwgc_gkyxc_nhwgk_bf16_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_dynamic_op_nhwgc_gkyxc_nhwgk_f16_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_dynamic_op_nhwgc_gkyxc_nhwgk_f32_instance.cpp - xdl/device_grouped_conv2d_fwd_xdl_dynamic_op_nhwgc_gkyxc_nhwgk_int8_instance.cpp) + xdl/device_grouped_conv2d_fwd_xdl_dynamic_op_nhwgc_gkyxc_nhwgk_int8_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_dynamic_op_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_dynamic_op_nhwgc_gkyxc_nhwgk_f16_instance.cpp) add_instance_library(device_grouped_conv2d_fwd_dynamic_op_instance ${GROUPED_CONV2D_FWD_DYNAMIC_OP}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_dynamic_op/wmma/device_grouped_conv2d_fwd_wmma_dynamic_op_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_dynamic_op/wmma/device_grouped_conv2d_fwd_wmma_dynamic_op_nhwgc_gkyxc_nhwgk_bf16_instance.cpp new file mode 100644 index 00000000000..d8e05942bd4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_dynamic_op/wmma/device_grouped_conv2d_fwd_wmma_dynamic_op_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -0,0 +1,55 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_dynamic_op_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv2d_fwd_wmma_cshufflev3_dynamic_op_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector, + NHWGK, + BF16, + BF16, + ck::Tuple<>, + BF16, + PassThrough, + PassThrough, + DynamicUnaryOp>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_dynamic_op_bf16_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_dynamic_op_bf16_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_dynamic_op_bf16_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_dynamic_op/wmma/device_grouped_conv2d_fwd_wmma_dynamic_op_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_dynamic_op/wmma/device_grouped_conv2d_fwd_wmma_dynamic_op_nhwgc_gkyxc_nhwgk_f16_instance.cpp new file mode 100644 index 00000000000..93126dde032 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_dynamic_op/wmma/device_grouped_conv2d_fwd_wmma_dynamic_op_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -0,0 +1,55 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_dynamic_op_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv2d_fwd_wmma_cshufflev3_dynamic_op_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector, + NHWGK, + F16, + F16, + ck::Tuple<>, + F16, + PassThrough, + PassThrough, + DynamicUnaryOp>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_dynamic_op_f16_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_dynamic_op_f16_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_dynamic_op_f16_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index 7b5138ad9ed..a25f66665d3 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -52,22 +52,27 @@ set(GROUPED_CONV3D_FWD xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_part2_instance.cpp xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_part2_instance.cpp - wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp - wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp - wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp - wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instance.cpp - wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1p0_instance.cpp - wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instance.cpp - wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instance.cpp - wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1p0_instance.cpp - wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instance.cpp - wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp - wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp - wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp - wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instance.cpp - wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instance.cpp - wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instance.cpp - wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp + # WMMA_OLD + wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1p0_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1p0_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp + + # WMMA CSHUFFLE V3 + wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp ) # Add generated files for sharded instantiations. include(ShardInstantiation) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 00000000000..2660da6ad92 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,55 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp new file mode 100644 index 00000000000..1c08b80f750 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -0,0 +1,55 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_f16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_f16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_f16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt index 11dac1620a7..a84f3b753b9 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt @@ -1,7 +1,7 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS set(GROUPED_CONV3D_FWD xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp @@ -34,6 +34,10 @@ set(GROUPED_CONV3D_FWD xdl/comp/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_comp_instance.cpp xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp - ) + + # WMMA CSHUFFLE V3 + wmma/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + wmma/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_fp16_instance.cpp +) add_instance_library(device_grouped_conv3d_fwd_bias_clamp_instance ${GROUPED_CONV3D_FWD}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/wmma/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/wmma/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 00000000000..47f3802ad18 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/wmma/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,61 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddClamp>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Tuple, + AddClamp>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1P0, + Tuple, + AddClamp>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1S1P0, + Tuple, + AddClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/wmma/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_fp16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/wmma/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_fp16_instance.cpp new file mode 100644 index 00000000000..02643803814 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/wmma/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_fp16_instance.cpp @@ -0,0 +1,61 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector, + NDHWGK, + F16, + F16, + Tuple, + F16, + PassThrough, + PassThrough, + AddClamp>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_f16_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Tuple, + AddClamp>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_f16_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1P0, + Tuple, + AddClamp>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_f16_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1S1P0, + Tuple, + AddClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/CMakeLists.txt index bd143bc0b94..d54fc6d2022 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/CMakeLists.txt @@ -1,12 +1,14 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS set(GROUPED_CONV3D_FWD_BILINEAR xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp - xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp) + xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp) add_instance_library(device_grouped_conv3d_fwd_bilinear_instance ${GROUPED_CONV3D_FWD_BILINEAR}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/wmma/device_grouped_conv3d_fwd_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/wmma/device_grouped_conv3d_fwd_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 00000000000..2abe9b485c0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/wmma/device_grouped_conv3d_fwd_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,55 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_bilinear_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Bilinear>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bilinear_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bilinear_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bilinear_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/wmma/device_grouped_conv3d_fwd_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/wmma/device_grouped_conv3d_fwd_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp new file mode 100644 index 00000000000..d8247cac3e0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/wmma/device_grouped_conv3d_fwd_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -0,0 +1,55 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_bilinear_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector, + NDHWGK, + F16, + F16, + ck::Tuple, + F16, + PassThrough, + PassThrough, + Bilinear>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bilinear_f16_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bilinear_f16_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bilinear_f16_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt index 84b59b4849a..49dfac34343 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt @@ -1,7 +1,7 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS set(GROUPED_CONV3D_FWD xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp @@ -34,6 +34,10 @@ set(GROUPED_CONV3D_FWD xdl/comp/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_comp_instance.cpp xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp - ) + + # WMMA CSHUFFLE V3 + wmma/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + wmma/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_fp16_instance.cpp +) add_instance_library(device_grouped_conv3d_fwd_clamp_instance ${GROUPED_CONV3D_FWD}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/wmma/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/wmma/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 00000000000..c66daec97ce --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/wmma/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,61 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1P0, + Tuple<>, + Clamp>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1S1P0, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/wmma/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_fp16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/wmma/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_fp16_instance.cpp new file mode 100644 index 00000000000..dfd8fc7d156 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/wmma/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_fp16_instance.cpp @@ -0,0 +1,61 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector, + NDHWGK, + F16, + F16, + Tuple<>, + F16, + PassThrough, + PassThrough, + Clamp>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_f16_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_f16_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1P0, + Tuple<>, + Clamp>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_f16_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1S1P0, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convinvscale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convinvscale/CMakeLists.txt index 6b284512bed..234a4894fc2 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convinvscale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convinvscale/CMakeLists.txt @@ -1,8 +1,9 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS set(GROUPED_CONV3D_FWD_CONVINVSCALE - xdl/device_grouped_conv3d_fwd_xdl_convinvscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp) + xdl/device_grouped_conv3d_fwd_xdl_convinvscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_convinvscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp) add_instance_library(device_grouped_conv3d_fwd_convinvscale_instance ${GROUPED_CONV3D_FWD_CONVINVSCALE}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convinvscale/wmma/device_grouped_conv3d_fwd_wmma_convinvscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convinvscale/wmma/device_grouped_conv3d_fwd_wmma_convinvscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp new file mode 100644 index 00000000000..90d0d47085c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convinvscale/wmma/device_grouped_conv3d_fwd_wmma_convinvscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp @@ -0,0 +1,63 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_outelementop_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = ck::f8_t; +using ConvInvscale = ck::tensor_operation::element_wise::ConvInvscale; + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_convinvscale_ndhwgc_gkzyxc_ndhwgk_f8_instances( + std::vector, + NDHWGK, + F8, + F8, + ck::Tuple<>, + F8, + PassThrough, + PassThrough, + ConvInvscale, + F8, + F8>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_outelementop_f8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwdDefault, + ConvInvscale>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_outelementop_f8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwd1x1P0, + ConvInvscale>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_outelementop_f8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwd1x1S1P0, + ConvInvscale>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/CMakeLists.txt index 90ddaacbcaf..06d9c89a9ec 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/CMakeLists.txt @@ -1,12 +1,17 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS set(GROUPED_CONV3D_FWD_CONVSCALE xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instance.cpp - xdl/device_grouped_conv3d_fwd_xdl_combconvscale_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instance.cpp) + xdl/device_grouped_conv3d_fwd_xdl_combconvscale_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_convscale_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_combconvscale_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instance.cpp) add_instance_library(device_grouped_conv3d_fwd_convscale_instance ${GROUPED_CONV3D_FWD_CONVSCALE}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/wmma/device_grouped_conv3d_fwd_wmma_combconvscale_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/wmma/device_grouped_conv3d_fwd_wmma_combconvscale_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instance.cpp new file mode 100644 index 00000000000..877c9998ac4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/wmma/device_grouped_conv3d_fwd_wmma_combconvscale_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instance.cpp @@ -0,0 +1,61 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_outelementop_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_combconvscale_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instances( + std::vector, + NDHWGK, + F8, + F8, + ck::Tuple<>, + F32, + PassThrough, + PassThrough, + CombConvScale, + F8, + F8>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_outelementop_f8_f8_f32_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwdDefault, + CombConvScale>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_outelementop_f8_f8_f32_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwd1x1P0, + CombConvScale>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_outelementop_f8_f8_f32_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwd1x1S1P0, + CombConvScale>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/wmma/device_grouped_conv3d_fwd_wmma_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/wmma/device_grouped_conv3d_fwd_wmma_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instance.cpp new file mode 100644 index 00000000000..0e20ae6d99a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/wmma/device_grouped_conv3d_fwd_wmma_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instance.cpp @@ -0,0 +1,62 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_outelementop_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using ConvScale = ck::tensor_operation::element_wise::ConvScale; + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instances( + std::vector, + NDHWGK, + BF8, + F8, + ck::Tuple<>, + F8, + PassThrough, + PassThrough, + ConvScale, + BF8, + F8>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_outelementop_bf8_f8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwdDefault, + ConvScale>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_outelementop_bf8_f8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwd1x1P0, + ConvScale>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_outelementop_bf8_f8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwd1x1S1P0, + ConvScale>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/wmma/device_grouped_conv3d_fwd_wmma_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/wmma/device_grouped_conv3d_fwd_wmma_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_instance.cpp new file mode 100644 index 00000000000..a61b026ae14 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/wmma/device_grouped_conv3d_fwd_wmma_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_instance.cpp @@ -0,0 +1,62 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_outelementop_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using ConvScale = ck::tensor_operation::element_wise::ConvScale; + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_instances( + std::vector, + NDHWGK, + BF8, + BF8, + ck::Tuple<>, + F8, + PassThrough, + PassThrough, + ConvScale, + BF8, + BF8>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_outelementop_bf8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwdDefault, + ConvScale>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_outelementop_bf8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwd1x1P0, + ConvScale>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_outelementop_bf8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwd1x1S1P0, + ConvScale>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/wmma/device_grouped_conv3d_fwd_wmma_convscale_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/wmma/device_grouped_conv3d_fwd_wmma_convscale_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instance.cpp new file mode 100644 index 00000000000..4cd84a6e19a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/wmma/device_grouped_conv3d_fwd_wmma_convscale_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instance.cpp @@ -0,0 +1,62 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_outelementop_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using ConvScale = ck::tensor_operation::element_wise::ConvScale; + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_convscale_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instances( + std::vector, + NDHWGK, + F8, + BF8, + ck::Tuple<>, + F8, + PassThrough, + PassThrough, + ConvScale, + F8, + BF8>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_outelementop_f8_bf8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwdDefault, + ConvScale>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_outelementop_f8_bf8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwd1x1P0, + ConvScale>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_outelementop_f8_bf8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwd1x1S1P0, + ConvScale>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/wmma/device_grouped_conv3d_fwd_wmma_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/wmma/device_grouped_conv3d_fwd_wmma_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp new file mode 100644 index 00000000000..59306f1c6fa --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/wmma/device_grouped_conv3d_fwd_wmma_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp @@ -0,0 +1,63 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_outelementop_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = ck::f8_t; +using ConvScale = ck::tensor_operation::element_wise::ConvScale; + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instances( + std::vector, + NDHWGK, + F8, + F8, + ck::Tuple<>, + F8, + PassThrough, + PassThrough, + ConvScale, + F8, + F8>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_outelementop_f8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwdDefault, + ConvScale>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_outelementop_f8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwd1x1P0, + ConvScale>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_outelementop_f8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwd1x1S1P0, + ConvScale>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_add/CMakeLists.txt index e148b19839f..272d53d7e04 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_add/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_add/CMakeLists.txt @@ -1,8 +1,9 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS set(GROUPED_CONV3D_FWD_CONVSCALE_ADD - xdl/device_grouped_conv3d_fwd_xdl_convscale_add_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp) + xdl/device_grouped_conv3d_fwd_xdl_convscale_add_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_convscale_add_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp) add_instance_library(device_grouped_conv3d_fwd_convscale_add_instance ${GROUPED_CONV3D_FWD_CONVSCALE_ADD}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_add/wmma/device_grouped_conv3d_fwd_wmma_convscale_add_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_add/wmma/device_grouped_conv3d_fwd_wmma_convscale_add_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp new file mode 100644 index 00000000000..30918ec2d4a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_add/wmma/device_grouped_conv3d_fwd_wmma_convscale_add_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp @@ -0,0 +1,65 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_binary_outelementop_instance.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = ck::f8_t; +using F32 = float; +using ConvScaleAdd = ck::tensor_operation::element_wise::ConvScaleAdd; + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_convscale_add_ndhwgc_gkzyxc_ndhwgk_f8_instances( + std::vector, + NDHWGK, + F8, + F8, + ck::Tuple, + F8, + PassThrough, + PassThrough, + ConvScaleAdd, + F8, + F8>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_binary_outelementop_f8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple, + NDHWGK, + ConvFwdDefault, + ConvScaleAdd>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_binary_outelementop_f8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple, + NDHWGK, + ConvFwd1x1P0, + ConvScaleAdd>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_binary_outelementop_f8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple, + NDHWGK, + ConvFwd1x1S1P0, + ConvScaleAdd>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/CMakeLists.txt index e79da12b1a0..a52b131214f 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/CMakeLists.txt @@ -1,9 +1,11 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS set(GROUPED_CONV3D_FWD_CONVSCALE_RELU xdl/device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp - xdl/device_grouped_conv3d_fwd_xdl_combconvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instance.cpp) + xdl/device_grouped_conv3d_fwd_xdl_combconvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_combconvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instance.cpp) add_instance_library(device_grouped_conv3d_fwd_convscale_relu_instance ${GROUPED_CONV3D_FWD_CONVSCALE_RELU}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/wmma/device_grouped_conv3d_fwd_wmma_combconvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/wmma/device_grouped_conv3d_fwd_wmma_combconvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instance.cpp new file mode 100644 index 00000000000..9701dfe2cc3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/wmma/device_grouped_conv3d_fwd_wmma_combconvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instance.cpp @@ -0,0 +1,67 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_outelementop_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_relu.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = ck::f8_t; +using F32 = float; + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_combconvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instances( + std::vector, + NDHWGK, + F8, + F8, + ck::Tuple<>, + F32, + PassThrough, + PassThrough, + CombConvScaleRelu, + F8, + F8>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_outelementop_f8_f8_f32_instances< + 3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwdDefault, + CombConvScaleRelu>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_outelementop_f8_f8_f32_instances< + 3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwd1x1P0, + CombConvScaleRelu>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_outelementop_f8_f8_f32_instances< + 3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwd1x1S1P0, + CombConvScaleRelu>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/wmma/device_grouped_conv3d_fwd_wmma_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/wmma/device_grouped_conv3d_fwd_wmma_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp new file mode 100644 index 00000000000..e42479d0de7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/wmma/device_grouped_conv3d_fwd_wmma_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp @@ -0,0 +1,64 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_outelementop_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_relu.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = ck::f8_t; +using ConvScaleRelu = ck::tensor_operation::element_wise::ConvScaleRelu; + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instances( + std::vector, + NDHWGK, + F8, + F8, + ck::Tuple<>, + F8, + PassThrough, + PassThrough, + ConvScaleRelu, + F8, + F8>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_outelementop_f8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwdDefault, + ConvScaleRelu>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_outelementop_f8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwd1x1P0, + ConvScaleRelu>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_outelementop_f8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwd1x1S1P0, + ConvScaleRelu>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_dynamic_op/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_dynamic_op/CMakeLists.txt index 715ce6630ad..f67221aa77f 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_dynamic_op/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_dynamic_op/CMakeLists.txt @@ -1,11 +1,13 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS set(GROUPED_CONV3D_FWD_DYNAMIC_OP xdl/device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp - xdl/device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp) + xdl/device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_dynamic_op_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_dynamic_op_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp) add_instance_library(device_grouped_conv3d_fwd_dynamic_op_instance ${GROUPED_CONV3D_FWD_DYNAMIC_OP}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_dynamic_op/wmma/device_grouped_conv3d_fwd_wmma_dynamic_op_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_dynamic_op/wmma/device_grouped_conv3d_fwd_wmma_dynamic_op_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 00000000000..f14f2da8f6d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_dynamic_op/wmma/device_grouped_conv3d_fwd_wmma_dynamic_op_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,55 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_dynamic_op_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_dynamic_op_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + ck::Tuple<>, + BF16, + PassThrough, + PassThrough, + DynamicUnaryOp>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_dynamic_op_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_dynamic_op_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_dynamic_op_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_dynamic_op/wmma/device_grouped_conv3d_fwd_wmma_dynamic_op_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_dynamic_op/wmma/device_grouped_conv3d_fwd_wmma_dynamic_op_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp new file mode 100644 index 00000000000..e71da052049 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_dynamic_op/wmma/device_grouped_conv3d_fwd_wmma_dynamic_op_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -0,0 +1,55 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_dynamic_op_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_dynamic_op_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector, + NDHWGK, + F16, + F16, + ck::Tuple<>, + F16, + PassThrough, + PassThrough, + DynamicUnaryOp>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_dynamic_op_f16_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_dynamic_op_f16_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_dynamic_op_f16_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt index 0622b121b53..893f490ec45 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt @@ -1,12 +1,14 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS -set(GROUPED_CONV3D_FWD_BILINEAR +# ONLY XDL_AND_WMMA_KERNELS +set(GROUPED_CONV3D_FWD_SCALE xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp - xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp) + xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp) -add_instance_library(device_grouped_conv3d_fwd_scale_instance ${GROUPED_CONV3D_FWD_BILINEAR}) +add_instance_library(device_grouped_conv3d_fwd_scale_instance ${GROUPED_CONV3D_FWD_SCALE}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/wmma/device_grouped_conv3d_fwd_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/wmma/device_grouped_conv3d_fwd_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 00000000000..8f072c58a19 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/wmma/device_grouped_conv3d_fwd_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,55 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_scale_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + ck::Tuple<>, + BF16, + PassThrough, + PassThrough, + Scale>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scale_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scale_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scale_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/wmma/device_grouped_conv3d_fwd_wmma_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/wmma/device_grouped_conv3d_fwd_wmma_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp new file mode 100644 index 00000000000..b7d3c3cf96f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/wmma/device_grouped_conv3d_fwd_wmma_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -0,0 +1,55 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_scale_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scale_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector, + NDHWGK, + F16, + F16, + ck::Tuple<>, + F16, + PassThrough, + PassThrough, + Scale>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scale_f16_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scale_f16_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scale_f16_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/CMakeLists.txt index 21ec8ecc6e4..aa3dd0af12d 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/CMakeLists.txt @@ -1,11 +1,16 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS set(GROUPED_CONV3D_FWD_SCALEADD_AB xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp - xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp) + xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp + + # WMMA CSHUFFLE V3 + wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp + ) add_instance_library(device_grouped_conv3d_fwd_scaleadd_ab_instance ${GROUPED_CONV3D_FWD_SCALEADD_AB}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 00000000000..f1a711de5ed --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,52 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + ScaleAdd, + ScaleAdd, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp new file mode 100644 index 00000000000..5aa527d8295 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -0,0 +1,52 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector, + NDHWGK, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + F16, + ScaleAdd, + ScaleAdd, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_f16_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_f16_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_f16_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/CMakeLists.txt index 3495d88637f..659877c21b0 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/CMakeLists.txt @@ -1,11 +1,13 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS set(GROUPED_CONV3D_FWD_scaleadd_scaleadd_RELU xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp - xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp) + xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp) add_instance_library(device_grouped_conv3d_fwd_scaleadd_scaleadd_relu_instance ${GROUPED_CONV3D_FWD_scaleadd_scaleadd_RELU}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/wmma/device_grouped_conv3d_fwd_wmma_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/wmma/device_grouped_conv3d_fwd_wmma_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 00000000000..e19ca2841a7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/wmma/device_grouped_conv3d_fwd_wmma_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,58 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_scaleadd_relu_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + ScaleAddScaleAddRelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_scaleadd_relu_bf16_instances< + 3, + NDHWGC, + GKZYXC, + ck::Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_scaleadd_relu_bf16_instances< + 3, + NDHWGC, + GKZYXC, + ck::Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_scaleadd_relu_bf16_instances< + 3, + NDHWGC, + GKZYXC, + ck::Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/wmma/device_grouped_conv3d_fwd_wmma_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/wmma/device_grouped_conv3d_fwd_wmma_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp new file mode 100644 index 00000000000..f89e349ff1e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/wmma/device_grouped_conv3d_fwd_wmma_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -0,0 +1,58 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_scaleadd_relu_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector, + NDHWGK, + F16, + F16, + ck::Tuple, + F16, + PassThrough, + PassThrough, + ScaleAddScaleAddRelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_scaleadd_relu_f16_instances< + 3, + NDHWGC, + GKZYXC, + ck::Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_scaleadd_relu_f16_instances< + 3, + NDHWGC, + GKZYXC, + ck::Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_scaleadd_relu_f16_instances< + 3, + NDHWGC, + GKZYXC, + ck::Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp index 2a2f516727c..50cd58eec37 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp @@ -196,6 +196,7 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification, float best_tflops = 0; float best_gb_per_sec = 0; int num_kernel = 0; + int valids = 0; // profile device op instances bool pass = true; @@ -217,6 +218,8 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification, // re-init output to zero before profiling next kernel out_device_buf.SetZero(); + valids++; + std::string op_name = op_ptr->GetTypeString(); auto invoker_ptr = op_ptr->MakeInvokerPointer(); @@ -322,6 +325,8 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification, run_impl(op_ptr, argument_ptr); } + printf("\033[36mvalids: %d\n\033[0m", valids); + std::cout << "Best configuration parameters:" << "\nname: " << best_op_name << "\navg_time: " << best_avg_time << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl; diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_bilinear_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_bilinear_impl.hpp new file mode 100644 index 00000000000..0b5d62dfd44 --- /dev/null +++ b/profiler/include/profiler/profile_grouped_conv_fwd_bilinear_impl.hpp @@ -0,0 +1,324 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "profiler/common.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bilinear.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_grouped_conv_fwd_bilinear_impl( + int do_verification, + int init_method, + bool do_log, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const ck::tensor_operation::element_wise::Bilinear& bilinear_op = + ck::tensor_operation::element_wise::Bilinear{}) +{ + using InElementOp = ck::tensor_operation::element_wise::PassThrough; + using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + using OutElementOp = ck::tensor_operation::element_wise::Bilinear; + using CShuffleDataType = float; + + bool pass = true; + + auto f_host_tensor_descriptor = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + + auto f_host_tensor_descriptor_packed = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + + auto e_host_tensor_descriptor = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + auto d_host_tensor_descriptor = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array d_g_n_k_wos_lengths{}; + std::array d_g_n_k_wos_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(f_host_tensor_descriptor.GetLengths(), a_g_n_c_wis_lengths); + copy(f_host_tensor_descriptor.GetStrides(), a_g_n_c_wis_strides); + copy(f_host_tensor_descriptor_packed.GetLengths(), b_g_k_c_xs_lengths); + copy(f_host_tensor_descriptor_packed.GetStrides(), b_g_k_c_xs_strides); + copy(d_host_tensor_descriptor.GetLengths(), d_g_n_k_wos_lengths); + copy(d_host_tensor_descriptor.GetStrides(), d_g_n_k_wos_strides); + copy(e_host_tensor_descriptor.GetLengths(), e_g_n_k_wos_lengths); + copy(e_host_tensor_descriptor.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + Tensor input(f_host_tensor_descriptor); + Tensor weight(f_host_tensor_descriptor_packed); + Tensor d_tensor(d_host_tensor_descriptor); + Tensor c(e_host_tensor_descriptor); + Tensor host_output(e_host_tensor_descriptor); + Tensor device_output(e_host_tensor_descriptor); + + std::cout << "input: " << input.mDesc << std::endl; + std::cout << "weight: " << weight.mDesc << std::endl; + std::cout << "d_tensor: " << d_tensor.mDesc << std::endl; + std::cout << "output: " << host_output.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + input.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + weight.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_tensor.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + input.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + weight.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_tensor.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * weight.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_tensor.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(input.mData.data()); + wei_device_buf.ToDevice(weight.mData.data()); + d_device_buf.ToDevice(d_tensor.mData.data()); + + if(do_verification) + { + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd< + NDimSpatial, + InDataType, + WeiDataType, + CShuffleDataType, + InElementOp, + WeiElementOp, + ck::tensor_operation::element_wise::PassThrough>{}; + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = + ref_conv.MakeArgument(input, + weight, + c, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + ck::tensor_operation::element_wise::PassThrough{}); + + c.SetZero(); + ref_invoker.Run(ref_argument); + + host_output.ForEach([&](auto&, auto idx) { + const auto conv_shuffle = ck::type_convert(c(idx)); + + if constexpr(std::is_same_v) + { + const auto conv_val = ck::type_convert(conv_shuffle); + bilinear_op(host_output(idx), conv_val, d_tensor(idx)); + } + else + { + const auto conv_val = conv_shuffle; + const auto d_val = ck::type_convert(d_tensor(idx)); + + CShuffleDataType out_val{}; + bilinear_op(out_val, conv_val, d_val); + host_output(idx) = ck::type_convert(out_val); + } + }); + } + + std::string best_op_name; + float best_avg_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + int valids = 0; + + using DeviceOp = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + AComputeType, + BComputeType>; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + for(std::size_t i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + std::array{static_cast(d_device_buf.GetDeviceBuffer())}, + static_cast(out_device_buf.GetDeviceBuffer()), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, 1>{d_g_n_k_wos_lengths}, + std::array, 1>{d_g_n_k_wos_strides}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + bilinear_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + ++valids; + + std::string op_name = op_ptr->GetTypeString(); + + if(do_log) + { + std::cout << "Evaluating [" << i << "] " << op_name << std::endl; + } + + out_device_buf.SetZero(); + auto ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + auto flop = conv_param.GetFlops(); + auto num_btype = conv_param.GetByte() + + sizeof(DDataType) * (conv_param.G_ * conv_param.N_ * conv_param.K_); + + for(std::size_t j = 0; j < conv_param.filter_spatial_lengths_.size(); ++j) + { + num_btype += sizeof(DDataType) * conv_param.output_spatial_lengths_[j]; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + if(do_log) + { + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops + << " TFlops, " << gb_per_sec << " GB/s, " << op_name << std::endl; + } + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_avg_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + out_device_buf.FromDevice(device_output.mData.data()); + + bool is_valid = ck::utils::check_err(device_output, + host_output, + "Error: Device and Host results do not match!", + get_rtol(), + get_atol()); + + if(!is_valid) + { + pass = false; + } + + if(do_log) + { + LogRangeAsType(std::cout << "input : ", input.mData, ",") << std::endl; + LogRangeAsType(std::cout << "weight: ", weight.mData, ",") << std::endl; + LogRangeAsType(std::cout << "d_tensor: ", d_tensor.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "host_output : ", host_output.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "device_output: ", device_output.mData, ",") + << std::endl; + } + } + } + else + { + if(do_log) + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" + << std::endl; + } + } + } + + printf("\033[36mvalids: %d\033[0m\n", valids); + + if(valids > 0) + { + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + } + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_convscale_add_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_convscale_add_impl.hpp new file mode 100644 index 00000000000..aea97a89039 --- /dev/null +++ b/profiler/include/profiler/profile_grouped_conv_fwd_convscale_add_impl.hpp @@ -0,0 +1,314 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_add.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_grouped_conv_fwd_convscale_add_impl( + int do_verification, + int init_method, + bool do_log, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const ck::tensor_operation::element_wise::ConvScaleAdd& convscaleadd_op = + ck::tensor_operation::element_wise::ConvScaleAdd{}) +{ + using InElementOp = ck::tensor_operation::element_wise::PassThrough; + using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + using OutElementOp = ck::tensor_operation::element_wise::ConvScaleAdd; + + bool pass = true; + + auto f_host_tensor_descriptor = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + + auto f_host_tensor_descriptor_packed = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + + auto e_host_tensor_descriptor = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + auto d_host_tensor_descriptor = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array d_g_n_k_wos_lengths{}; + std::array d_g_n_k_wos_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(f_host_tensor_descriptor.GetLengths(), a_g_n_c_wis_lengths); + copy(f_host_tensor_descriptor.GetStrides(), a_g_n_c_wis_strides); + copy(f_host_tensor_descriptor_packed.GetLengths(), b_g_k_c_xs_lengths); + copy(f_host_tensor_descriptor_packed.GetStrides(), b_g_k_c_xs_strides); + copy(d_host_tensor_descriptor.GetLengths(), d_g_n_k_wos_lengths); + copy(d_host_tensor_descriptor.GetStrides(), d_g_n_k_wos_strides); + copy(e_host_tensor_descriptor.GetLengths(), e_g_n_k_wos_lengths); + copy(e_host_tensor_descriptor.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + Tensor input(f_host_tensor_descriptor); + Tensor weight(f_host_tensor_descriptor_packed); + Tensor d_tensor(d_host_tensor_descriptor); + Tensor host_output(e_host_tensor_descriptor); + Tensor device_output(e_host_tensor_descriptor); + + std::cout << "input: " << input.mDesc << std::endl; + std::cout << "weight: " << weight.mDesc << std::endl; + std::cout << "d_tensor: " << d_tensor.mDesc << std::endl; + std::cout << "output: " << host_output.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + input.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + weight.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_tensor.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + input.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + weight.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_tensor.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * weight.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_tensor.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(input.mData.data()); + wei_device_buf.ToDevice(weight.mData.data()); + d_device_buf.ToDevice(d_tensor.mData.data()); + + if(do_verification) + { + // findme + using tmpType = float; // OutDataType; + // using tmpType = OutDataType; // OutDataType; + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd< + NDimSpatial, + InDataType, + WeiDataType, + tmpType, + InElementOp, + WeiElementOp, + ck::tensor_operation::element_wise::PassThrough>{}; + + Tensor c_tensor(e_host_tensor_descriptor); + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument_c = + ref_conv.MakeArgument(input, + weight, + c_tensor, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + ck::tensor_operation::element_wise::PassThrough{}); + + c_tensor.SetZero(); + ref_invoker.Run(ref_argument_c); + + host_output.ForEach([&](auto&, auto idx) { + convscaleadd_op(host_output(idx), c_tensor(idx), d_tensor(idx)); + }); + } + + std::string best_op_name; + float best_avg_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + int valids = 0; + + using DeviceOp = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + AComputeType, + BComputeType>; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + for(std::size_t i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + std::array{static_cast(d_device_buf.GetDeviceBuffer())}, + static_cast(out_device_buf.GetDeviceBuffer()), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, 1>{d_g_n_k_wos_lengths}, + std::array, 1>{d_g_n_k_wos_strides}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + convscaleadd_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + ++valids; + + std::string op_name = op_ptr->GetTypeString(); + + if(do_log) + { + std::cout << "Evaluating [" << i << "] " << op_name << std::endl; + } + + out_device_buf.SetZero(); + auto ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + auto flop = conv_param.GetFlops(); + auto num_btype = conv_param.GetByte() + + sizeof(DDataType) * (conv_param.G_ * conv_param.N_ * conv_param.K_); + + for(std::size_t j = 0; j < conv_param.filter_spatial_lengths_.size(); ++j) + { + num_btype += sizeof(DDataType) * conv_param.output_spatial_lengths_[j]; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + if(do_log) + { + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops + << " TFlops, " << gb_per_sec << " GB/s, " << op_name << std::endl; + } + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_avg_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + out_device_buf.FromDevice(device_output.mData.data()); + + double rtol = 1e-3, atol = 1e-3; + if(std::is_same::value) + { + rtol = 1e-1; + atol = 16.1; + } + + bool is_valid = ck::utils::check_err( + device_output, host_output, "incorrect results", rtol, atol); + + if(!is_valid) + { + pass = false; + } + + if(do_log) + { + LogRangeAsType(std::cout << "input : ", input.mData, ",") << std::endl; + LogRangeAsType(std::cout << "weight: ", weight.mData, ",") << std::endl; + LogRangeAsType(std::cout << "d_tensor: ", d_tensor.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "host_output : ", host_output.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "device_output: ", device_output.mData, ",") + << std::endl; + } + } + } + else + { + if(do_log) + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" + << std::endl; + } + } + } + + printf("\033[36mvalids: %d\033[0m\n", valids); + + if(valids > 0) + { + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + } + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp index 427d2b14dff..641215aff57 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp @@ -13,6 +13,8 @@ #include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_dynamic_op.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convinvscale.hpp" #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" @@ -146,6 +148,8 @@ bool profile_grouped_conv_fwd_impl(int do_verification, float best_tflops = 0; float best_gb_per_sec = 0; index_t num_kernel = 0; + int valids = 0; + // profile device op instances bool pass = true; @@ -165,6 +169,9 @@ bool profile_grouped_conv_fwd_impl(int do_verification, } std::string op_name = op_ptr->GetTypeString(); + valids++; + + out_device_buf.SetZero(); auto invoker_ptr = op_ptr->MakeInvokerPointer(); @@ -258,6 +265,8 @@ bool profile_grouped_conv_fwd_impl(int do_verification, run_impl(op_ptr, argument_ptr); } + printf("\033[36mvalids: %d\033[0m\n", valids); + std::cout << "Best configuration parameters:" << "\nname: " << best_op_name << "\navg_time: " << best_avg_time << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl; diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp index 50b97c3baee..8cf24a62150 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp @@ -5,11 +5,25 @@ #include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convinvscale.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_relu.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scale.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "profiler/common.hpp" +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + namespace ck { namespace profiler { @@ -29,7 +43,7 @@ bool profile_grouped_conv_fwd_outelementop_impl(int do_verification, bool time_kernel, const ck::utils::conv::ConvParam& conv_param) { - auto pass = true; // return status + auto pass = true; using CShuffleDataType = float; @@ -110,8 +124,27 @@ bool profile_grouped_conv_fwd_outelementop_impl(int do_verification, auto scale_out = type_convert( type_convert(2.0f * float(RAND_MAX / 2 - std::rand()) / float(RAND_MAX))); - // initialize out_element_op for each iteration - const auto out_element_op = OutElementOp{scale_in, scale_wei, scale_out}; + OutElementOp out_element_op; + if constexpr(std::is_same_v) + { + using Scale = ck::tensor_operation::element_wise::Scale; + out_element_op = OutElementOp{Scale{scale_in}, Scale{scale_wei}, PassThrough{}}; + } + else if constexpr(std::is_same_v) + { + using Scale = ck::tensor_operation::element_wise::Scale; + using Relu = ck::tensor_operation::element_wise::Relu; + out_element_op = OutElementOp{Scale{scale_in}, Scale{scale_wei}, Relu{}}; + } + else if constexpr(std::is_same_v) + { + out_element_op = OutElementOp{scale_out}; + } + else + { + out_element_op = OutElementOp{scale_in, scale_wei, scale_out}; + } std::cout << "scale_in: " << scale_in << std::endl; std::cout << "scale_wei: " << scale_wei << std::endl; @@ -148,13 +181,32 @@ bool profile_grouped_conv_fwd_outelementop_impl(int do_verification, c.SetZero(); ref_invoker.Run(ref_argument); - host_output.ForEach([&](auto&, auto idx) { out_element_op(host_output(idx), c(idx)); }); + host_output.ForEach([&](auto&, auto idx) { + if constexpr(std::is_same_v) + { + const auto conv_shuffle = ck::type_convert(c(idx)); + if constexpr(std::is_same_v) + { + const auto conv_val = ck::type_convert(conv_shuffle); + out_element_op(host_output(idx), conv_val); + } + else + { + out_element_op(host_output(idx), conv_shuffle); + } + } + else + { + out_element_op(host_output(idx), c(idx)); + } + }); } std::string best_op_name; float best_avg_time = 0; float best_tflops = 0; float best_gb_per_sec = 0; + int valids = 0; auto run_impl = [&](auto& op_ptr, auto& argument_ptr) { if(op_ptr->IsSupportedArgument(argument_ptr.get())) @@ -163,6 +215,7 @@ bool profile_grouped_conv_fwd_outelementop_impl(int do_verification, out_device_buf.SetZero(); std::string op_name = op_ptr->GetTypeString(); + valids++; auto invoker_ptr = op_ptr->MakeInvokerPointer(); @@ -199,15 +252,11 @@ bool profile_grouped_conv_fwd_outelementop_impl(int do_verification, if(do_log) { - LogRangeAsType(std::cout << "input : ", input.mData, ",") - << std::endl; - LogRangeAsType(std::cout << "weight: ", weight.mData, ",") - << std::endl; - LogRangeAsType( - std::cout << "host_output : ", host_output.mData, ",") + LogRangeAsType(std::cout << "input : ", input.mData, ",") << std::endl; + LogRangeAsType(std::cout << "weight: ", weight.mData, ",") << std::endl; + LogRangeAsType(std::cout << "host_output : ", host_output.mData, ",") << std::endl; - LogRangeAsType( - std::cout << "device_output: ", device_output.mData, ",") + LogRangeAsType(std::cout << "device_output: ", device_output.mData, ",") << std::endl; } } @@ -264,9 +313,12 @@ bool profile_grouped_conv_fwd_outelementop_impl(int do_verification, run_impl(op_ptr, argument_ptr); } + printf("\033[36mvalids: %d\033[0m\n", valids); + std::cout << "Best configuration parameters:" << "\nname: " << best_op_name << "\navg_time: " << best_avg_time << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl; + return pass; } diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_scaleadd_scaleadd_relu_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_scaleadd_scaleadd_relu_impl.hpp new file mode 100644 index 00000000000..177d0fcde98 --- /dev/null +++ b/profiler/include/profiler/profile_grouped_conv_fwd_scaleadd_scaleadd_relu_impl.hpp @@ -0,0 +1,391 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_scaleadd_relu.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace profiler { + +template +inline constexpr double get_rtol_scaleadd() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else + { + return 1e-3; + } +} + +template +inline constexpr double get_atol_scaleadd() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else + { + return 1e-3; + } +} + +template +bool profile_grouped_conv_fwd_scaleadd_scaleadd_relu_impl( + int do_verification, + int init_method, + bool do_log, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param) +{ + auto pass = true; + + using CShuffleDataType = float; + + using BiasDataType = std::conditional_t, float, InDataType>; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using InElementOp = PassThrough; + using WeiElementOp = PassThrough; + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + + const auto out_element_op = OutElementOp{1.0f, 2.0f}; + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + const index_t G = conv_param.G_; + const index_t K = conv_param.K_; + + auto bias1_ndhwgk_desc = out_g_n_k_wos_desc; + auto bias2_g_k_desc = HostTensorDescriptor({G, K}); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array bias1_ndhwgk_lengths{}; + std::array bias1_ndhwgk_strides{}; + std::array bias2_g_n_k_wos_lengths{}; + std::array bias2_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(out_g_n_k_wos_desc.GetLengths(), bias1_ndhwgk_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), bias1_ndhwgk_strides); + copy(out_g_n_k_wos_desc.GetLengths(), bias2_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), bias2_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + constexpr ck::index_t spatial_offset = 3; + bias2_g_n_k_wos_strides[1] = 0; + for(int i = 0; i < NDimSpatial; i++) + { + bias2_g_n_k_wos_strides[i + spatial_offset] = 0; + } + + Tensor input(in_g_n_c_wis_desc); + Tensor weight(wei_g_k_c_xs_desc); + Tensor c(out_g_n_k_wos_desc); + Tensor host_output(out_g_n_k_wos_desc); + Tensor device_output(out_g_n_k_wos_desc); + Tensor bias1(bias1_ndhwgk_desc); + Tensor bias2(bias2_g_k_desc); + + std::cout << "input: " << input.mDesc << std::endl; + std::cout << "weight: " << weight.mDesc << std::endl; + std::cout << "output: " << host_output.mDesc << std::endl; + std::cout << "bias1 (NDHWGK): " << bias1.mDesc << std::endl; + std::cout << "bias2 (G_K): " << bias2.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + input.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + weight.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + bias1.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + bias2.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + break; + default: + input.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}); + weight.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + bias1.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + bias2.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * weight.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpaceSize()); + DeviceMem bias1_device_buf(sizeof(BiasDataType) * bias1.mDesc.GetElementSpaceSize()); + DeviceMem bias2_device_buf(sizeof(BiasDataType) * bias2.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(input.mData.data()); + wei_device_buf.ToDevice(weight.mData.data()); + bias1_device_buf.ToDevice(bias1.mData.data()); + bias2_device_buf.ToDevice(bias2.mData.data()); + + // run reference op + if(do_verification) + { + std::cout << "\nVerifying algorithm against reference convolution..." << std::endl; + std::cout << "\tUsing (rel_tol,abs_tol) = (" << std::setprecision(7) + << get_rtol_scaleadd() << ", " << get_atol_scaleadd() + << ")" << std::endl; + + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd{}; + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(input, + weight, + c, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + PassThrough{}); + + c.SetZero(); + ref_invoker.Run(ref_argument); + + host_output.ForEach([&](auto&, auto idx) { + const auto g_idx = idx[0]; + const auto k_idx = idx[2]; + + const auto conv_shuffle = ck::type_convert(c(idx)); + + if constexpr(std::is_same_v) + { + const auto conv_val = ck::type_convert(conv_shuffle); + + const auto bias1_val = bias1(idx); + const auto bias2_val = bias2(g_idx, k_idx); + + OutDataType out_val{}; + out_element_op(out_val, conv_val, bias1_val, bias2_val); + + host_output(idx) = ck::type_convert(out_val); + } + else + { + const auto conv_val = conv_shuffle; + + const auto bias1_val = ck::type_convert(bias1(idx)); + const auto bias2_val = ck::type_convert(bias2(g_idx, k_idx)); + + CShuffleDataType out_val{}; + out_element_op(out_val, conv_val, bias1_val, bias2_val); + + host_output(idx) = ck::type_convert(out_val); + } + }); + } + + std::string best_op_name; + float best_avg_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + auto run_impl = [&](auto& op_ptr, auto& argument_ptr) { + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + out_device_buf.SetZero(); + + std::string op_name = op_ptr->GetTypeString(); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + float avg_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = conv_param.GetFlops(); + std::size_t num_btype = conv_param.GetByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + + float gb_per_sec = num_btype / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + out_device_buf.FromDevice(device_output.mData.data()); + + pass = pass & ck::utils::check_err(device_output, + host_output, + "Error: Device and Host results do not match!", + get_rtol_scaleadd(), + get_atol_scaleadd()); + + if(do_log) + { + LogRangeAsType(std::cout << "input : ", input.mData, ",") << std::endl; + LogRangeAsType(std::cout << "weight: ", weight.mData, ",") << std::endl; + LogRangeAsType(std::cout << "bias1: ", bias1.mData, ",") << std::endl; + LogRangeAsType(std::cout << "bias2: ", bias2.mData, ",") << std::endl; + LogRangeAsType(std::cout << "host_output : ", host_output.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "device_output: ", device_output.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; + } + }; + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + AComputeType, + BComputeType>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "ckProfiler found " << op_ptrs.size() << " instances" << std::endl; + + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer( + in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + {bias1_device_buf.GetDeviceBuffer(), bias2_device_buf.GetDeviceBuffer()}, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + {bias1_ndhwgk_lengths, bias2_g_n_k_wos_lengths}, + {bias1_ndhwgk_strides, bias2_g_n_k_wos_strides}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + run_impl(op_ptr, argument_ptr); + } + + std::cout << "Best configuration parameters:" << "\nname: " << best_op_name + << "\navg_time: " << best_avg_time << "\ntflops: " << best_tflops + << "\nGB/s: " << best_gb_per_sec << std::endl; + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 71f16376534..65911927ac9 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -79,7 +79,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]") list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu_add.cpp) list(APPEND PROFILER_OPS profile_conv_bwd_data.cpp) list(APPEND PROFILER_OPS profile_conv_fwd.cpp) - list(APPEND PROFILER_OPS profile_grouped_conv_fwd_outelementop.cpp) endif() if((SUPPORTED_GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)) OR @@ -99,9 +98,14 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]") list(APPEND PROFILER_OPS profile_grouped_conv_fwd.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bias_clamp.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_fwd_clamp.cpp) + list(APPEND PROFILER_OPS profile_grouped_conv_fwd_convscale_add.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_bwd_data.cpp) + list(APPEND PROFILER_OPS profile_grouped_conv_fwd_dynamic_op.cpp) + list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bilinear.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp) + list(APPEND PROFILER_OPS profile_grouped_conv_fwd_outelementop.cpp) list(APPEND PROFILER_OPS profile_gemm_multi_abd.cpp) + list(APPEND PROFILER_OPS profile_grouped_conv_fwd_scaleadd_scaleadd_relu.cpp) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND PROFILER_OPS profile_gemm_add_multiply.cpp) list(APPEND PROFILER_OPS profile_gemm_multiply_add.cpp) @@ -205,19 +209,12 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]") list(APPEND DEVICE_INSTANCES device_conv2d_fwd_instance) list(APPEND DEVICE_INSTANCES device_conv2d_fwd_bias_relu_instance) list(APPEND DEVICE_INSTANCES device_conv2d_fwd_bias_relu_add_instance) - list(APPEND DEVICE_INSTANCES device_grouped_conv1d_fwd_instance) list(APPEND DEVICE_INSTANCES device_conv1d_bwd_data_instance) list(APPEND DEVICE_INSTANCES device_conv3d_bwd_data_instance) list(APPEND DEVICE_INSTANCES device_conv2d_bwd_data_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv1d_bwd_weight_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_weight_instance) list(APPEND DEVICE_INSTANCES device_grouped_convnd_bwd_weight_instance) - list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convscale_instance) - list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convinvscale_instance) - list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_clamp_instance) - list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_clamp_instance) - list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_bias_clamp_instance) - list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bias_clamp_instance) endif() if((SUPPORTED_GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)) OR @@ -234,10 +231,24 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]") list(APPEND DEVICE_INSTANCES device_gemm_b_scale_instance) list(APPEND DEVICE_INSTANCES device_gemm_universal_reduce_instance) list(APPEND DEVICE_INSTANCES device_batched_gemm_b_scale_instance) - list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_data_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_data_instance) + list(APPEND DEVICE_INSTANCES device_grouped_conv1d_fwd_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_instance) + list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convinvscale_instance) + list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convscale_add_instance) + list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convscale_relu_instance) + list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convscale_instance) + list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_instance) + list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_clamp_instance) + list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_clamp_instance) + list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_scale_instance) + list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_scaleadd_scaleadd_relu_instance) + list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_bias_clamp_instance) + list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bias_clamp_instance) + list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_dynamic_op_instance) + list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_dynamic_op_instance) + list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bilinear_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance) list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance) list(APPEND DEVICE_INSTANCES device_gemm_multi_abd_instance) diff --git a/profiler/src/profile_grouped_conv_fwd_bilinear.cpp b/profiler/src/profile_grouped_conv_fwd_bilinear.cpp new file mode 100644 index 00000000000..d4490abe7ee --- /dev/null +++ b/profiler/src/profile_grouped_conv_fwd_bilinear.cpp @@ -0,0 +1,186 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "profiler/profile_grouped_conv_fwd_bilinear_impl.hpp" + +#include "ck/utility/data_type.hpp" +#include "ck/utility/ignore.hpp" +#include "profiler_operation_registry.hpp" + +#include + +enum struct ConvLayout +{ + GNHWC_GKYXC_GNHWK, // 0 + NHWGC_GKYXC_NHWGK, // 1 + NGCHW_GKYXC_NGKHW, // 2 + NGCHW_GKCYX_NGKHW, // 3 +}; + +enum struct ConvDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 + F8_F8_F8, // 4 + BF8_BF8_F8, // 5 + F8_BF8_F8, // 6 + BF8_F8_F8, // 7 +}; + +enum struct IndexType +{ + INDEX_T, // 0 + LONG_INDEX_T, // 1 +}; + +#define OP_NAME "grouped_conv_fwd_bilinear" +#define OP_DESC "Grouped Convolution Forward+Bilinear" + +static void print_helper_msg() +{ + std::cout + // clang-format off + << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" + << "arg2: data type (0: Input fp32, Weight fp32, Output fp32\n" + << " 1: Input fp16, Weight fp16, Output fp16\n" + << " 2: Input bf16, Weight bf16, Output bf16\n" + << " 3: Input int8, Weight int8, Output int8\n" + << " 4: Input fp8, Weight fp8, Output fp8\n" + << " 5: Input bf8, Weight bf8, Output fp8\n" + << " 6: Input fp8, Weight bf8, Output fp8\n" + << " 7: Input bf8, Weight fp8, Output fp8)\n" + << "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n" + << " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K]\n" + << " 2: Input[N, G, C, Hi, Wi], Weight[G, K, Y, X, C], Output[N, " + "G, K, Ho, Wo]\n" + << " 3: Input[N, G, C, Hi, Wi], Weight[G, K, C, Y, X], Output[N, " + "G, K, Ho, Wo])\n" + << "arg4: indexing data type (0: 32-bit, 1: 64-bit)\n" + << "arg5: verification (0: no, 1: yes)\n" + << "arg6: initialization (0: no init, 1: integer value, 2: decimal value)\n" + << "arg7: print tensor value (0: no; 1: yes)\n" + << "arg8: time kernel (0: no, 1: yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; + // clang-format on +} + +int grouped_conv_fwd_bilinear(int argc, char* argv[]) +{ + // 8 for control, 1 for num_dim_spatial + if(argc < 10) + { + print_helper_msg(); + return 1; + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const auto index_type = static_cast(std::stoi(argv[4])); + const bool do_verification = std::stoi(argv[5]); + const int init_method = std::stoi(argv[6]); + const bool do_log = std::stoi(argv[7]); + const bool time_kernel = std::stoi(argv[8]); + const int num_dim_spatial = std::stoi(argv[9]); + + // 9 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial + if(argc != 9 + 1 + 4 + 6 * num_dim_spatial) + { + print_helper_msg(); + return 1; + } + + const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 10, argv); + + if(index_type != IndexType::INDEX_T) + { + std::cout << "this indexing data type is not implemented" << std::endl; + return 1; + } + + using F32 = float; + using BF16 = ck::bhalf_t; + using F16 = ck::half_t; + + using GKZYXC = ck::tensor_layout::convolution::GKZYXC; + using NDHWGC = ck::tensor_layout::convolution::NDHWGC; + using NDHWGK = ck::tensor_layout::convolution::NDHWGK; + + using GKZYXC = ck::tensor_layout::convolution::GKZYXC; + using NDHWGC = ck::tensor_layout::convolution::NDHWGC; + using NDHWGK = ck::tensor_layout::convolution::NDHWGK; + + constexpr auto I3 = ck::Number<3>{}; + + auto profile = [&](auto num_dim_spatial_tmp, + auto in_layout, + auto wei_layout, + auto out_layout, + auto in_type, + auto wei_type, + auto out_type, + auto a_compute_type, + auto b_compute_type) { + constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value; + + using InLayout = decltype(in_layout); + using WeiLayout = decltype(wei_layout); + using OutLayout = decltype(out_layout); + + using InDataType = decltype(in_type); + using WeiDataType = decltype(wei_type); + using OutDataType = decltype(out_type); + + using AComputeType = decltype(a_compute_type); + using BComputeType = decltype(b_compute_type); + + // Create a Bilinear operation (requires two input tensors) + const auto bilinear_op = ck::tensor_operation::element_wise::Bilinear{}; + + bool pass = ck::profiler::profile_grouped_conv_fwd_bilinear_impl< + NDimSpatial, + InLayout, + WeiLayout, + OutLayout, // D layout same as output + OutLayout, + InDataType, + WeiDataType, + OutDataType, // D data type same as output + OutDataType, + AComputeType, + BComputeType, + ck::index_t>(do_verification, init_method, do_log, time_kernel, params, bilinear_op); + + return pass ? 0 : 1; + }; + + if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile( + I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::INT8_INT8_INT8) + { + return profile( + I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}); + } + } + + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, grouped_conv_fwd_bilinear); diff --git a/profiler/src/profile_grouped_conv_fwd_convscale_add.cpp b/profiler/src/profile_grouped_conv_fwd_convscale_add.cpp new file mode 100644 index 00000000000..99a2c67b4f0 --- /dev/null +++ b/profiler/src/profile_grouped_conv_fwd_convscale_add.cpp @@ -0,0 +1,165 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include + +#include "profiler/profile_grouped_conv_fwd_convscale_add_impl.hpp" +#include "profiler_operation_registry.hpp" + +using F8 = ck::f8_t; +using F32 = float; + +namespace { + +enum struct ConvLayout +{ + NDHWGC_GKZYXC_NDHWGK, // 0 + // NDHWGK_GKZYXC_NDHWGK, // 1 + // NHWGC_GKYXC_NHWGK, // 2 + // NHWGK_GKYXC_NHWGK, // 3 + // NWGC_GKXC_NWGK, // 4 + // NWGK_GKXC_NWGK, // 5 +}; + +enum struct ConvDataType +{ + F8_F8_F8, // 0 +}; + +enum struct IndexType +{ + INDEX_T, // 0 + LONG_INDEX_T, // 1 +}; + +#define OP_NAME "grouped_conv_fwd_convscale_add" +#define OP_DESC "Grouped Convolution Forward ConvScaleAdd" + +static void print_helper_msg() +{ + std::cout + // clang-format off + << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" + << "arg2: data type (0: Input f8, Weight f8, Output f8\n" + << "arg3: tensor layout (0: Input[N, Di, Hi, Wi, G, C], Weight[G, K, Z, Y, X, C], Output[N, Do, Ho, Wo, G, K])\n" + << "arg4: index type (0: INDEX_T, 1: LONG_INDEX_T)\n" + << "arg5: verification (0: no, 1: yes)\n" + << "arg6: initialization (0: no init, 1: integer value, 2: decimal value)\n" + << "arg7: print tensor value (0: no; 1: yes)\n" + << "arg8: time kernel (0=no, 1=yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; + // clang-format on +} + +int profile_grouped_conv_fwd_convscale_add(int argc, char* argv[]) +{ + // 8 for control, 1 for num_dim_spatial + if(argc < 10) + { + print_helper_msg(); + return 1; + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const auto index_type = static_cast(std::stoi(argv[4])); + const bool do_verification = std::stoi(argv[5]); + const int init_method = std::stoi(argv[6]); + const bool do_log = std::stoi(argv[7]); + const bool time_kernel = std::stoi(argv[8]); + const int num_dim_spatial = std::stoi(argv[9]); + + // 9 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial + if(argc != 9 + 1 + 4 + 6 * num_dim_spatial) + { + print_helper_msg(); + return 1; + } + + const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 10, argv); + + if(index_type != IndexType::INDEX_T) + { + std::cout << "this indexing data type is not implemented" << std::endl; + return 1; + } + + using F32 = float; + + using GKZYXC = ck::tensor_layout::convolution::GKZYXC; + using NDHWGC = ck::tensor_layout::convolution::NDHWGC; + using NDHWGK = ck::tensor_layout::convolution::NDHWGK; + + using GKZYXC = ck::tensor_layout::convolution::GKZYXC; + using NDHWGC = ck::tensor_layout::convolution::NDHWGC; + using NDHWGK = ck::tensor_layout::convolution::NDHWGK; + + constexpr auto I3 = ck::Number<3>{}; + + auto profile = [&](auto num_dim_spatial_tmp, + auto in_layout, + auto wei_layout, + auto d_layout, + auto out_layout, + auto in_type, + auto wei_type, + auto d_type, + auto out_type, + auto a_compute_type, + auto b_compute_type) { + constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value; + + using InLayout = decltype(in_layout); + using WeiLayout = decltype(wei_layout); + using DLayout = decltype(d_layout); + using OutLayout = decltype(out_layout); + + using InDataType = decltype(in_type); + using WeiDataType = decltype(wei_type); + using DDataType = decltype(d_type); + using OutDataType = decltype(out_type); + + using AComputeType = decltype(a_compute_type); + using BComputeType = decltype(b_compute_type); + + const auto convscaleadd_op = ck::tensor_operation::element_wise::ConvScaleAdd{}; + + bool pass = ck::profiler::profile_grouped_conv_fwd_convscale_add_impl( + do_verification, init_method, do_log, time_kernel, params, convscaleadd_op); + + return pass ? 0 : 1; + }; + + if(num_dim_spatial == 3 && layout == ConvLayout::NDHWGC_GKZYXC_NDHWGK) + { + + if(data_type == ConvDataType::F8_F8_F8) + { + return profile( + I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, NDHWGK{}, F8{}, F8{}, F32{}, F8{}, F8{}, F8{}); + // I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, NDHWGK{}, F8{}, F8{}, F32{}, F8{}, F32{}, F32{}); + } + } + + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; +} + +} // namespace + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_grouped_conv_fwd_convscale_add); diff --git a/profiler/src/profile_grouped_conv_fwd_dynamic_op.cpp b/profiler/src/profile_grouped_conv_fwd_dynamic_op.cpp new file mode 100644 index 00000000000..583cd5b6d5f --- /dev/null +++ b/profiler/src/profile_grouped_conv_fwd_dynamic_op.cpp @@ -0,0 +1,207 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "profiler/profile_grouped_conv_fwd_impl.hpp" + +#include "ck/utility/data_type.hpp" +#include "ck/utility/ignore.hpp" +#include "profiler_operation_registry.hpp" + +#include + +enum struct ConvLayout +{ + GNHWC_GKYXC_GNHWK, // 0 + NHWGC_GKYXC_NHWGK, // 1 + NGCHW_GKYXC_NGKHW, // 2 + NGCHW_GKCYX_NGKHW, // 3 +}; + +enum struct ConvDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 + F8_F8_F8, // 4 + BF8_BF8_F8, // 5 + F8_BF8_F8, // 6 + BF8_F8_F8, // 7 +}; + +enum struct IndexType +{ + INDEX_T, // 0 + LONG_INDEX_T, // 1 +}; + +#define OP_NAME "grouped_conv_fwd_dynamic_op" +#define OP_DESC "Grouped Convolution Forward+DynamicUnaryOp" + +static void print_helper_msg() +{ + std::cout + // clang-format off + << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" + << "arg2: data type (0: Input fp32, Weight fp32, Output fp32\n" + << " 1: Input fp16, Weight fp16, Output fp16\n" + << " 2: Input bf16, Weight bf16, Output bf16\n" + << " 3: Input int8, Weight int8, Output int8\n" + << " 4: Input fp8, Weight fp8, Output fp8\n" + << " 5: Input bf8, Weight bf8, Output fp8\n" + << " 6: Input fp8, Weight bf8, Output fp8\n" + << " 7: Input bf8, Weight fp8, Output fp8)\n" + << "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n" + << " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K]\n" + << " 2: Input[N, G, C, Hi, Wi], Weight[G, K, Y, X, C], Output[N, " + "G, K, Ho, Wo]\n" + << " 3: Input[N, G, C, Hi, Wi], Weight[G, K, C, Y, X], Output[N, " + "G, K, Ho, Wo])\n" + << "arg4: indexing data type (0: 32-bit, 1: 64-bit)\n" + << "arg5: verification (0: no, 1: yes)\n" + << "arg6: initialization (0: no init, 1: integer value, 2: decimal value)\n" + << "arg7: print tensor value (0: no; 1: yes)\n" + << "arg8: time kernel (0: no, 1: yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; + // clang-format on +} + +int grouped_conv_fwd_dynamic_op(int argc, char* argv[]) +{ + // 8 for control, 1 for num_dim_spatial + if(argc < 10) + { + print_helper_msg(); + return 1; + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const auto index_type = static_cast(std::stoi(argv[4])); + const bool do_verification = std::stoi(argv[5]); + const int init_method = std::stoi(argv[6]); + const bool do_log = std::stoi(argv[7]); + const bool time_kernel = std::stoi(argv[8]); + const int num_dim_spatial = std::stoi(argv[9]); + + // 9 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial + if(argc != 9 + 1 + 4 + 6 * num_dim_spatial) + { + print_helper_msg(); + return 1; + } + + const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 10, argv); + + if(index_type != IndexType::INDEX_T) + { + std::cout << "this indexing data type is not implemented" << std::endl; + return 1; + } + + using F32 = float; + using BF16 = ck::bhalf_t; + using F16 = ck::half_t; + + using GKZYXC = ck::tensor_layout::convolution::GKZYXC; + using NDHWGC = ck::tensor_layout::convolution::NDHWGC; + using NDHWGK = ck::tensor_layout::convolution::NDHWGK; + + using GKYXC = ck::tensor_layout::convolution::GKYXC; + using NHWGC = ck::tensor_layout::convolution::NHWGC; + using NHWGK = ck::tensor_layout::convolution::NHWGK; + + constexpr auto I2 = ck::Number<2>{}; + constexpr auto I3 = ck::Number<3>{}; + + auto profile = [&](auto num_dim_spatial_tmp, + auto in_layout, + auto wei_layout, + auto out_layout, + auto in_type, + auto wei_type, + auto out_type, + auto a_compute_type, + auto b_compute_type) { + constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value; + + using InLayout = decltype(in_layout); + using WeiLayout = decltype(wei_layout); + using OutLayout = decltype(out_layout); + + using InDataType = decltype(in_type); + using WeiDataType = decltype(wei_type); + using OutDataType = decltype(out_type); + + using AComputeType = decltype(a_compute_type); + using BComputeType = decltype(b_compute_type); + + const auto dynamic_op = ck::tensor_operation::element_wise::DynamicUnaryOp{ + ck::tensor_operation::element_wise::PassThrough{}}; + + bool pass = ck::profiler::profile_grouped_conv_fwd_impl< + NDimSpatial, + InLayout, + WeiLayout, + OutLayout, + InDataType, + WeiDataType, + OutDataType, + AComputeType, + BComputeType, + ck::index_t, + ck::tensor_operation::element_wise::DynamicUnaryOp>( + do_verification, init_method, do_log, time_kernel, params, dynamic_op); + + return pass ? 0 : 1; + }; + + if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::INT8_INT8_INT8) + { + return profile( + I2, NHWGC{}, GKYXC{}, NHWGK{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}); + } + } + else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile( + I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::INT8_INT8_INT8) + { + return profile( + I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}); + } + } + + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, grouped_conv_fwd_dynamic_op); diff --git a/profiler/src/profile_grouped_conv_fwd_outelementop.cpp b/profiler/src/profile_grouped_conv_fwd_outelementop.cpp index 00b4bd8f133..f0985088ed1 100644 --- a/profiler/src/profile_grouped_conv_fwd_outelementop.cpp +++ b/profiler/src/profile_grouped_conv_fwd_outelementop.cpp @@ -17,16 +17,24 @@ enum struct ConvLayout enum struct OutElementOp { - ConvScale = 0, - ConvInvScale = 1 + ConvScale = 0, + ConvInvScale = 1, + CombConvScale = 2, + ConvScaleRelu = 3, + Scale = 4, + CombConvScaleRelu = 5 }; enum struct ConvDataType { - F8_F8_F8 = 0, - BF8_BF8_F8 = 1, - F8_BF8_F8 = 2, - BF8_F8_F8 = 3 + F8_F8_F8 = 0, + BF8_BF8_F8 = 1, + F8_BF8_F8 = 2, + BF8_F8_F8 = 3, + F16_F16_F16 = 4, + BF16_BF16_BF16 = 5, + I8_I8_I8 = 6, + F8_F8_F32 = 7 }; #define OP_NAME "grouped_conv_fwd_outelementop" @@ -41,8 +49,16 @@ static void print_helper_msg() << " 1: Input bf8, Weight bf8, Output fp8\n" << " 2: Input fp8, Weight bf8, Output fp8\n" << " 3: Input bf8, Weight fp8, Output fp8)\n" + << " 4: Input f16, Weight f16, Output f16)\n" + << " 5: Input bf16, Weight bf16, Output bf16)\n" + << " 6: Input i8, Weight i8, Output i8)\n" + << " 7: Input f8, Weight f8, Output f32)\n" << "arg3: element-wise operation (0: ConvScale\n" - << " 1: ConvInvScale)\n" + << " 1: ConvInvScale\n" + << " 2: CombConvScale\n" + << " 3: ConvScaleRelu\n" + << " 4: Scale\n" + << " 5: CombConvScaleRelu)\n" << "arg4: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n" << " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])\n" << "arg5: verification (0: no, 1: yes)\n" @@ -81,15 +97,23 @@ int grouped_conv_fwd_outelementop(int argc, char* argv[]) const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 10, argv); - using F8 = ck::f8_t; - using BF8 = ck::bf8_t; + using F8 = ck::f8_t; + using F16 = ck::half_t; + using F32 = float; + using BF8 = ck::bf8_t; + using BF16 = ck::bhalf_t; + using I8 = int8_t; using GKZYXC = ck::tensor_layout::convolution::GKZYXC; using NDHWGC = ck::tensor_layout::convolution::NDHWGC; using NDHWGK = ck::tensor_layout::convolution::NDHWGK; - using ConvScale = ck::tensor_operation::element_wise::ConvScale; - using ConvInvScale = ck::tensor_operation::element_wise::ConvInvscale; + using ConvScale = ck::tensor_operation::element_wise::ConvScale; + using ConvInvScale = ck::tensor_operation::element_wise::ConvInvscale; + using CombConvScale = ck::tensor_operation::element_wise::ScaleScalePass; + using ConvScaleRelu = ck::tensor_operation::element_wise::ConvScaleRelu; + using Scale = ck::tensor_operation::element_wise::Scale; + using CombConvScaleRelu = ck::tensor_operation::element_wise::ScaleScaleRelu; constexpr auto I3 = ck::Number<3>{}; @@ -173,6 +197,22 @@ int grouped_conv_fwd_outelementop(int argc, char* argv[]) return profile( I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, F8{}, F8{}, ConvInvScale{}, F8{}, F8{}); } + } + else if(op == OutElementOp::CombConvScale) + { + if(data_type == ConvDataType::F8_F8_F8) + { + return profile(I3, + NDHWGC{}, + GKZYXC{}, + NDHWGK{}, + F8{}, + F8{}, + F8{}, + CombConvScale{}, + F8{}, + F8{}); + } else if(data_type == ConvDataType::BF8_BF8_F8) { return profile(I3, @@ -182,7 +222,7 @@ int grouped_conv_fwd_outelementop(int argc, char* argv[]) BF8{}, BF8{}, F8{}, - ConvInvScale{}, + CombConvScale{}, BF8{}, BF8{}); } @@ -195,7 +235,7 @@ int grouped_conv_fwd_outelementop(int argc, char* argv[]) F8{}, BF8{}, F8{}, - ConvInvScale{}, + CombConvScale{}, F8{}, BF8{}); } @@ -208,11 +248,69 @@ int grouped_conv_fwd_outelementop(int argc, char* argv[]) BF8{}, F8{}, F8{}, - ConvInvScale{}, + CombConvScale{}, BF8{}, F8{}); } } + else if(op == OutElementOp::ConvScaleRelu) + { + if(data_type == ConvDataType::F8_F8_F8) + { + return profile(I3, + NDHWGC{}, + GKZYXC{}, + NDHWGK{}, + F8{}, + F8{}, + F8{}, + ConvScaleRelu{}, + F8{}, + F8{}); + } + } + else if(op == OutElementOp::Scale) + { + if(data_type == ConvDataType::F16_F16_F16) + { + return profile( + I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, Scale{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile(I3, + NDHWGC{}, + GKZYXC{}, + NDHWGK{}, + BF16{}, + BF16{}, + BF16{}, + Scale{}, + BF16{}, + BF16{}); + } + else if(data_type == ConvDataType::I8_I8_I8) + { + return profile( + I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, I8{}, I8{}, I8{}, Scale{}, I8{}, I8{}); + } + } + else if(op == OutElementOp::CombConvScaleRelu) + { + if(data_type == ConvDataType::F8_F8_F32) + { + return profile(I3, + NDHWGC{}, + GKZYXC{}, + NDHWGK{}, + F8{}, + F8{}, + F32{}, + CombConvScaleRelu{}, + F8{}, + F8{}); + } + } } std::cout << "this data_type & layout is not implemented" << std::endl; diff --git a/profiler/src/profile_grouped_conv_fwd_scaleadd_scaleadd_relu.cpp b/profiler/src/profile_grouped_conv_fwd_scaleadd_scaleadd_relu.cpp new file mode 100644 index 00000000000..26c871ccf24 --- /dev/null +++ b/profiler/src/profile_grouped_conv_fwd_scaleadd_scaleadd_relu.cpp @@ -0,0 +1,183 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "profiler/profile_grouped_conv_fwd_scaleadd_scaleadd_relu_impl.hpp" + +#include "ck/utility/data_type.hpp" +#include "profiler_operation_registry.hpp" + +#include + +enum struct ConvLayout +{ + GNHWC_GKYXC_GNHWK = 0, + NHWGC_GKYXC_NHWGK = 1 +}; + +enum struct OutElementOp +{ + ScaleAddScaleAddRelu = 0 +}; + +enum struct ConvDataType +{ + I8_I8_I8 = 1, + F16_F16_F16 = 2, + BF16_BF16_BF16 = 3 +}; + +#define OP_NAME "grouped_conv_fwd_scaleadd_scaleadd_relu_wmma" +#define OP_DESC "Grouped Convolution Forward+ScaleAddScaleAddRelu Operation (WMMA)" + +static void print_helper_msg() +{ + // clang-format off + std::cout + << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" + << "arg2: data type (1: Input i8, Weight i8, Output i8\n" + << " 2: Input f16, Weight f16, Output f16\n" + << " 3: Input bf16, Weight bf16, Output bf16)\n" + << "arg3: element-wise operation (0: ScaleAddScaleAddRelu)\n" + << "arg4: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n" + << " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])\n" + << "arg5: verification (0: no, 1: yes)\n" + << "arg6: initialization (0: no init, 1: integer value, 2: decimal value)\n" + << "arg7: print tensor value (0: no; 1: yes)\n" + << "arg8: time kernel (0: no, 1: yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; + // clang-format on +} + +int grouped_conv_fwd_scaleadd_scaleadd_relu_wmma(int argc, char* argv[]) +{ + // 9 total, 1 for num_dim_spatial + if(argc < 10) + { + print_helper_msg(); + return 1; + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto op = static_cast(std::stoi(argv[3])); + const auto layout = static_cast(std::stoi(argv[4])); + const bool do_verification = std::stoi(argv[5]); + const int init_method = std::stoi(argv[6]); + const bool do_log = std::stoi(argv[7]); + const bool time_kernel = std::stoi(argv[8]); + const int num_dim_spatial = std::stoi(argv[9]); + + // 8 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial + 1 for argv[0] + if(argc != 8 + 1 + 4 + 6 * num_dim_spatial + 1) + { + print_helper_msg(); + return 1; + } + + const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 10, argv); + + using I8 = int8_t; + using F16 = ck::half_t; + using BF16 = ck::bhalf_t; + + using GKZYXC = ck::tensor_layout::convolution::GKZYXC; + using NDHWGC = ck::tensor_layout::convolution::NDHWGC; + using NDHWGK = ck::tensor_layout::convolution::NDHWGK; + + using ScaleAddScaleAddRelu = ck::tensor_operation::element_wise::ScaleAddScaleAddRelu; + + constexpr auto I3 = ck::Number<3>{}; + + auto profile = [&](auto num_dim_spatial_tmp, + auto in_layout, + auto wei_layout, + auto out_layout, + auto in_type, + auto wei_type, + auto out_type, + auto out_element_op, + auto a_compute_type, + auto b_compute_type) { + constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value; + + using InLayout = decltype(in_layout); + using WeiLayout = decltype(wei_layout); + using OutLayout = decltype(out_layout); + + using InDataType = decltype(in_type); + using WeiDataType = decltype(wei_type); + using OutDataType = decltype(out_type); + + using OutElementOp = decltype(out_element_op); + + using AComputeType = decltype(a_compute_type); + using BComputeType = decltype(b_compute_type); + + bool pass = + ck::profiler::profile_grouped_conv_fwd_scaleadd_scaleadd_relu_impl( + do_verification, init_method, do_log, time_kernel, params); + + return pass ? 0 : 1; + }; + + if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) + { + if(op == OutElementOp::ScaleAddScaleAddRelu) + { + if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I3, + NDHWGC{}, + GKZYXC{}, + NDHWGK{}, + F16{}, + F16{}, + F16{}, + ScaleAddScaleAddRelu{}, + F16{}, + F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile(I3, + NDHWGC{}, + GKZYXC{}, + NDHWGK{}, + BF16{}, + BF16{}, + BF16{}, + ScaleAddScaleAddRelu{}, + BF16{}, + BF16{}); + } + else if(data_type == ConvDataType::I8_I8_I8) + { + return profile(I3, + NDHWGC{}, + GKZYXC{}, + NDHWGK{}, + I8{}, + I8{}, + I8{}, + ScaleAddScaleAddRelu{}, + I8{}, + I8{}); + } + } + } + + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, grouped_conv_fwd_scaleadd_scaleadd_relu_wmma); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 802f29024c6..a802b87b05b 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -33,6 +33,7 @@ set(REGRESSION_TESTS test_convnd_fwd test_convnd_bwd_data test_grouped_convnd_fwd + test_grouped_convnd_fwd_scaleadd_ab test_grouped_convnd_bwd_weight test_softmax_rank3 test_softmax_rank4 diff --git a/test/grouped_convnd_fwd/CMakeLists.txt b/test/grouped_convnd_fwd/CMakeLists.txt index ab52d12bb09..1e736f04790 100644 --- a/test/grouped_convnd_fwd/CMakeLists.txt +++ b/test/grouped_convnd_fwd/CMakeLists.txt @@ -4,6 +4,15 @@ if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_gtest_executable(test_grouped_convnd_fwd test_grouped_convnd_fwd.cpp) target_link_libraries(test_grouped_convnd_fwd PRIVATE utility device_grouped_conv1d_fwd_instance device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance) + + add_gtest_executable(test_grouped_convnd_fwd_dynamic_op test_grouped_convnd_fwd_dynamic_op.cpp) + target_link_libraries(test_grouped_convnd_fwd_dynamic_op PRIVATE utility device_grouped_conv2d_fwd_dynamic_op_instance device_grouped_conv3d_fwd_dynamic_op_instance) + + add_gtest_executable(test_grouped_convnd_fwd_bilinear test_grouped_convnd_fwd_bilinear.cpp) + target_link_libraries(test_grouped_convnd_fwd_bilinear PRIVATE utility device_grouped_conv3d_fwd_bilinear_instance) + + add_gtest_executable(test_grouped_convnd_fwd_scaleadd_ab test_grouped_convnd_fwd_scaleadd_ab.cpp) + target_link_libraries(test_grouped_convnd_fwd_scaleadd_ab PRIVATE utility device_grouped_conv3d_fwd_scaleadd_ab_instance) add_executable(test_grouped_convnd_fwd_large_cases_xdl test_grouped_convnd_fwd_large_cases_xdl.cpp) target_compile_options(test_grouped_convnd_fwd_large_cases_xdl PRIVATE -Wno-global-constructors -Wno-undef) diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp index b5c5248df5d..82a66cda9ca 100644 --- a/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp +++ b/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp @@ -7,20 +7,32 @@ #include #include +#include "ck/utility/common_header.hpp" +#include "ck/host_utility/device_prop.hpp" #include "profiler/profile_grouped_conv_fwd_impl.hpp" static ck::index_t param_mask = 0xffff; static ck::index_t instance_index = -1; +using I8 = int8_t; +using F8 = ck::f8_t; +using BF8 = ck::bf8_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; template class TestGroupedConvndFwd : public ::testing::Test { protected: - using DataType = std::tuple_element_t<0, Tuple>; - using InLayout = std::tuple_element_t<1, Tuple>; - using WeiLayout = std::tuple_element_t<2, Tuple>; - using OutLayout = std::tuple_element_t<3, Tuple>; - using IndexType = ck::index_t; + using InDataType = std::tuple_element_t<0, Tuple>; + using WeiDataType = std::tuple_element_t<1, Tuple>; + using OutDataType = std::tuple_element_t<2, Tuple>; + using AComputeType = std::tuple_element_t<3, Tuple>; + using BComputeType = std::tuple_element_t<4, Tuple>; + using InLayout = std::tuple_element_t<5, Tuple>; + using WeiLayout = std::tuple_element_t<6, Tuple>; + using OutLayout = std::tuple_element_t<7, Tuple>; + using IndexType = ck::index_t; std::vector conv_params; @@ -35,16 +47,25 @@ class TestGroupedConvndFwd : public ::testing::Test { continue; } + // FP8 workaround for CDNA1/2 is currently broken, do not test. + if(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a") + { + if(std::is_same::value || std::is_same::value) + { + printf("Skipping FP8 / BF8 tests on CDNA1/2.\n"); + continue; + } + } auto& param = conv_params[i]; pass = pass && ck::profiler::profile_grouped_conv_fwd_impl( true, // do_verification 1, // init_method: integer value @@ -60,36 +81,43 @@ class TestGroupedConvndFwd : public ::testing::Test using namespace ck::tensor_layout::convolution; -using KernelTypes1d = ::testing::Types, - std::tuple, - std::tuple, - std::tuple>; - -using KernelTypes2d = ::testing::Types, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple>; - -using KernelTypes3d = ::testing::Types, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple>; +using KernelTypes1d = ::testing::Types, + std::tuple, + std::tuple, + std::tuple>; + +using KernelTypes2d = + ::testing::Types, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple>; + +using KernelTypes3d = + ::testing::Types, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple>; template class TestGroupedConvndFwd1d : public TestGroupedConvndFwd @@ -125,15 +153,32 @@ TYPED_TEST(TestGroupedConvndFwd1d, Test1D) TYPED_TEST(TestGroupedConvndFwd2d, Test2D) { this->conv_params.clear(); + this->conv_params.push_back( + {2, 3, 5, 96, 200, {1, 1}, {73, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( + {2, 1, 1, 32, 32, {1, 1}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( + {2, 1, 1, 32, 32, {2, 2}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( + {2, 1, 1, 32, 32, {3, 3}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( + {2, 1, 1, 32, 32, {5, 5}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( + {2, 1, 1, 32, 32, {9, 9}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( {2, 2, 32, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); this->conv_params.push_back( {2, 2, 32, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back( {2, 2, 32, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); this->conv_params.push_back({2, 1, 1, 1, 32, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->conv_params.push_back({2, 1, 1, 64, 3, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->conv_params.push_back({2, 1, 1, 1, 1, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + + this->conv_params.push_back( + {2, 96, 1, 1, 1, {1, 1}, {120, 160}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->conv_params.push_back( {2, 96, 1, 1, 1, {3, 3}, {120, 160}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->template Run<2>(); @@ -142,20 +187,41 @@ TYPED_TEST(TestGroupedConvndFwd2d, Test2D) TYPED_TEST(TestGroupedConvndFwd3d, Test3D) { this->conv_params.clear(); + + this->conv_params.push_back( + {3, 3, 5, 96, 200, {1, 1, 1}, {37, 37, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {1, 1, 1}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {2, 2, 2}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {5, 5, 5}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {9, 9, 9}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( {3, 2, 32, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); this->conv_params.push_back( {3, 2, 32, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( {3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {1, 1, 1}, {16, 16, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( {3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); this->conv_params.push_back( {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); this->conv_params.push_back( {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + + this->conv_params.push_back( + {3, 96, 1, 1, 1, {1, 1, 1}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); this->conv_params.push_back( - {3, 96, 1, 1, 1, {3, 3, 3}, {4, 30, 160}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + {3, 96, 1, 1, 1, {3, 3, 3}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); this->template Run<3>(); } diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_bilinear.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_bilinear.cpp new file mode 100644 index 00000000000..1b37f5eb4ee --- /dev/null +++ b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_bilinear.cpp @@ -0,0 +1,134 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "profiler/profile_grouped_conv_fwd_bilinear_impl.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" + +using I8 = int8_t; +using F8 = ck::f8_t; +using BF8 = ck::bf8_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; + +template +class TestGroupedConvndFwdBilinear : public ::testing::Test +{ + protected: + using InDataType = std::tuple_element_t<0, Tuple>; + using WeiDataType = std::tuple_element_t<1, Tuple>; + using OutDataType = std::tuple_element_t<2, Tuple>; + using AComputeType = std::tuple_element_t<3, Tuple>; + using BComputeType = std::tuple_element_t<4, Tuple>; + using InLayout = std::tuple_element_t<5, Tuple>; + using WeiLayout = std::tuple_element_t<6, Tuple>; + using OutLayout = std::tuple_element_t<7, Tuple>; + using IndexType = ck::index_t; + + std::vector conv_params; + + template + void Run() + { + EXPECT_FALSE(conv_params.empty()); + bool pass = true; + + // Create a Bilinear operation (binary element-wise operation) + const auto bilinear_op = ck::tensor_operation::element_wise::Bilinear{}; + + for(auto& param : conv_params) + { + if(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a") + { + if(std::is_same::value || std::is_same::value) + { + printf("Skipping FP8 / BF8 tests on CDNA1/2.\n"); + continue; + } + } + pass = pass && ck::profiler::profile_grouped_conv_fwd_bilinear_impl< + NDimSpatial, + InLayout, + WeiLayout, + OutLayout, // D layout same as output + OutLayout, + InDataType, + WeiDataType, + OutDataType, // D data type same as output + OutDataType, + AComputeType, + BComputeType, + IndexType>(true, // do_verification + 1, // init_method: integer value + false, // do_log + true, // time_kernel + param, + bilinear_op); + } + EXPECT_TRUE(pass); + } +}; + +using namespace ck::tensor_layout::convolution; + +using KernelTypes3d = + ::testing::Types, + std::tuple, + std::tuple, + std::tuple>; + +template +class TestGroupedConvndFwdBilinear3d : public TestGroupedConvndFwdBilinear +{ +}; + +TYPED_TEST_SUITE(TestGroupedConvndFwdBilinear3d, KernelTypes3d); + +TYPED_TEST(TestGroupedConvndFwdBilinear3d, Test3D) +{ + this->conv_params.clear(); + + this->conv_params.push_back( + {3, 3, 5, 96, 200, {1, 1, 1}, {37, 37, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {1, 1, 1}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {2, 2, 2}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {5, 5, 5}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {9, 9, 9}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + + this->conv_params.push_back( + {3, 2, 32, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 2, 32, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + + this->conv_params.push_back( + {3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {1, 1, 1}, {16, 16, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + + this->conv_params.push_back( + {3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + + this->conv_params.push_back( + {3, 96, 1, 1, 1, {1, 1, 1}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 96, 1, 1, 1, {3, 3, 3}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->template Run<3>(); +} diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_dynamic_op.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_dynamic_op.cpp new file mode 100644 index 00000000000..43485f8171e --- /dev/null +++ b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_dynamic_op.cpp @@ -0,0 +1,180 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "profiler/profile_grouped_conv_fwd_impl.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +using I8 = int8_t; +using F8 = ck::f8_t; +using BF8 = ck::bf8_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; + +template +class TestGroupedConvndFwdDynamicOp : public ::testing::Test +{ + protected: + using InDataType = std::tuple_element_t<0, Tuple>; + using WeiDataType = std::tuple_element_t<1, Tuple>; + using OutDataType = std::tuple_element_t<2, Tuple>; + using AComputeType = std::tuple_element_t<3, Tuple>; + using BComputeType = std::tuple_element_t<4, Tuple>; + using InLayout = std::tuple_element_t<5, Tuple>; + using WeiLayout = std::tuple_element_t<6, Tuple>; + using OutLayout = std::tuple_element_t<7, Tuple>; + using IndexType = ck::index_t; + + std::vector conv_params; + + template + void Run() + { + EXPECT_FALSE(conv_params.empty()); + bool pass = true; + + const auto dynamic_op = ck::tensor_operation::element_wise::DynamicUnaryOp{ + ck::tensor_operation::element_wise::PassThrough{}}; + + for(auto& param : conv_params) + { + if(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a") + { + if(std::is_same::value || std::is_same::value) + { + printf("Skipping FP8 / BF8 tests on CDNA1/2.\n"); + continue; + } + } + pass = pass && ck::profiler::profile_grouped_conv_fwd_impl< + NDimSpatial, + InLayout, + WeiLayout, + OutLayout, + InDataType, + WeiDataType, + OutDataType, + AComputeType, + BComputeType, + IndexType, + ck::tensor_operation::element_wise::DynamicUnaryOp>( + true, // do_verification + 1, // init_method: integer value + false, // do_log + true, // time_kernel + param, + dynamic_op); + } + EXPECT_TRUE(pass); + } +}; + +using namespace ck::tensor_layout::convolution; + +using KernelTypes2d = + ::testing::Types, + std::tuple, + std::tuple, + std::tuple>; + +using KernelTypes3d = + ::testing::Types, + std::tuple, + std::tuple, + std::tuple>; + +template +class TestGroupedConvndFwdDynamicOp2d : public TestGroupedConvndFwdDynamicOp +{ +}; + +template +class TestGroupedConvndFwdDynamicOp3d : public TestGroupedConvndFwdDynamicOp +{ +}; + +TYPED_TEST_SUITE(TestGroupedConvndFwdDynamicOp2d, KernelTypes2d); +TYPED_TEST_SUITE(TestGroupedConvndFwdDynamicOp3d, KernelTypes3d); + +TYPED_TEST(TestGroupedConvndFwdDynamicOp2d, Test2D) +{ + this->conv_params.clear(); + this->conv_params.push_back( + {2, 3, 5, 96, 200, {1, 1}, {73, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( + {2, 1, 1, 32, 32, {1, 1}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( + {2, 1, 1, 32, 32, {2, 2}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( + {2, 1, 1, 32, 32, {3, 3}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( + {2, 1, 1, 32, 32, {5, 5}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( + {2, 1, 1, 32, 32, {9, 9}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + + this->conv_params.push_back( + {2, 2, 32, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( + {2, 2, 32, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + + this->conv_params.push_back( + {2, 2, 32, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back({2, 1, 1, 1, 32, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back({2, 1, 1, 64, 3, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back({2, 1, 1, 1, 1, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + + this->conv_params.push_back( + {2, 96, 1, 1, 1, {1, 1}, {120, 160}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back( + {2, 96, 1, 1, 1, {3, 3}, {120, 160}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->template Run<2>(); +} + +TYPED_TEST(TestGroupedConvndFwdDynamicOp3d, Test3D) +{ + this->conv_params.clear(); + + this->conv_params.push_back( + {3, 3, 5, 96, 200, {1, 1, 1}, {37, 37, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {1, 1, 1}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {2, 2, 2}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {5, 5, 5}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {9, 9, 9}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + + this->conv_params.push_back( + {3, 2, 32, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 2, 32, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + + this->conv_params.push_back( + {3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {1, 1, 1}, {16, 16, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + + this->conv_params.push_back( + {3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + + this->conv_params.push_back( + {3, 96, 1, 1, 1, {1, 1, 1}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 96, 1, 1, 1, {3, 3, 3}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->template Run<3>(); +} diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_scaleadd_ab.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_scaleadd_ab.cpp new file mode 100644 index 00000000000..ab7a28a388f --- /dev/null +++ b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_scaleadd_ab.cpp @@ -0,0 +1,411 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck/utility/data_type.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" + +using I8 = int8_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; + +// This is pretty much a fully functional profiler function, but I only implemented it here to add a +// proper gtest test for the scaleadd_ab flavor. At some point we may want to move this and add it +// to the ckProfiler. +template +bool profile_grouped_conv_fwd_scaleadd_ab_impl(int do_verification, + int init_method, + bool do_log, + [[maybe_unused]] bool time_kernel, + const ck::utils::conv::ConvParam& conv_param) +{ + constexpr ck::index_t NumAs = 2; + constexpr ck::index_t NumBs = 2; + using InElementOp = ck::tensor_operation::element_wise::ScaleAdd; + using WeiElementOp = ck::tensor_operation::element_wise::ScaleAdd; + using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + + constexpr float scale = 1.5f; + + const auto in_element_op = InElementOp{scale}; + const auto wei_element_op = WeiElementOp{scale}; + const auto out_element_op = OutElementOp{}; + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + ck::Tensor input(in_g_n_c_wis_desc); + ck::Tensor input_bias(in_g_n_c_wis_desc); + ck::Tensor weight(wei_g_k_c_xs_desc); + ck::Tensor weight_bias(wei_g_k_c_xs_desc); + ck::Tensor host_output(out_g_n_k_wos_desc); + ck::Tensor device_output(out_g_n_k_wos_desc); + + std::cout << "input: " << input.mDesc << std::endl; + std::cout << "weight: " << weight.mDesc << std::endl; + std::cout << "output: " << host_output.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + input.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + input_bias.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + weight.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + weight_bias.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + input.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + input_bias.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + weight.GenerateTensorValue(GeneratorTensor_3{-0.05, 0.05}); + weight_bias.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + } + + ck::DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpaceSize()); + ck::DeviceMem in_bias_device_buf(sizeof(InDataType) * input_bias.mDesc.GetElementSpaceSize()); + ck::DeviceMem wei_device_buf(sizeof(WeiDataType) * weight.mDesc.GetElementSpaceSize()); + ck::DeviceMem wei_bias_device_buf(sizeof(WeiDataType) * + weight_bias.mDesc.GetElementSpaceSize()); + ck::DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(input.mData.data()); + in_bias_device_buf.ToDevice(input_bias.mData.data()); + wei_device_buf.ToDevice(weight.mData.data()); + wei_bias_device_buf.ToDevice(weight_bias.mData.data()); + + // Run reference op + if(do_verification) + { + const std::array, NumAs - 1> elementwise_a_tensors = {input_bias}; + const std::array, NumBs - 1> elementwise_b_tensors = {weight_bias}; + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(input, + weight, + host_output, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op, + elementwise_a_tensors, + elementwise_b_tensors); + + // init host output to zero + host_output.SetZero(); + + ref_invoker.Run(ref_argument); + } + + std::string best_op_name; + float best_avg_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + int valids = 0; + + // profile device op instances + bool pass = true; + + auto run_impl = [&](auto& op_ptr, auto& argument_ptr) { + // workspace_sz will be equal to 0 for other layout than NGCHW + // TODO: Is workspace even necessary? + const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + ck::DeviceMem workspace_dev(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + // re-init output to zero before profiling next kernel + out_device_buf.SetZero(); + + valids++; + + std::string op_name = op_ptr->GetTypeString(); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + float avg_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = conv_param.GetFlops() + + 2 * conv_param.GetOutputByte() / sizeof(InDataType) + + 2 * conv_param.GetOutputByte() / sizeof(WeiDataType); + std::size_t num_btype = conv_param.GetByte() + + conv_param.GetInputByte() + + conv_param.GetWeightByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + + float gb_per_sec = num_btype / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + out_device_buf.FromDevice(device_output.mData.data()); + + pass = pass & ck::utils::check_err(device_output, host_output); + + if(do_log) + { + printf("log\n"); + // LogRangeAsType(std::cout << "input : ", input.mData, ",") << + // std::endl; LogRangeAsType(std::cout << "input_bias: ", + // input_bias.mData, ",") + // << std::endl; + // LogRangeAsType(std::cout << "weight: ", weight.mData, ",") << + // std::endl; LogRangeAsType(std::cout << "weight_bias: ", + // weight_bias.mData, ",") + // << std::endl; + // LogRangeAsType(std::cout << "host_output : ", host_output.mData, + // ",") + // << std::endl; + // LogRangeAsType(std::cout << "device_output: ", + // device_output.mData, ",") + // << std::endl; + } + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; + } + }; + + // InDataType and WeiDataType must be tuple, inLayout and weiLayout are single. + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple<>, + OutLayout, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "ckProfiler found " << op_ptrs.size() << " instances" << std::endl; + + std::array as{in_device_buf.GetDeviceBuffer(), + in_bias_device_buf.GetDeviceBuffer()}; + std::array bs{wei_device_buf.GetDeviceBuffer(), + wei_bias_device_buf.GetDeviceBuffer()}; + std::array ds{}; + + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer(as, + bs, + ds, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + {}, + {}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + run_impl(op_ptr, argument_ptr); + } + + printf("\033[36mvalids: %d\n\033[0m", valids); + + std::cout << "Best configuration parameters:" << "\nname: " << best_op_name + << "\navg_time: " << best_avg_time << "\ntflops: " << best_tflops + << "\nGB/s: " << best_gb_per_sec << std::endl; + + return pass; +} + +template +class TestGroupedConvndFwdScaleaddAB : public ::testing::Test +{ + protected: + using InDataType = std::tuple_element_t<0, Tuple>; + using WeiDataType = std::tuple_element_t<1, Tuple>; + using OutDataType = std::tuple_element_t<2, Tuple>; + using InLayout = std::tuple_element_t<3, Tuple>; + using WeiLayout = std::tuple_element_t<4, Tuple>; + using OutLayout = std::tuple_element_t<5, Tuple>; + + std::vector conv_params; + + template + void Run() + { + EXPECT_FALSE(conv_params.empty()); + bool pass = true; + for(auto& param : conv_params) + { + pass = pass && profile_grouped_conv_fwd_scaleadd_ab_impl( + true, // do_verification + 1, // init_method: integer value + false, // do_log + false, // time_kernel + param); + } + EXPECT_TRUE(pass); + } +}; + +using namespace ck::tensor_layout::convolution; + +// TODO: Not all possible layouts exist in the instance factory, (GNDHWC, GKZYXC, GNDHWK) only +// exists in example 62. +using KernelTypes3d = ::testing::Types, + std::tuple, + std::tuple, + std::tuple>; + +template +class TestGroupedConvndFwdScaleaddAB3d : public TestGroupedConvndFwdScaleaddAB +{ +}; + +TYPED_TEST_SUITE(TestGroupedConvndFwdScaleaddAB3d, KernelTypes3d); + +TYPED_TEST(TestGroupedConvndFwdScaleaddAB3d, Test3D) +{ + this->conv_params.clear(); + + // Client example 24. This one takes quite long. + this->conv_params.push_back( + {3, 32, 64, 32, 64, {3, 3, 3}, {14, 14, 14}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + + // Generic problems, same set as for vanilla, clamp, and (gk) bias clamp tests. + this->conv_params.push_back( + {3, 3, 5, 96, 200, {1, 1, 1}, {37, 37, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {1, 1, 1}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {2, 2, 2}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {5, 5, 5}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {9, 9, 9}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + + this->conv_params.push_back( + {3, 2, 32, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 2, 32, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + + this->conv_params.push_back( + {3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {1, 1, 1}, {16, 16, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + + this->conv_params.push_back( + {3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + + this->conv_params.push_back( + {3, 96, 1, 1, 1, {1, 1, 1}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 96, 1, 1, 1, {3, 3, 3}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->template Run<3>(); +} diff --git a/test/grouped_convnd_fwd_activation/CMakeLists.txt b/test/grouped_convnd_fwd_activation/CMakeLists.txt index 18b91143001..9b4c1922abb 100644 --- a/test/grouped_convnd_fwd_activation/CMakeLists.txt +++ b/test/grouped_convnd_fwd_activation/CMakeLists.txt @@ -19,7 +19,32 @@ if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_gtest_executable(test_grouped_convnd_fwd_clamp test_grouped_convnd_fwd_clamp.cpp) target_link_libraries(test_grouped_convnd_fwd_clamp PRIVATE utility device_grouped_conv2d_fwd_clamp_instance device_grouped_conv3d_fwd_clamp_instance) + add_executable(test_grouped_convnd_fwd_bias_clamp_large_cases test_grouped_convnd_fwd_bias_clamp_large_cases.cpp) target_compile_options(test_grouped_convnd_fwd_bias_clamp_large_cases PRIVATE -Wno-global-constructors -Wno-undef) target_link_libraries(test_grouped_convnd_fwd_bias_clamp_large_cases PRIVATE gtest_main getopt::getopt utility device_grouped_conv2d_fwd_bias_clamp_instance device_grouped_conv3d_fwd_bias_clamp_instance) + + add_gtest_executable(test_grouped_convnd_fwd_scale test_grouped_convnd_fwd_scale.cpp) + target_link_libraries(test_grouped_convnd_fwd_scale PRIVATE utility device_grouped_conv3d_fwd_scale_instance) + + add_gtest_executable(test_grouped_convnd_fwd_scaleadd_scaleadd_relu test_grouped_convnd_fwd_scaleadd_scaleadd_relu.cpp) + target_link_libraries(test_grouped_convnd_fwd_scaleadd_scaleadd_relu PRIVATE utility device_grouped_conv3d_fwd_scaleadd_scaleadd_relu_instance) + + add_gtest_executable(test_grouped_convnd_fwd_convinvscale test_grouped_convnd_fwd_convinvscale.cpp) + target_link_libraries(test_grouped_convnd_fwd_convinvscale PRIVATE utility device_grouped_conv3d_fwd_convinvscale_instance) + + add_gtest_executable(test_grouped_convnd_fwd_convscaleadd test_grouped_convnd_fwd_convscaleadd.cpp) + target_link_libraries(test_grouped_convnd_fwd_convscaleadd PRIVATE utility device_grouped_conv3d_fwd_convscale_add_instance) + + add_gtest_executable(test_grouped_convnd_fwd_convscalerelu test_grouped_convnd_fwd_convscalerelu.cpp) + target_link_libraries(test_grouped_convnd_fwd_convscalerelu PRIVATE utility device_grouped_conv3d_fwd_convscale_relu_instance) + + add_gtest_executable(test_grouped_convnd_fwd_convscale test_grouped_convnd_fwd_convscale.cpp) + target_link_libraries(test_grouped_convnd_fwd_convscale PRIVATE utility device_grouped_conv3d_fwd_convscale_instance) + + add_gtest_executable(test_grouped_convnd_fwd_combconvscale test_grouped_convnd_fwd_combconvscale.cpp) + target_link_libraries(test_grouped_convnd_fwd_combconvscale PRIVATE utility device_grouped_conv3d_fwd_convscale_instance) + + add_gtest_executable(test_grouped_convnd_fwd_combconvscalerelu test_grouped_convnd_fwd_combconvscalerelu.cpp) + target_link_libraries(test_grouped_convnd_fwd_combconvscalerelu PRIVATE utility device_grouped_conv3d_fwd_convscale_relu_instance) endif() diff --git a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_bias_clamp.cpp b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_bias_clamp.cpp index 955ff724130..d1706d4cec6 100644 --- a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_bias_clamp.cpp +++ b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_bias_clamp.cpp @@ -86,20 +86,76 @@ TYPED_TEST_SUITE(TestGroupedConvndFwd3d, KernelTypes3d); TYPED_TEST(TestGroupedConvndFwd2d, Test2D) { this->conv_params.clear(); + this->conv_params.push_back( + {2, 3, 5, 96, 200, {1, 1}, {73, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( + {2, 1, 1, 32, 32, {1, 1}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( + {2, 1, 1, 32, 32, {2, 2}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( + {2, 1, 1, 32, 32, {3, 3}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( + {2, 1, 1, 32, 32, {5, 5}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( + {2, 1, 1, 32, 32, {9, 9}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( {2, 2, 32, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); this->conv_params.push_back( {2, 2, 32, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + + this->conv_params.push_back( + {2, 2, 32, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back({2, 1, 1, 1, 32, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back({2, 1, 1, 64, 3, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back({2, 1, 1, 1, 1, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + + this->conv_params.push_back( + {2, 96, 1, 1, 1, {1, 1}, {120, 160}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back( + {2, 96, 1, 1, 1, {3, 3}, {120, 160}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->template Run<2>(); } TYPED_TEST(TestGroupedConvndFwd3d, Test3D) { this->conv_params.clear(); + + this->conv_params.push_back( + {3, 3, 5, 96, 200, {1, 1, 1}, {37, 37, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {1, 1, 1}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {2, 2, 2}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {5, 5, 5}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {9, 9, 9}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( {3, 2, 32, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); this->conv_params.push_back( {3, 2, 32, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + + this->conv_params.push_back( + {3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {1, 1, 1}, {16, 16, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + + this->conv_params.push_back( + {3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + + this->conv_params.push_back( + {3, 96, 1, 1, 1, {1, 1, 1}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 96, 1, 1, 1, {3, 3, 3}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); this->template Run<3>(); } int main(int argc, char** argv) diff --git a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_clamp.cpp b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_clamp.cpp index dc4b0f45032..fef485a9506 100644 --- a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_clamp.cpp +++ b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_clamp.cpp @@ -88,20 +88,75 @@ TYPED_TEST_SUITE(TestGroupedConvndFwd3d, KernelTypes3d); TYPED_TEST(TestGroupedConvndFwd2d, Test2D) { this->conv_params.clear(); + this->conv_params.push_back( + {2, 3, 5, 96, 200, {1, 1}, {73, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( + {2, 1, 1, 32, 32, {1, 1}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( + {2, 1, 1, 32, 32, {2, 2}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( + {2, 1, 1, 32, 32, {3, 3}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( + {2, 1, 1, 32, 32, {5, 5}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( + {2, 1, 1, 32, 32, {9, 9}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( {2, 2, 32, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); this->conv_params.push_back( {2, 2, 32, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + + this->conv_params.push_back( + {2, 2, 32, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back({2, 1, 1, 1, 32, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back({2, 1, 1, 64, 3, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back({2, 1, 1, 1, 1, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + + this->conv_params.push_back( + {2, 96, 1, 1, 1, {1, 1}, {120, 160}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back( + {2, 96, 1, 1, 1, {3, 3}, {120, 160}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->template Run<2>(); } TYPED_TEST(TestGroupedConvndFwd3d, Test3D) { this->conv_params.clear(); + + this->conv_params.push_back( + {3, 3, 5, 96, 200, {1, 1, 1}, {37, 37, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {1, 1, 1}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {2, 2, 2}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {5, 5, 5}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {9, 9, 9}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( {3, 2, 32, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); this->conv_params.push_back( {3, 2, 32, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + + this->conv_params.push_back( + {3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {1, 1, 1}, {16, 16, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + + this->conv_params.push_back( + {3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + + this->conv_params.push_back( + {3, 96, 1, 1, 1, {1, 1, 1}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 96, 1, 1, 1, {3, 3, 3}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); this->template Run<3>(); } int main(int argc, char** argv) diff --git a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_combconvscale.cpp b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_combconvscale.cpp new file mode 100644 index 00000000000..b4cfdb3f81d --- /dev/null +++ b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_combconvscale.cpp @@ -0,0 +1,120 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "profiler/profile_grouped_conv_fwd_outelementop_impl.hpp" + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" + +using CombConvScale = ck::tensor_operation::element_wise::ScaleScalePass; + +template +class TestGroupedConvndFwdCombConvScale : public ::testing::Test +{ + protected: + using InDataType = std::tuple_element_t<0, Tuple>; + using WeiDataType = std::tuple_element_t<1, Tuple>; + using OutDataType = std::tuple_element_t<2, Tuple>; + using InLayout = std::tuple_element_t<3, Tuple>; + using WeiLayout = std::tuple_element_t<4, Tuple>; + using OutLayout = std::tuple_element_t<5, Tuple>; + using IndexType = ck::index_t; + + std::vector conv_params; + + template + void Run() + { + EXPECT_FALSE(conv_params.empty()); + bool pass = true; + for(auto& param : conv_params) + { + if(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a") + { + if(std::is_same::value || + std::is_same::value) + { + printf("Skipping FP8 / BF8 tests on CDNA1/2.\n"); + continue; + } + } + pass = pass && ck::profiler::profile_grouped_conv_fwd_outelementop_impl( + true, // do_verification + 1, // init_method: integer value + false, // do_log + true, // time_kernel + param); + } + EXPECT_TRUE(pass); + } +}; + +using namespace ck::tensor_layout::convolution; +using CombConvScaleKernelTypes3d = + ::testing::Types>; + +template +class TestGroupedConvndFwdCombConvScale3d : public TestGroupedConvndFwdCombConvScale +{ +}; + +TYPED_TEST_SUITE(TestGroupedConvndFwdCombConvScale3d, CombConvScaleKernelTypes3d); + +TYPED_TEST(TestGroupedConvndFwdCombConvScale3d, Test3D) +{ + this->conv_params.clear(); + + this->conv_params.push_back( + {3, 3, 5, 96, 200, {1, 1, 1}, {37, 37, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {1, 1, 1}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {2, 2, 2}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {5, 5, 5}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {9, 9, 9}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + + this->conv_params.push_back( + {3, 2, 32, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 2, 32, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + + this->conv_params.push_back( + {3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {1, 1, 1}, {16, 16, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + + this->conv_params.push_back( + {3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + + this->conv_params.push_back( + {3, 96, 1, 1, 1, {1, 1, 1}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 96, 1, 1, 1, {3, 3, 3}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + + this->template Run<3>(); +} diff --git a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_combconvscalerelu.cpp b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_combconvscalerelu.cpp new file mode 100644 index 00000000000..38e425f0b00 --- /dev/null +++ b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_combconvscalerelu.cpp @@ -0,0 +1,121 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "profiler/profile_grouped_conv_fwd_outelementop_impl.hpp" + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" + +using CombConvScaleRelu = ck::tensor_operation::element_wise::ScaleScaleRelu; + +template +class TestGroupedConvndFwdCombConvScaleRelu : public ::testing::Test +{ + protected: + using InDataType = std::tuple_element_t<0, Tuple>; + using WeiDataType = std::tuple_element_t<1, Tuple>; + using OutDataType = std::tuple_element_t<2, Tuple>; + using InLayout = std::tuple_element_t<3, Tuple>; + using WeiLayout = std::tuple_element_t<4, Tuple>; + using OutLayout = std::tuple_element_t<5, Tuple>; + using IndexType = ck::index_t; + + std::vector conv_params; + + template + void Run() + { + EXPECT_FALSE(conv_params.empty()); + bool pass = true; + for(auto& param : conv_params) + { + if(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a") + { + if(std::is_same::value || + std::is_same::value) + { + printf("Skipping FP8 / BF8 tests on CDNA1/2.\n"); + continue; + } + } + pass = + pass && ck::profiler::profile_grouped_conv_fwd_outelementop_impl( + true, // do_verification + 1, // init_method: integer value + false, // do_log + true, // time_kernel + param); + } + EXPECT_TRUE(pass); + } +}; + +using namespace ck::tensor_layout::convolution; +using CombConvScaleReluKernelTypes3d = + ::testing::Types>; + +template +class TestGroupedConvndFwdCombConvScaleRelu3d : public TestGroupedConvndFwdCombConvScaleRelu +{ +}; + +TYPED_TEST_SUITE(TestGroupedConvndFwdCombConvScaleRelu3d, CombConvScaleReluKernelTypes3d); + +TYPED_TEST(TestGroupedConvndFwdCombConvScaleRelu3d, Test3D) +{ + this->conv_params.clear(); + + this->conv_params.push_back( + {3, 3, 5, 96, 200, {1, 1, 1}, {37, 37, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {1, 1, 1}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {2, 2, 2}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {5, 5, 5}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {9, 9, 9}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + + this->conv_params.push_back( + {3, 2, 32, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 2, 32, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + + this->conv_params.push_back( + {3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {1, 1, 1}, {16, 16, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + + this->conv_params.push_back( + {3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + + this->conv_params.push_back( + {3, 96, 1, 1, 1, {1, 1, 1}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 96, 1, 1, 1, {3, 3, 3}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + + this->template Run<3>(); +} diff --git a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_convinvscale.cpp b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_convinvscale.cpp new file mode 100644 index 00000000000..7acfb3836c4 --- /dev/null +++ b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_convinvscale.cpp @@ -0,0 +1,114 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "profiler/profile_grouped_conv_fwd_outelementop_impl.hpp" + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using ConvInvscale = ck::tensor_operation::element_wise::ConvInvscale; + +template +class TestGroupedConvndFwdConvInvscale : public ::testing::Test +{ + protected: + using DataType = std::tuple_element_t<0, Tuple>; + using InLayout = std::tuple_element_t<1, Tuple>; + using WeiLayout = std::tuple_element_t<2, Tuple>; + using OutLayout = std::tuple_element_t<3, Tuple>; + using IndexType = ck::index_t; + + std::vector conv_params; + + template + void Run() + { + EXPECT_FALSE(conv_params.empty()); + bool pass = true; + for(auto& param : conv_params) + { + if(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a") + { + if(std::is_same::value || + std::is_same::value) + { + printf("Skipping FP8 / BF8 tests on CDNA1/2.\n"); + continue; + } + } + pass = pass && ck::profiler::profile_grouped_conv_fwd_outelementop_impl( + true, // do_verification + 1, // init_method: integer value + false, // do_log + true, // time_kernel + param); + } + EXPECT_TRUE(pass); + } +}; + +using namespace ck::tensor_layout::convolution; + +using KernelTypes3d = ::testing::Types>; + +template +class TestGroupedConvndFwdConvInvscale3d : public TestGroupedConvndFwdConvInvscale +{ +}; + +TYPED_TEST_SUITE(TestGroupedConvndFwdConvInvscale3d, KernelTypes3d); + +TYPED_TEST(TestGroupedConvndFwdConvInvscale3d, Test3D) +{ + this->conv_params.clear(); + + this->conv_params.push_back( + {3, 3, 5, 96, 200, {1, 1, 1}, {37, 37, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {1, 1, 1}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {2, 2, 2}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {5, 5, 5}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {9, 9, 9}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + + this->conv_params.push_back( + {3, 2, 32, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 2, 32, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + + this->conv_params.push_back( + {3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {1, 1, 1}, {16, 16, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + + this->conv_params.push_back( + {3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + + this->conv_params.push_back( + {3, 96, 1, 1, 1, {1, 1, 1}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 96, 1, 1, 1, {3, 3, 3}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->template Run<3>(); +} diff --git a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_convscale.cpp b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_convscale.cpp new file mode 100644 index 00000000000..85dd36dee27 --- /dev/null +++ b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_convscale.cpp @@ -0,0 +1,122 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "profiler/profile_grouped_conv_fwd_outelementop_impl.hpp" + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using ConvScale = ck::tensor_operation::element_wise::ConvScale; + +template +class TestGroupedConvndFwdConvScale : public ::testing::Test +{ + protected: + using InDataType = std::tuple_element_t<0, Tuple>; + using WeiDataType = std::tuple_element_t<1, Tuple>; + using OutDataType = std::tuple_element_t<2, Tuple>; + using AComputeType = std::tuple_element_t<3, Tuple>; + using BComputeType = std::tuple_element_t<4, Tuple>; + using InLayout = std::tuple_element_t<5, Tuple>; + using WeiLayout = std::tuple_element_t<6, Tuple>; + using OutLayout = std::tuple_element_t<7, Tuple>; + using IndexType = ck::index_t; + + std::vector conv_params; + + template + void Run() + { + EXPECT_FALSE(conv_params.empty()); + bool pass = true; + for(auto& param : conv_params) + { + if(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a") + { + if(std::is_same::value || + std::is_same::value) + { + printf("Skipping FP8 / BF8 tests on CDNA1/2.\n"); + continue; + } + } + pass = pass && ck::profiler::profile_grouped_conv_fwd_outelementop_impl( + true, // do_verification + 1, // init_method: integer value + false, // do_log + true, // time_kernel + param); + } + EXPECT_TRUE(pass); + } +}; + +using namespace ck::tensor_layout::convolution; +using KernelTypes3d = ::testing::Types< + std::tuple, + std::tuple, + std::tuple, + std::tuple>; +template +class TestGroupedConvndFwdConvScale3d : public TestGroupedConvndFwdConvScale +{ +}; + +TYPED_TEST_SUITE(TestGroupedConvndFwdConvScale3d, KernelTypes3d); + +TYPED_TEST(TestGroupedConvndFwdConvScale3d, Test3D) +{ + this->conv_params.clear(); + + this->conv_params.push_back( + {3, 3, 5, 96, 200, {1, 1, 1}, {37, 37, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {1, 1, 1}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {2, 2, 2}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {5, 5, 5}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {9, 9, 9}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + + this->conv_params.push_back( + {3, 2, 32, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 2, 32, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + + this->conv_params.push_back( + {3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {1, 1, 1}, {16, 16, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + + this->conv_params.push_back( + {3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + + this->conv_params.push_back( + {3, 96, 1, 1, 1, {1, 1, 1}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 96, 1, 1, 1, {3, 3, 3}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->template Run<3>(); +} diff --git a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_convscaleadd.cpp b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_convscaleadd.cpp new file mode 100644 index 00000000000..ec83ef7f1c6 --- /dev/null +++ b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_convscaleadd.cpp @@ -0,0 +1,116 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "profiler/profile_grouped_conv_fwd_convscale_add_impl.hpp" + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using ConvScaleAdd = ck::tensor_operation::element_wise::ConvScaleAdd; + +template +class TestGroupedConvndFwdConvScaleAdd : public ::testing::Test +{ + protected: + using DataType = std::tuple_element_t<0, Tuple>; + using InLayout = std::tuple_element_t<1, Tuple>; + using WeiLayout = std::tuple_element_t<2, Tuple>; + using BiasLayout = std::tuple_element_t<3, Tuple>; + using OutLayout = std::tuple_element_t<4, Tuple>; + using IndexType = ck::index_t; + + std::vector conv_params; + + template + void Run() + { + EXPECT_FALSE(conv_params.empty()); + bool pass = true; + for(auto& param : conv_params) + { + if(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a") + { + if(std::is_same::value || + std::is_same::value) + { + printf("Skipping FP8 / BF8 tests on CDNA1/2.\n"); + continue; + } + } + pass = pass && ck::profiler::profile_grouped_conv_fwd_convscale_add_impl( + true, // do_verification + 1, // init_method: integer value + false, // do_log + true, // time_kernel + param); + } + EXPECT_TRUE(pass); + } +}; + +using namespace ck::tensor_layout::convolution; + +using KernelTypes3d = ::testing::Types>; + +template +class TestGroupedConvndFwdConvScaleAdd3d : public TestGroupedConvndFwdConvScaleAdd +{ +}; + +TYPED_TEST_SUITE(TestGroupedConvndFwdConvScaleAdd3d, KernelTypes3d); + +TYPED_TEST(TestGroupedConvndFwdConvScaleAdd3d, Test3D) +{ + this->conv_params.clear(); + + this->conv_params.push_back( + {3, 3, 5, 96, 200, {1, 1, 1}, {37, 37, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {1, 1, 1}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {2, 2, 2}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {5, 5, 5}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {9, 9, 9}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + + this->conv_params.push_back( + {3, 2, 32, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 2, 32, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + + this->conv_params.push_back( + {3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {1, 1, 1}, {16, 16, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + + this->conv_params.push_back( + {3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + + this->conv_params.push_back( + {3, 96, 1, 1, 1, {1, 1, 1}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 96, 1, 1, 1, {3, 3, 3}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->template Run<3>(); +} diff --git a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_convscalerelu.cpp b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_convscalerelu.cpp new file mode 100644 index 00000000000..d667a9becfb --- /dev/null +++ b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_convscalerelu.cpp @@ -0,0 +1,114 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "profiler/profile_grouped_conv_fwd_outelementop_impl.hpp" + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using ConvScaleRelu = ck::tensor_operation::element_wise::ConvScaleRelu; + +template +class TestGroupedConvndFwdConvScaleRelu : public ::testing::Test +{ + protected: + using DataType = std::tuple_element_t<0, Tuple>; + using InLayout = std::tuple_element_t<1, Tuple>; + using WeiLayout = std::tuple_element_t<2, Tuple>; + using OutLayout = std::tuple_element_t<3, Tuple>; + using IndexType = ck::index_t; + + std::vector conv_params; + + template + void Run() + { + EXPECT_FALSE(conv_params.empty()); + bool pass = true; + for(auto& param : conv_params) + { + if(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a") + { + if(std::is_same::value || + std::is_same::value) + { + printf("Skipping FP8 / BF8 tests on CDNA1/2.\n"); + continue; + } + } + pass = pass && ck::profiler::profile_grouped_conv_fwd_outelementop_impl( + true, // do_verification + 1, // init_method: integer value + false, // do_log + true, // time_kernel + param); + } + EXPECT_TRUE(pass); + } +}; + +using namespace ck::tensor_layout::convolution; + +using KernelTypes3d = ::testing::Types>; + +template +class TestGroupedConvndFwdConvScaleRelu3d : public TestGroupedConvndFwdConvScaleRelu +{ +}; + +TYPED_TEST_SUITE(TestGroupedConvndFwdConvScaleRelu3d, KernelTypes3d); + +TYPED_TEST(TestGroupedConvndFwdConvScaleRelu3d, Test3D) +{ + this->conv_params.clear(); + + this->conv_params.push_back( + {3, 3, 5, 96, 200, {1, 1, 1}, {37, 37, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {1, 1, 1}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {2, 2, 2}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {5, 5, 5}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {9, 9, 9}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + + this->conv_params.push_back( + {3, 2, 32, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 2, 32, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + + this->conv_params.push_back( + {3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {1, 1, 1}, {16, 16, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + + this->conv_params.push_back( + {3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + + this->conv_params.push_back( + {3, 96, 1, 1, 1, {1, 1, 1}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 96, 1, 1, 1, {3, 3, 3}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->template Run<3>(); +} diff --git a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_gk_bias_clamp.cpp b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_gk_bias_clamp.cpp index 726ce30a1bc..a78a17cbf49 100644 --- a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_gk_bias_clamp.cpp +++ b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_gk_bias_clamp.cpp @@ -81,20 +81,75 @@ TYPED_TEST_SUITE(TestGroupedConvndFwd3d, KernelTypes3d); TYPED_TEST(TestGroupedConvndFwd2d, Test2D) { this->conv_params.clear(); + this->conv_params.push_back( + {2, 3, 5, 96, 200, {1, 1}, {73, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( + {2, 1, 1, 32, 32, {1, 1}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( + {2, 1, 1, 32, 32, {2, 2}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( + {2, 1, 1, 32, 32, {3, 3}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( + {2, 1, 1, 32, 32, {5, 5}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( + {2, 1, 1, 32, 32, {9, 9}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( {2, 2, 32, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); this->conv_params.push_back( {2, 2, 32, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + + this->conv_params.push_back( + {2, 2, 32, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back({2, 1, 1, 1, 32, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back({2, 1, 1, 64, 3, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back({2, 1, 1, 1, 1, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + + this->conv_params.push_back( + {2, 96, 1, 1, 1, {1, 1}, {120, 160}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back( + {2, 96, 1, 1, 1, {3, 3}, {120, 160}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->template Run<2>(); } TYPED_TEST(TestGroupedConvndFwd3d, Test3D) { this->conv_params.clear(); + + this->conv_params.push_back( + {3, 3, 5, 96, 200, {1, 1, 1}, {37, 37, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {1, 1, 1}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {2, 2, 2}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {5, 5, 5}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {9, 9, 9}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( {3, 2, 32, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); this->conv_params.push_back( {3, 2, 32, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + + this->conv_params.push_back( + {3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {1, 1, 1}, {16, 16, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + + this->conv_params.push_back( + {3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + + this->conv_params.push_back( + {3, 96, 1, 1, 1, {1, 1, 1}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 96, 1, 1, 1, {3, 3, 3}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); this->template Run<3>(); } int main(int argc, char** argv) diff --git a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_scale.cpp b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_scale.cpp new file mode 100644 index 00000000000..b4179cae627 --- /dev/null +++ b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_scale.cpp @@ -0,0 +1,124 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "profiler/profile_grouped_conv_fwd_outelementop_impl.hpp" + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" + +using I8 = int8_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; + +template +class TestGroupedConvndFwdScale : public ::testing::Test +{ + protected: + using InDataType = std::tuple_element_t<0, Tuple>; + using WeiDataType = std::tuple_element_t<1, Tuple>; + using OutDataType = std::tuple_element_t<2, Tuple>; + using InLayout = std::tuple_element_t<3, Tuple>; + using WeiLayout = std::tuple_element_t<4, Tuple>; + using OutLayout = std::tuple_element_t<5, Tuple>; + using IndexType = ck::index_t; + + std::vector conv_params; + + template + void Run() + { + EXPECT_FALSE(conv_params.empty()); + bool pass = true; + for(auto& param : conv_params) + { + if(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a") + { + if(std::is_same::value || + std::is_same::value) + { + printf("Skipping FP8 / BF8 tests on CDNA1/2.\n"); + continue; + } + } + pass = pass && ck::profiler::profile_grouped_conv_fwd_outelementop_impl< + NDimSpatial, + InLayout, + WeiLayout, + OutLayout, + InDataType, + WeiDataType, + OutDataType, + ck::tensor_operation::element_wise::Scale, + InDataType, + InDataType>(true, // do_verification + 1, // init_method: integer value + false, // do_log + true, // time_kernel + param); + } + EXPECT_TRUE(pass); + } +}; + +using namespace ck::tensor_layout::convolution; +using CombConvScaleKernelTypes3d = + ::testing::Types, + std::tuple, + std::tuple>; + +template +class TestGroupedConvndFwdScale3d : public TestGroupedConvndFwdScale +{ +}; + +TYPED_TEST_SUITE(TestGroupedConvndFwdScale3d, CombConvScaleKernelTypes3d); + +TYPED_TEST(TestGroupedConvndFwdScale3d, Test3D) +{ + this->conv_params.clear(); + + this->conv_params.push_back( + {3, 3, 5, 96, 200, {1, 1, 1}, {37, 37, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {1, 1, 1}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {2, 2, 2}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {5, 5, 5}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {9, 9, 9}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + + this->conv_params.push_back( + {3, 2, 32, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 2, 32, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + + this->conv_params.push_back( + {3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {1, 1, 1}, {16, 16, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + + this->conv_params.push_back( + {3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + + this->conv_params.push_back( + {3, 96, 1, 1, 1, {1, 1, 1}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 96, 1, 1, 1, {3, 3, 3}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + + this->template Run<3>(); +} diff --git a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_scaleadd_scaleadd_relu.cpp b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_scaleadd_scaleadd_relu.cpp new file mode 100644 index 00000000000..726247dffcb --- /dev/null +++ b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_scaleadd_scaleadd_relu.cpp @@ -0,0 +1,111 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "profiler/profile_grouped_conv_fwd_scaleadd_scaleadd_relu_impl.hpp" + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using I8 = int8_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; + +template +class TestGroupedConvndFwdScaleAddScaleAddRelu : public ::testing::Test +{ + protected: + using InDataType = std::tuple_element_t<0, Tuple>; + using WeiDataType = std::tuple_element_t<1, Tuple>; + using OutDataType = std::tuple_element_t<2, Tuple>; + using InLayout = std::tuple_element_t<3, Tuple>; + using WeiLayout = std::tuple_element_t<4, Tuple>; + using OutLayout = std::tuple_element_t<5, Tuple>; + + std::vector conv_params; + + template + void Run() + { + EXPECT_FALSE(conv_params.empty()); + bool pass = true; + for(auto& param : conv_params) + { + if(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a") + { + if(std::is_same::value || + std::is_same::value) + { + printf("Skipping FP8 / BF8 tests on CDNA1/2.\n"); + continue; + } + } + pass = pass && ck::profiler::profile_grouped_conv_fwd_scaleadd_scaleadd_relu_impl< + NDimSpatial, + InLayout, + WeiLayout, + OutLayout, + InDataType, + WeiDataType, + OutDataType, + ck::tensor_operation::element_wise::ScaleAddScaleAddRelu, + InDataType, + InDataType>(true, // do_verification + 1, // init_method: integer value + false, // do_log + true, // time_kernel + param); + } + EXPECT_TRUE(pass); + } +}; + +using namespace ck::tensor_layout::convolution; +using CombConvScaleAddScaleAddReluKernelTypes3d = + ::testing::Types, + std::tuple, + std::tuple>; + +template +class TestGroupedConvndFwdScaleAddScaleAddRelu3d + : public TestGroupedConvndFwdScaleAddScaleAddRelu +{ +}; + +TYPED_TEST_SUITE(TestGroupedConvndFwdScaleAddScaleAddRelu3d, + CombConvScaleAddScaleAddReluKernelTypes3d); + +TYPED_TEST(TestGroupedConvndFwdScaleAddScaleAddRelu3d, Test3D) +{ + this->conv_params.clear(); + + this->conv_params.push_back( + {3, 3, 5, 96, 200, {1, 1, 1}, {37, 37, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {1, 1, 1}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {2, 2, 2}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {5, 5, 5}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + + this->conv_params.push_back( + {3, 2, 32, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 2, 32, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + + this->conv_params.push_back( + {3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + + this->template Run<3>(); +}