diff --git a/example/65_gemm_multiply_multiply/CMakeLists.txt b/example/65_gemm_multiply_multiply/CMakeLists.txt index 944a8f96bf5..24a4106ae72 100644 --- a/example/65_gemm_multiply_multiply/CMakeLists.txt +++ b/example/65_gemm_multiply_multiply/CMakeLists.txt @@ -18,6 +18,7 @@ add_example_executable(example_moe_gemm1_xdl_fp8 moe_gemm1_xdl_fp8.cpp) add_example_executable(example_moe_gemm2_xdl_fp8 moe_gemm2_xdl_fp8.cpp) add_example_executable(example_moe_gemm2_xdl_fp8_blockscale moe_gemm2_xdl_fp8_blockscale.cpp) add_example_executable(example_moe_gemm1_xdl_fp8_blockscale moe_gemm1_xdl_fp8_blockscale.cpp) +add_example_executable(example_moe_gemm1_xdl_fp8_blockscale_splitk moe_gemm1_xdl_fp8_blockscale_splitk.cpp) list(APPEND gpu_list gfx942 gfx950 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1153 gfx1200 gfx1201 gfx11-generic gfx12-generic) set(target 0) diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp index fdaef8ec3ea..ecc3034bba5 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp @@ -171,7 +171,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceM // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| CShuffleMXDLPerWave, CShuffleNXDLPerWave, S<1, 32, 1, 8>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, true, MulRoutedWeight, int32_t, A0DataType>; + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, true, false, MulRoutedWeight, int32_t, A0DataType>; #else static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale< Row, Col, DsLayout, ELayout, @@ -185,7 +185,7 @@ static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 4, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, true, MulRoutedWeight, int32_t, A0DataType>; + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, true, false, MulRoutedWeight, int32_t, A0DataType>; #endif // clang-format on diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale_splitk.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale_splitk.cpp new file mode 100644 index 00000000000..08a1d6614e7 --- /dev/null +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale_splitk.cpp @@ -0,0 +1,535 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_moe_gemm1_blockscale_splitk.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +using ::ck::DeviceMem; +using ::ck::HostTensorDescriptor; +using ::ck::Tensor; + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F8 = ck::f8_t; +using F32 = float; +using I64 = int64_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + +using A0DataType = F8; +using A1DataType = F32; +using B0DataType = F8; +using B1DataType = F32; +using EDataType = F32; +using AccDataType = F32; +using CShuffleDataType = EDataType; +using D2DataType = F32; +using DsDataType = ck::Tuple; + +using A0Layout = Row; +using B0Layout = Col; +using ELayout = Row; +using D0Layout = Row; +using D1Layout = Col; +using D2Layout = ELayout; +using DsLayout = ck::Tuple; + +struct MulABScaleExpertWeight +{ + template + __host__ __device__ constexpr void operator()(E& e, const C& c, const D2& d2) const; + // for real kernel use + template <> + __host__ __device__ constexpr void + operator()(EDataType& e, const EDataType& c, const float& d2) const + { + (void)d2; + e = ck::type_convert(c); + } +}; + +void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl) +{ + int KPack = 16 / sizeof(B0DataType); + int NLane = NXdl; + int KLane = 64 / NLane; + + int K0 = K / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NLane + // N, K -> N0 K0 KLane NLane KPack + int tempk; + for(I64 n = 0; n < N; ++n) + { + for(I64 k = 0; k < K; ++k) + { + I64 n0 = n / NLane; + I64 n1 = n % NLane; + + I64 k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + I64 k1 = tempk / KPack; + I64 k2 = tempk % KPack; + + I64 outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + + k1 * KPack * NLane + n1 * KPack + k2; + + dst[outputIndex] = src[n * static_cast(K) + k]; + } + } +} +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MulABScaleExpertWeight; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr ck::index_t Scale_Block_M = 1; +static constexpr ck::index_t Scale_Block_N = 128; +static constexpr ck::index_t Scale_Block_K = 128; + +static constexpr ck::index_t Nswizzle = false; +static constexpr ck::index_t IsInputGemm = true; // splitk gemm1 goes to gemm2 pipeline. +static constexpr ck::index_t IsSplitK = true; // splitk gemm1 +static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul +static constexpr bool MulRoutedWeight = false; // splitk gemm1 does not do routedWeight. + +#if 1 +static constexpr ck::index_t MPerBlock = 32; +static constexpr ck::index_t NPerBlock = 128; +static constexpr ck::index_t MNPerXDL = 16; +static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * 1); +static constexpr ck::index_t NXDLPerWave = NPerBlock / (MNPerXDL * 4); +static constexpr ck::index_t CShuffleMXDLPerWave = MXDLPerWave; +static constexpr ck::index_t CShuffleNXDLPerWave = NXDLPerWave; +static constexpr ck::index_t BLOCKSIZE = 256; + +static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); +static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); +static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); +static constexpr ck::index_t EVec = 16 / sizeof(EDataType); +static constexpr ck::index_t D0Vec = 1; +static constexpr ck::index_t D1Vec = 1; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale + // clang-format off + < Row, Col, DsLayout, ELayout, + A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + //threadnum, mblock, nblock, kblock + BLOCKSIZE, Scale_Block_M, Scale_Block_N, Scale_Block_K, + MPerBlock, NPerBlock, KPerBlock, + // ak1, bk1 + AK1, BK1, + // mn_perxdl + MNPerXDL, MNPerXDL, + // mn_xdlperwave + MXDLPerWave, NXDLPerWave, + // a,b: loadtranfer cluster, cluster order, srcorder,VECDIM, srcpervec, dstpervec, lds_extra + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, + // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + CShuffleMXDLPerWave, CShuffleNXDLPerWave, S<1, 32, 1, 8>, S, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, IsInputGemm, IsSplitK, MulRoutedWeight, int32_t, A0DataType>; +#else + +static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale< + Row, Col, DsLayout, ELayout, + A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + 256, Scale_Block_M, Scale_Block_N, Scale_Block_K, + MPerBlock, 128, 128, + 16, 16, + 16, 16, + 4, 2, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + 4, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, IsInputGemm, IsSplitK, MulRoutedWeight, int32_t, A0DataType>; +#endif +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; +#if 1 + // GEMM shape + ck::index_t N = 4096; + ck::index_t K = 6144; + // ck::index_t N = 128; + // ck::index_t K = 512; + ck::index_t experts = 8; + ck::index_t topk = 2; + // ck::index_t sorted_tile_num = 515; + // ck::index_t valid_tile_num = 512; + // ck::index_t tokens = 208; + // ck::index_t sorted_tile_num = 15; + // ck::index_t valid_tile_num = 13; + // ck::index_t sorted_tile_num = 259; + // ck::index_t valid_tile_num = 256; + // ck::index_t tokens = 4096; + ck::index_t sorted_tile_num = 2; + ck::index_t valid_tile_num = 2; + ck::index_t tokens = 32; +#else + // deepseek + ck::index_t N = 2048; + ck::index_t K = 7168; + ck::index_t experts = 256; + ck::index_t topk = 8; + ck::index_t tokens = 4096; + ck::index_t sorted_tile_num = 261; + ck::index_t valid_tile_num = 256; +#endif + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + // use default case + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 7) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + tokens = std::stoi(argv[6]); + } + else if(argc == 9) + { + + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + tokens = std::stoi(argv[6]); + sorted_tile_num = std::stoi(argv[7]); + valid_tile_num = std::stoi(argv[8]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 6: N, K, tokens\n"); + exit(0); + } + + ck::index_t sorted_size = sorted_tile_num * MPerBlock; + ck::index_t valid_size = valid_tile_num * MPerBlock; + if(tokens * topk > valid_size) + { + printf("err config, tokens * topk > valid_size\n"); + exit(-1); + } + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideE = N * 2; + constexpr ck::index_t NumDTensor = DsDataType::Size(); + constexpr auto StrideDs = std::array{0}; + ck::index_t Scale_Stride_AM = (K + Scale_Block_K - 1) / Scale_Block_K; + ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K; + ck::index_t Scale_Stride_B = (N + Scale_Block_N - 1) / Scale_Block_N * 2; + + ck::index_t KBatch = 6; + + Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); + Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); + Tensor max_token_id(HostTensorDescriptor({1 + sorted_tile_num})); + max_token_id.mData = {valid_size}; + // int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3}; + for(int i = 0; i < sorted_tile_num; i++) + { + expert_ids.mData[i] = i / ck::math::integer_divide_ceil(valid_tile_num, experts); + } + + int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num; + int tokenid = 0; + + for(int i = 0; i < sorted_size; i++) + { + int tile_off = i % MPerBlock; + if(tile_off < token_per_tile && tokenid < tokens * topk) + { + sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24); + tokenid++; + } + else + { + sorted_token_ids.mData[i] = tokens; + } + } + Tensor a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); + Tensor a1_t_k(HostTensorDescriptor( + {tokens, (K + Scale_Block_K - 1) / Scale_Block_K}, {Scale_Stride_AM, 1}, Row{})); + Tensor b0_e_n_k( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); + Tensor b1_e_n_k( + HostTensorDescriptor({experts, + (K + Scale_Block_K - 1) / Scale_Block_K, + (N + Scale_Block_N - 1) / Scale_Block_N * 2}, + {(Scale_Stride_B * Scale_Stride_BN), 1, Scale_Stride_BN}, + Col{})); + Tensor b0_preshuffled( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); + Tensor e_t_n_host_result( + HostTensorDescriptor({tokens, topk, N * 2}, {topk * N * 2, N * 2, 1}, Row{})); + Tensor e_t_n_device_result( + HostTensorDescriptor({tokens, topk, N * 2}, {topk * N * 2, N * 2, 1}, Row{})); + e_t_n_device_result.SetZero(); + std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; + std::cout << "a1_t_k: " << a1_t_k.mDesc << std::endl; + std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; + std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl; + std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_t_k.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 2: + a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 3: + a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 4: + a0_t_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 5: + a0_t_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 6: + a0_t_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + break; + default: + a0_t_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + } + DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * + sorted_token_ids.mDesc.GetElementSpaceSize()); + DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize()); + DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize()); + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.mDesc.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(A1DataType) * a1_t_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_e_n_k.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); + + sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); + expert_ids_dev.ToDevice(expert_ids.mData.data()); + max_token_id_dev.ToDevice(max_token_id.mData.data()); + a0_device_buf.ToDevice(a0_t_k.mData.data()); + a1_device_buf.ToDevice(a1_t_k.mData.data()); + b1_device_buf.ToDevice(b1_e_n_k.mData.data()); + e_device_buf.ToDevice(e_t_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + + int NPerXdl = device_op.GetPreShuffleParameters(); + + preShuffleBuffer( + b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * 2 * experts, K, NPerXdl); + + b0_device_buf.ToDevice(b0_preshuffled.mData.data()); + + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument(sorted_token_ids_dev.GetDeviceBuffer(), + expert_ids_dev.GetDeviceBuffer(), + max_token_id_dev.GetDeviceBuffer(), + a0_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + std::array{nullptr}, + e_device_buf.GetDeviceBuffer(), + tokens, + topk, + sorted_size, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + a1_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer(), + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + if(time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * tokens * topk * N * 2 * K; + std::size_t num_btype = sizeof(A0DataType) * valid_tile_num * K + + sizeof(B0DataType) * K * N * 2 * experts + + sizeof(EDataType) * valid_tile_num * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s.\n" + << device_op.GetTypeString() << std::endl; + } + + if(do_verification) + { + // use atomic, so need to reinit outputs + e_device_buf.ToDevice(e_t_n_device_result.mData.data()); + invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); + + Tensor a_t_k({tokens, K}); + Tensor b_e_n_k({experts, K, N * 2}); + e_device_buf.FromDevice(e_t_n_device_result.mData.data()); + + Tensor c_t_k_n({tokens, topk, N * 2}, {topk * N * 2, N * 2, 1}, Row{}); + + // handle scale before ref. + for(int t = 0; t < tokens; ++t) + { + for(int k = 0; k < K; ++k) + { + a_t_k(t, k) = ck::type_convert(a0_t_k(t, k)) * a1_t_k(t, k / Scale_Block_K); + } + } + + for(int e = 0; e < experts; ++e) + { + for(int k = 0; k < K; ++k) + { + for(int n = 0; n < N * 2; ++n) + { + b_e_n_k(e, k, n) = ck::type_convert(b0_e_n_k(e, k, n)) * + b1_e_n_k(e, k / Scale_Block_K, n / Scale_Block_N); + } + } + } + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceMoeGemm1BlockScaleSplitK; + auto ref_moe_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_moe_gemm.MakeInvoker(); + + auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids, + expert_ids, + max_token_id, + MPerBlock, + a_t_k, + b_e_n_k, + c_t_k_n, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + ref_invoker.Run(ref_argument); + for(int m = 0; m < valid_size; ++m) + { + + const int fuse_t = sorted_token_ids.mData[m]; + const int t = fuse_t & 0xffffff; + const int topk_id = (fuse_t & 0xff000000) >> 24; + + if(t >= tokens) + { + continue; + } + for(int n = 0; n < 2 * N; ++n) + { + e_t_n_host_result(t, topk_id, n) = + ck::type_convert(c_t_k_n(t, topk_id, n)); + } + } + + e_device_buf.FromDevice(e_t_n_device_result.mData.data()); + + auto status = + ck::utils::check_err( + e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1) + ? 0 + : 1; + if(status == 0) + { + printf("Validation Pass.\n"); + } + return status; + } + + return 0; +} diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp index 35c9d3d7884..fb5e3b64567 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp @@ -165,7 +165,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale< S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, 2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, int32_t, A0DataType>; + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, false, MulRoutedWeight, int32_t, A0DataType>; #else static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale< @@ -180,7 +180,7 @@ static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, int32_t, A0DataType>; + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, false, MulRoutedWeight, int32_t, A0DataType>; #endif // clang-format on diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp index feaace8919d..df7179efe50 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp @@ -74,6 +74,7 @@ template 1) - hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, - 0, - arg_.M * arg_.N * sizeof(CDataType), - stream_config.stream_id_)); + // if(arg_.KBatch > 1) + // hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, + // 0, + // arg_.M * arg_.N * sizeof(CDataType) + // * (IsInputGemm && IsSplitK ? 2 : 1), + // stream_config.stream_id_)); }; ave_time = ck::utility::launch_and_time_kernel_with_preprocess( @@ -267,11 +270,12 @@ struct DeviceMoeGemmBlockScale } else { - if(arg.KBatch > 1) - hipGetErrorString(hipMemsetAsync(arg.p_c_grid, - 0, - arg.M * arg.N * sizeof(CDataType), - stream_config.stream_id_)); + // if(arg.KBatch > 1) + // hipGetErrorString(hipMemsetAsync(arg.p_c_grid, + // 0, + // arg.M * arg.N * sizeof(CDataType) * + // (IsInputGemm && IsSplitK ? 2 : 1), + // stream_config.stream_id_)); ave_time = launch_and_time_kernel( stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); @@ -289,8 +293,9 @@ struct DeviceMoeGemmBlockScale constexpr index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2; - constexpr auto MemoryDataOp = - IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd; + constexpr auto MemoryDataOp = (IsInputGemm && !IsSplitK) + ? InMemoryDataOperationEnum::Set + : InMemoryDataOperationEnum::AtomicAdd; if(has_main_k_block_loop) { @@ -416,8 +421,8 @@ struct DeviceMoeGemmBlockScale static bool IsSupportedArgument(const Argument& arg) { - // only impl kbatch 1 now - if(arg.KBatch > 1) + // only impl kbatch 1 for fp32 + if(arg.KBatch > 1 && !std::is_same_v) { return false; } @@ -441,6 +446,11 @@ struct DeviceMoeGemmBlockScale { return false; } + if(arg.KBatch > 1 && arg.K % (KPerBlock * arg.KBatch) != 0) + { + // Not support Kpadding with KBatch > 1 + return false; + } if(get_warp_size() == 64) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp index 0a565bf17e8..4ba8ab00ea3 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp @@ -60,8 +60,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) karg.p_b_grid + splitk_batch_offset.b_k_split_offset, karg.p_ds_grid, karg.p_c_grid, - karg.p_a_scale_grid, - karg.p_b_scale_grid, + karg.p_a_scale_grid + splitk_batch_offset.ascale_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.bscale_k_split_offset, p_shared, karg, karg.a_element_op, @@ -101,8 +101,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) karg.p_b_grid + splitk_batch_offset.b_k_split_offset, karg.p_ds_grid, karg.p_c_grid, - karg.p_a_scale_grid, - karg.p_b_scale_grid, + karg.p_a_scale_grid + splitk_batch_offset.ascale_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.bscale_k_split_offset, p_shared, p_shared1, karg, @@ -167,6 +167,7 @@ template {}, Sequence<0>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - return a_grid_desc_ak0_m_ak1; } } @@ -741,35 +748,42 @@ struct GridwiseMoeGemmBlockScale { if constexpr(is_same_v) { - a_k_split_offset = k_id * karg.KRead / APackedSize; + a_k_split_offset = k_id * karg.KRead / APackedSize; + ascale_k_split_offset = math::integer_divide_floor(a_k_split_offset, ScaleBlockK); } else if constexpr(is_same_v) { - a_k_split_offset = k_id * karg.KRead * karg.StrideA; + a_k_split_offset = k_id * karg.KRead * karg.StrideA; + ascale_k_split_offset = math::integer_divide_floor(a_k_split_offset, ScaleBlockK); } if constexpr(is_same_v) { - b_k_split_offset = k_id * karg.KRead * karg.StrideB; + b_k_split_offset = k_id * karg.KRead * karg.StrideB; + bscale_k_split_offset = math::integer_divide_floor(b_k_split_offset, ScaleBlockK); } else if constexpr(is_same_v) { // KPack * NLane * KLane * K0 * N0 b_k_split_offset = k_id * karg.KRead * NLane / BPackedSize; + bscale_k_split_offset = + k_id * karg.KRead / ScaleBlockK + k_id * NLane / ScaleBlockN; } - if(k_id < karg.KBatch - 1) - { - karg.K = karg.KRead; - } - else - { - karg.K = karg.K - karg.KRead * (karg.KBatch - 1); - } + // if(k_id < karg.KBatch - 1) + // { + // karg.K = karg.KRead; + // } + // else + // { + // karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + // } } index_t a_k_split_offset; index_t b_k_split_offset; + index_t ascale_k_split_offset; + index_t bscale_k_split_offset; }; __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() @@ -912,8 +926,8 @@ struct GridwiseMoeGemmBlockScale } using BlockwiseGemmPipe = - remove_cvref_t())>; + IsInputGemm && !IsSplitK > ())>; __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { @@ -1189,9 +1203,9 @@ struct GridwiseMoeGemmBlockScale BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) { - ignore = b_element_op; - index_t BN0Shuffled = CalculateBN0Shuffled(problem.N); - index_t BK0Shuffled = CalculateBK0Shuffled(problem.K); + ignore = b_element_op; + index_t BN0Shuffled = CalculateBN0Shuffled(problem.N * (IsInputGemm && IsSplitK ? 2 : 1)); + index_t BK0Shuffled = CalculateBK0Shuffled(problem.K); const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK, problem.MPadded, @@ -1204,8 +1218,8 @@ struct GridwiseMoeGemmBlockScale const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens, problem.MPadded, - problem.N, - problem.NPadded, + problem.N * (IsInputGemm && IsSplitK ? 2 : 1), + problem.NPadded * (IsInputGemm && IsSplitK ? 2 : 1), problem.StrideC); const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor( @@ -1215,7 +1229,8 @@ struct GridwiseMoeGemmBlockScale math::integer_divide_ceil(problem.K, ScaleBlockK)), make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1)); const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( - make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN), + make_tuple(math::integer_divide_ceil(problem.N * (IsInputGemm && IsSplitK ? 2 : 1), + ScaleBlockN), math::integer_divide_ceil(problem.K, ScaleBlockK)), make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1)); @@ -1371,9 +1386,10 @@ struct GridwiseMoeGemmBlockScale decltype(c_thread_buf) c_thread_buf_up; const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( - (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / - KPerBlock); - + problem.KBatch == 1 + ? (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock + : problem.KBatch); constexpr index_t ScaleSliceSizeM = MXdlPerWave; constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN); constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK); @@ -1447,7 +1463,7 @@ struct GridwiseMoeGemmBlockScale constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK); constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock); - if constexpr(IsInputGemm) + if constexpr(IsInputGemm && !IsSplitK) { const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize; const auto b_grid_buf_up = make_dynamic_buffer( @@ -1606,7 +1622,7 @@ struct GridwiseMoeGemmBlockScale blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( make_tuple(m0, n0, n2 * N4 + n4)); constexpr auto cidx = Number{}; - if constexpr(IsInputGemm) // gu fusion, elementwise + if constexpr(IsInputGemm && !IsSplitK) // gu fusion, elementwise { if constexpr(ActivationOperation == Activation::silu_and_mul) { @@ -1743,8 +1759,12 @@ struct GridwiseMoeGemmBlockScale using EDataType = CDataType; - const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( - problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + const auto ds_grid_desc_m_n = + MakeDsGridDescriptor_M_N(problem.M, + problem.MPadded, + problem.N * (IsInputGemm && IsSplitK ? 2 : 1), + problem.NPadded * (IsInputGemm && IsSplitK ? 2 : 1), + problem.StrideDs); const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( @@ -1875,7 +1895,8 @@ struct GridwiseMoeGemmBlockScale { token_offset = token_offset * problem.TopK + (fused_token >> 24); } - scatter_offsets(m0) = token_offset * problem.N; + scatter_offsets(m0) = + token_offset * problem.N * (IsInputGemm && IsSplitK ? 2 : 1); }); block_sync_lds(); @@ -1953,8 +1974,8 @@ struct GridwiseMoeGemmBlockScale const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens, problem.MPadded, - problem.N, - problem.NPadded, + problem.N * (IsInputGemm && IsSplitK ? 2 : 1), + problem.NPadded * (IsInputGemm && IsSplitK ? 2 : 1), problem.StrideC); const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor( @@ -2125,8 +2146,10 @@ struct GridwiseMoeGemmBlockScale decltype(c_thread_buf) c_thread_buf_up; const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( - (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / - KPerBlock); + problem.KBatch == 1 + ? (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock + : problem.KBatch); // scale constexpr index_t ScaleSliceSizeM = MXdlPerWave; @@ -2202,7 +2225,7 @@ struct GridwiseMoeGemmBlockScale constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK); constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock); - if constexpr(IsInputGemm) + if constexpr(IsInputGemm && !IsSplitK) { const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize; const auto b_grid_buf_up = make_dynamic_buffer( @@ -2352,7 +2375,7 @@ struct GridwiseMoeGemmBlockScale blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( make_tuple(m0, n0, n2 * N4 + n4)); constexpr auto cidx = Number{}; - if constexpr(IsInputGemm) // gu fusion, elementwise + if constexpr(IsInputGemm && !IsSplitK) // gu fusion, elementwise { if constexpr(ActivationOperation == Activation::silu_and_mul) { @@ -2619,7 +2642,8 @@ struct GridwiseMoeGemmBlockScale { token_offset = token_offset * problem.TopK + (fused_token >> 24); } - scatter_offsets(m0) = token_offset * problem.N; + scatter_offsets(m0) = + token_offset * problem.N * (IsInputGemm && IsSplitK ? 2 : 1); }); block_sync_lds(); diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm1_blockscale_splitk.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm1_blockscale_splitk.hpp new file mode 100644 index 00000000000..9d9b8a62f59 --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm1_blockscale_splitk.hpp @@ -0,0 +1,232 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include + +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceMoeGemm1BlockScaleSplitK : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& sorted_token_ids, + const Tensor& expert_ids, + const Tensor& max_token_id, + const index_t sorted_tile_size, + const Tensor& a_t_k, + const Tensor& b_e_n_k, + Tensor& c_t_k_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : sorted_token_ids_{sorted_token_ids}, + expert_ids_{expert_ids}, + max_token_id_{max_token_id}, + sorted_tile_size_{sorted_tile_size}, + a_t_k_{a_t_k}, + b_e_n_k_{b_e_n_k}, + c_t_k_n_{c_t_k_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const Tensor& sorted_token_ids_; + const Tensor& expert_ids_; + const Tensor& max_token_id_; + index_t sorted_tile_size_; + const Tensor& a_t_k_; + const Tensor& b_e_n_k_; + Tensor& c_t_k_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceMoeGemm1BlockScaleSplitK::Argument; + + float Run(const Argument& arg) + { + const int full_n = arg.c_t_k_n_.mDesc.GetLengths()[2]; + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = arg.a_t_k_.mDesc.GetLengths()[1]; + AccDataType v_acc{0}; + + ComputeTypeA v_a{0}; + ComputeTypeB v_b{0}; + + const int t = arg.sorted_token_ids_(m) & 0xffffff; + const int topk_id = (arg.sorted_token_ids_(m) & 0xff000000) >> 24; + const int e = arg.expert_ids_(m / arg.sorted_tile_size_); + const int token_cnt = arg.a_t_k_.mDesc.GetLengths()[0]; + if(t < token_cnt) + { + for(int k = 0; k < K; ++k) + { + if constexpr(is_same_v) + { + uint8_t i4x2 = arg.a_t_k_(t, k).data; + uint8_t i4 = 0; + if(k % 2 == 1) + i4 = (i4x2 >> 0) & 0xf; + else + i4 = (i4x2 >> 4) & 0xf; +#if CK_USE_PK4_LAYOUT_SHUFFLE + v_a = i4_to_f32_gfx9(i4); +#else + v_a = i4 - 8; +#endif + } + else + { + arg.a_element_op_(v_a, arg.a_t_k_(t, k)); + } + // same for B matrix + if constexpr(is_same_v) + { + uint8_t i4x2 = arg.b_e_n_k_(e, k, n).data; + uint8_t i4 = 0; + if(k % 2 == 1) + { + i4 = (i4x2 >> 0) & 0xf; + } + else + { + i4 = (i4x2 >> 4) & 0xf; + } +#if CK_USE_PK4_LAYOUT_SHUFFLE + v_b = i4_to_f32_gfx9(i4); +#else + v_b = i4 - 8; +#endif + } + else + { + arg.b_element_op_(v_b, arg.b_e_n_k_(e, k, n)); + } + + v_acc += + ck::type_convert(v_a) * ck::type_convert(v_b); + } + CDataType v_c{0}; + + arg.c_element_op_(v_c, v_acc); + arg.c_t_k_n_(t, topk_id, n) = v_c; + } + }; + + const ck::index_t max_token_id = arg.max_token_id_(0); + make_ParallelTensorFunctor(f_mk_kn_mn, max_token_id, full_n)( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& sorted_token_ids, + const Tensor& expert_ids, + const Tensor& max_token_id, + const index_t sorted_tile_size, + const Tensor& a_t_k, + const Tensor& b_e_n_k, + Tensor& c_t_k_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{sorted_token_ids, + expert_ids, + max_token_id, + sorted_tile_size, + a_t_k, + b_e_n_k, + c_t_k_n, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceMoeGemm1BlaockScaleSplitK" + << std::endl; + // clang-format on + + return str.str(); + } + + static float i4_to_f32_gfx9(uint8_t i4) + { + static std::unordered_map u = {{0b1000, -0.5000f}, + {0b1001, -0.4375f}, + {0b1010, -0.3750f}, + {0b1011, -0.3125f}, + {0b1100, -0.2500f}, + {0b1101, -0.1875f}, + {0b1110, -0.1250f}, + {0b1111, -0.0625f}, + {0b0, +0.0000f}, + {0b1, +0.0625f}, + {0b10, +0.1250f}, + {0b11, +0.1875f}, + {0b100, +0.2500f}, + {0b101, +0.3125f}, + {0b110, +0.3750f}, + {0b111, +0.4375f}}; + + return u[i4]; + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck