From bce6ec11cd231e9c1ee6245e67a8394df9e46c79 Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Thu, 29 Jan 2026 14:26:43 -0700 Subject: [PATCH] Optimize tensor descriptor functor template instantiation Replace inline lambdas with named functor structs in transform_tensor_descriptor to reduce template instantiation overhead and improve compile times. Changes: - Add three named functors in tensor_descriptor.hpp: - convert_visible_to_hidden_id: maps visible dimension ID to hidden ID - convert_visible_ids_to_hidden_ids: maps sequence of visible IDs to hidden IDs - generate_arithmetic_sequence_from_scan: generates consecutive hidden dim ID ranges - Add utility functions in sequence_helper.hpp and tuple_helper.hpp: - unpack_and_merge_sequences(): unpacks tuple of sequences and merges them - generate_identity_sequences(): creates Tuple, Sequence<1>, ...> - Update 14 call sites across threadwise transfer, wrapper, and device files to use generate_identity_sequences() instead of generate_tuple with lambdas - Add comprehensive unit tests: - unit_sequence_helper.cpp: tests for new utility functions - unit_tensor_descriptor_functors.cpp: tests for new functors Co-Authored-By: Claude --- .../tensor_description/tensor_descriptor.hpp | 76 ++-- .../gpu/device/matrix_padder.hpp | 7 +- .../threadwise_tensor_slice_transfer_v3r1.hpp | 3 +- ...ise_tensor_slice_transfer_v3r1_dequant.hpp | 9 +- ...wise_tensor_slice_transfer_v3r1_gather.hpp | 6 +- .../threadwise_tensor_slice_transfer_v3r2.hpp | 6 +- .../threadwise_tensor_slice_transfer_v7r2.hpp | 6 +- .../threadwise_tensor_slice_transfer_v7r3.hpp | 6 +- ...ise_tensor_slice_transfer_v7r3_scatter.hpp | 6 +- include/ck/utility/sequence_helper.hpp | 18 + include/ck/utility/tuple_helper.hpp | 21 + include/ck/wrapper/layout.hpp | 3 +- include/ck/wrapper/operations/gemm.hpp | 3 +- include/ck/wrapper/tensor.hpp | 3 +- include/ck/wrapper/utils/layout_utils.hpp | 6 +- include/ck/wrapper/utils/tensor_partition.hpp | 3 +- test/util/CMakeLists.txt | 10 + test/util/unit_sequence_helper.cpp | 65 ++++ test/util/unit_tensor_descriptor_functors.cpp | 365 ++++++++++++++++++ 19 files changed, 549 insertions(+), 73 deletions(-) create mode 100644 test/util/unit_sequence_helper.cpp create mode 100644 test/util/unit_tensor_descriptor_functors.cpp diff --git a/include/ck/tensor_description/tensor_descriptor.hpp b/include/ck/tensor_description/tensor_descriptor.hpp index 2437132d114..fa526690e48 100644 --- a/include/ck/tensor_description/tensor_descriptor.hpp +++ b/include/ck/tensor_description/tensor_descriptor.hpp @@ -36,11 +36,9 @@ struct TensorDescriptor __host__ __device__ static constexpr index_t GetNumOfHiddenDimension() { - constexpr auto all_low_dim_ids = unpack( - [](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionIdss{}); + constexpr auto all_low_dim_ids = unpack_and_merge_sequences(LowerDimensionIdss{}); - constexpr auto all_up_dim_ids = unpack( - [](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionIdss{}); + constexpr auto all_up_dim_ids = unpack_and_merge_sequences(UpperDimensionIdss{}); constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids); @@ -311,6 +309,41 @@ struct lambda_get_up_dim_num } }; +// Maps a visible dimension ID to its corresponding hidden dimension ID +template +struct convert_visible_to_hidden_id +{ + __host__ __device__ constexpr auto operator()(index_t low_dim_visible_id) const + { + return OldTensorDescriptor::GetVisibleDimensionIds().At(low_dim_visible_id); + } +}; + +// Maps a sequence of visible IDs to their corresponding hidden IDs +template +struct convert_visible_ids_to_hidden_ids +{ + template + __host__ __device__ constexpr auto operator()(LowDimVisibleIds low_dim_visible_ids) const + { + return transform_sequences(convert_visible_to_hidden_id{}, + low_dim_visible_ids); + } +}; + +// Generates consecutive ranges of hidden dimension IDs for each transform's upper dimensions +template +struct generate_arithmetic_sequence_from_scan +{ + template + __host__ __device__ constexpr auto operator()(I) const + { + constexpr index_t start = OldHiddenDimNumber + UpDimNumbersScan{}.At(I{}); + constexpr index_t end = OldHiddenDimNumber + UpDimNumbersScan{}.At(I{} + Number<1>{}); + return typename arithmetic_sequence_gen::type{}; + } +}; + template ::value && is_valid_sequence_map::value, @@ -341,17 +374,9 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, // lower dimension's hidden idss // convert lower dimension visible idss (tuple of sequences) to hidden idss (tuple of // sequences) - constexpr auto low_dim_hidden_idss = transform_tuples( - // convert lower dimension visible ids (a sequence) to hidden ids (a sequence) - [](auto low_dim_visible_ids) constexpr { - return transform_sequences( - // convert lower dimension visible id to hidden id - [](auto low_dim_visible_id) constexpr { - return OldTensorDescriptor::GetVisibleDimensionIds()[low_dim_visible_id]; - }, - low_dim_visible_ids); - }, - NewLowerDimensionOldVisibleIdss{}); + constexpr auto low_dim_hidden_idss = + transform_tuples(convert_visible_ids_to_hidden_ids{}, + NewLowerDimensionOldVisibleIdss{}); constexpr index_t num_new_transform = NewTransforms::Size(); @@ -364,22 +389,17 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, constexpr auto up_dim_numbers_scan = merge_sequences( Sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, math::plus{}, Number<0>{})); + using UpDimNumbersScanType = remove_cvref_t; constexpr auto up_dim_hidden_idss = generate_tuple( - [old_hidden_dim_number, up_dim_numbers_scan](auto i) constexpr { - return - typename arithmetic_sequence_gen::type{}; - }, + generate_arithmetic_sequence_from_scan{}, Number{}); // new visible dimension's hidden ids constexpr auto unordered_new_visible_dim_hidden_ids = - unpack([](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss); + unpack_and_merge_sequences(up_dim_hidden_idss); constexpr auto new_visible_dim_unordered2ordered = - unpack([](auto... xs) constexpr { return merge_sequences(xs...); }, - NewUpperDimensionNewVisibleIdss{}); + unpack_and_merge_sequences(NewUpperDimensionNewVisibleIdss{}); constexpr auto new_visible_dim_hidden_ids = unordered_new_visible_dim_hidden_ids.ReorderGivenOld2New(new_visible_dim_unordered2ordered); diff --git a/include/ck/tensor_operation/gpu/device/matrix_padder.hpp b/include/ck/tensor_operation/gpu/device/matrix_padder.hpp index 6ead8a955cc..270f1061142 100644 --- a/include/ck/tensor_operation/gpu/device/matrix_padder.hpp +++ b/include/ck/tensor_operation/gpu/device/matrix_padder.hpp @@ -43,11 +43,8 @@ PadTensorDescriptor(const TensorDesc& desc, const TileLengths& tile_lengths, DoP }, Number{}); - // lower dimension Id - const auto lower_dimss = - generate_tuple([&](auto idim) { return Sequence{}; }, Number{}); - - // upper dimension Id + // lower/upper dimension Ids + const auto lower_dimss = generate_identity_sequences(); const auto upper_dimss = lower_dimss; return transform_tensor_descriptor(desc, transforms, lower_dimss, upper_dimss); diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp index 27c22f32b5a..7b9d136068f 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp @@ -739,8 +739,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 }, Number{}); - constexpr auto up_dim_idss = - generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + constexpr auto up_dim_idss = generate_identity_sequences(); return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); } diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp index 6eb4b21e216..2ddb34671a6 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp @@ -894,8 +894,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant }, Number{}); - constexpr auto up_dim_idss = - generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + constexpr auto up_dim_idss = generate_identity_sequences(); return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); } @@ -944,8 +943,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant }, Number{}); - constexpr auto up_dim_idss = - generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + constexpr auto up_dim_idss = generate_identity_sequences(); return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); } @@ -993,8 +991,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant }, Number{}); - constexpr auto up_dim_idss = - generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + constexpr auto up_dim_idss = generate_identity_sequences(); return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); } diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp index 2077eeebd79..e080d7eeac7 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp @@ -833,8 +833,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather }, Number{}); - constexpr auto up_dim_idss = - generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + constexpr auto up_dim_idss = generate_identity_sequences(); return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); } @@ -892,8 +891,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather }, Number{}); - constexpr auto up_dim_idss = - generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + constexpr auto up_dim_idss = generate_identity_sequences(); return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); } diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp index 56ae553f2f0..3c7291cca31 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp @@ -692,8 +692,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2 }, Number{}); - constexpr auto up_dim_idss = - generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + constexpr auto up_dim_idss = generate_identity_sequences(); return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); } @@ -744,8 +743,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2 }, Number{}); - constexpr auto up_dim_idss = - generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + constexpr auto up_dim_idss = generate_identity_sequences(); return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); } diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp index 87cecc75740..6326f6cbda2 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp @@ -514,8 +514,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2 }, Number{}); - constexpr auto up_dim_idss = - generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + constexpr auto up_dim_idss = generate_identity_sequences(); return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); } @@ -563,8 +562,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2 }, Number{}); - constexpr auto up_dim_idss = - generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + constexpr auto up_dim_idss = generate_identity_sequences(); return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); } diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp index cbca8629c38..9ad9c940f4f 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp @@ -656,8 +656,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3 }, Number{}); - constexpr auto up_dim_idss = - generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + constexpr auto up_dim_idss = generate_identity_sequences(); return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); } @@ -706,8 +705,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3 }, Number{}); - constexpr auto up_dim_idss = - generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + constexpr auto up_dim_idss = generate_identity_sequences(); return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); } diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp index fe975f4e36b..732922c1576 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp @@ -548,8 +548,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter }, Number{}); - constexpr auto up_dim_idss = - generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + constexpr auto up_dim_idss = generate_identity_sequences(); return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); } @@ -598,8 +597,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter }, Number{}); - constexpr auto up_dim_idss = - generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + constexpr auto up_dim_idss = generate_identity_sequences(); return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); } diff --git a/include/ck/utility/sequence_helper.hpp b/include/ck/utility/sequence_helper.hpp index 35a6a486324..6f096aef74c 100644 --- a/include/ck/utility/sequence_helper.hpp +++ b/include/ck/utility/sequence_helper.hpp @@ -3,6 +3,7 @@ #pragma once +#include "ck/utility/functional4.hpp" #include "ck/utility/tuple.hpp" namespace ck { @@ -34,4 +35,21 @@ __host__ __device__ constexpr auto to_sequence(Tuple...>) return Sequence{}; } +// Functor wrapper for merge_sequences to enable reuse across call sites +struct merge_sequences_functor +{ + template + __host__ __device__ constexpr auto operator()(Seqs... seqs) const + { + return merge_sequences(seqs...); + } +}; + +// Unpacks tuple of sequences and merges them into a single sequence +template +__host__ __device__ constexpr auto unpack_and_merge_sequences(TupleOfSequences tuple_of_sequences) +{ + return unpack(merge_sequences_functor{}, tuple_of_sequences); +} + } // namespace ck diff --git a/include/ck/utility/tuple_helper.hpp b/include/ck/utility/tuple_helper.hpp index 52ca5e91266..c71ef8427ad 100644 --- a/include/ck/utility/tuple_helper.hpp +++ b/include/ck/utility/tuple_helper.hpp @@ -37,6 +37,27 @@ __host__ __device__ constexpr auto generate_tie(F&& f, Number) typename arithmetic_sequence_gen<0, N, 1>::type{}); } +// Creates Tuple, Sequence<1>, ..., Sequence> +namespace detail { +template +__host__ __device__ constexpr auto make_identity_sequences_impl(Sequence) +{ + return make_tuple(Sequence{}...); +} +} // namespace detail + +template +__host__ __device__ constexpr auto generate_identity_sequences() +{ + return detail::make_identity_sequences_impl(make_index_sequence{}); +} + +template +__host__ __device__ constexpr auto generate_identity_sequences(Number) +{ + return generate_identity_sequences(); +} + // tx and ty are tuple of references, return type of will tuple of referennce (not rvalue) template __host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple& tx, diff --git a/include/ck/wrapper/layout.hpp b/include/ck/wrapper/layout.hpp index 334d5851db0..2f81a44f399 100644 --- a/include/ck/wrapper/layout.hpp +++ b/include/ck/wrapper/layout.hpp @@ -242,8 +242,7 @@ struct Layout const auto lower_dims = generate_tuple([&](auto i) { return GenerateLowerDim>(shape); }, Number::Size()>{}); - const auto upper_dims = generate_tuple([&](auto i) { return Sequence{}; }, - Number::Size()>{}); + const auto upper_dims = generate_identity_sequences::Size()>(); return transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims); } diff --git a/include/ck/wrapper/operations/gemm.hpp b/include/ck/wrapper/operations/gemm.hpp index d328ac7d42f..46142bd1de1 100644 --- a/include/ck/wrapper/operations/gemm.hpp +++ b/include/ck/wrapper/operations/gemm.hpp @@ -259,8 +259,7 @@ make_blockwise_gemm_xdl_c_local_partition(CTensorType& c_local_tile_tensor) const auto partition_desc = BlockwiseGemmXdlops::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2( layout(c_local_tile_tensor).GetUnrolledDescriptor()); - const auto lower_upper_dims = - generate_tuple([&](auto i) { return Sequence{}; }, Number<8>{}); + const auto lower_upper_dims = generate_identity_sequences<8>(); auto sliced_desc = transform_tensor_descriptor( partition_desc, diff --git a/include/ck/wrapper/tensor.hpp b/include/ck/wrapper/tensor.hpp index 9f8278a3578..120f0c694db 100644 --- a/include/ck/wrapper/tensor.hpp +++ b/include/ck/wrapper/tensor.hpp @@ -187,8 +187,7 @@ __host__ __device__ constexpr auto GenerateSlicedDescriptor(const Tuple& const auto transforms = GenerateSliceTransforms(idx, shape); using TransformsTupleType = decltype(transforms); - const auto lower_dims = - generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + const auto lower_dims = generate_identity_sequences(); const auto upper_dims = decltype(GenerateUpperDims<0>(TransformsTupleType{})){}; return transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims); } diff --git a/include/ck/wrapper/utils/layout_utils.hpp b/include/ck/wrapper/utils/layout_utils.hpp index 8dd111b8721..e9686de6e79 100644 --- a/include/ck/wrapper/utils/layout_utils.hpp +++ b/include/ck/wrapper/utils/layout_utils.hpp @@ -186,8 +186,7 @@ __host__ __device__ constexpr auto get(const Layout& layout }, Number{}); - const auto lower_dims = - generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + const auto lower_dims = generate_identity_sequences(); const auto upper_dims = generate_tuple( [&](auto i) { if constexpr(i < shape_offset || i >= shape_offset + new_shape_dims) @@ -492,8 +491,7 @@ __host__ __device__ constexpr auto unmerge(const Layout& la }, Number{}); - constexpr auto lower_dims = - generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + constexpr auto lower_dims = generate_identity_sequences(); constexpr auto upper_dims = generate_tuple( [&](auto i) { if constexpr(is_detected>::value) diff --git a/include/ck/wrapper/utils/tensor_partition.hpp b/include/ck/wrapper/utils/tensor_partition.hpp index 5099f35cdab..34986c270bc 100644 --- a/include/ck/wrapper/utils/tensor_partition.hpp +++ b/include/ck/wrapper/utils/tensor_partition.hpp @@ -293,8 +293,7 @@ make_local_partition(TensorType& tensor, }, Number::Size()>{}); const auto lower_upper_dims = - generate_tuple([&](auto i) { return Sequence{}; }, - Number::Size()>{}); + generate_identity_sequences::Size()>(); auto sliced_desc = transform_tensor_descriptor(unrolled_desc, transforms, lower_upper_dims, lower_upper_dims); // Create layout diff --git a/test/util/CMakeLists.txt b/test/util/CMakeLists.txt index bf0a444f18b..c4d91326e20 100644 --- a/test/util/CMakeLists.txt +++ b/test/util/CMakeLists.txt @@ -5,3 +5,13 @@ add_gtest_executable(unit_sequence unit_sequence.cpp) if(result EQUAL 0) target_link_libraries(unit_sequence PRIVATE utility) endif() + +add_gtest_executable(unit_sequence_helper unit_sequence_helper.cpp) +if(result EQUAL 0) + target_link_libraries(unit_sequence_helper PRIVATE utility) +endif() + +add_gtest_executable(unit_tensor_descriptor_functors unit_tensor_descriptor_functors.cpp) +if(result EQUAL 0) + target_link_libraries(unit_tensor_descriptor_functors PRIVATE utility) +endif() diff --git a/test/util/unit_sequence_helper.cpp b/test/util/unit_sequence_helper.cpp new file mode 100644 index 00000000000..4f8740f7997 --- /dev/null +++ b/test/util/unit_sequence_helper.cpp @@ -0,0 +1,65 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include "ck/utility/sequence_helper.hpp" +#include "ck/utility/tuple_helper.hpp" + +using namespace ck; + +// Tests for generate_identity_sequences (PR #3588) +TEST(GenerateIdentitySequences, Size5) +{ + auto result = generate_identity_sequences(Number<5>{}); + auto expected = + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}); + EXPECT_TRUE((is_same::value)); +} + +TEST(GenerateIdentitySequences, Size1) +{ + auto result = generate_identity_sequences(Number<1>{}); + auto expected = make_tuple(Sequence<0>{}); + EXPECT_TRUE((is_same::value)); +} + +TEST(GenerateIdentitySequences, Size0) +{ + auto result = generate_identity_sequences(Number<0>{}); + auto expected = make_tuple(); + EXPECT_TRUE((is_same::value)); +} + +TEST(GenerateIdentitySequences, WithNumber) +{ + constexpr auto result = generate_identity_sequences(Number<3>{}); + EXPECT_EQ(result.Size(), 3); + EXPECT_TRUE((is_same{})), const Sequence<0>&>::value)); + EXPECT_TRUE((is_same{})), const Sequence<1>&>::value)); + EXPECT_TRUE((is_same{})), const Sequence<2>&>::value)); +} + +// Tests for unpack_and_merge_sequences (PR #3589) +TEST(UnpackAndMergeSequences, MergeMultipleSequences) +{ + auto input = make_tuple(Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5, 6>{}); + auto result = unpack_and_merge_sequences(input); + auto expected = Sequence<1, 2, 3, 4, 5, 6>{}; + EXPECT_TRUE((is_same::value)); +} + +TEST(UnpackAndMergeSequences, SingleSequence) +{ + auto input = make_tuple(Sequence<10, 20, 30>{}); + auto result = unpack_and_merge_sequences(input); + auto expected = Sequence<10, 20, 30>{}; + EXPECT_TRUE((is_same::value)); +} + +TEST(UnpackAndMergeSequences, TwoSequences) +{ + auto input = make_tuple(Sequence<100>{}, Sequence<200, 300>{}); + auto result = unpack_and_merge_sequences(input); + auto expected = Sequence<100, 200, 300>{}; + EXPECT_TRUE((is_same::value)); +} diff --git a/test/util/unit_tensor_descriptor_functors.cpp b/test/util/unit_tensor_descriptor_functors.cpp new file mode 100644 index 00000000000..3e2cbcb3577 --- /dev/null +++ b/test/util/unit_tensor_descriptor_functors.cpp @@ -0,0 +1,365 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" + +using namespace ck; + +// ============================================================================= +// Tests for convert_visible_to_hidden_id functor +// ============================================================================= + +TEST(ConvertVisibleToHiddenId, SimplePackedDescriptor3D) +{ + // For a 3D packed descriptor created via make_naive_tensor_descriptor_packed, + // the visible dimension IDs are Sequence<1, 2, 3> (hidden dim 0 is the element space) + constexpr auto desc = + make_naive_tensor_descriptor_packed(make_tuple(Number<4>{}, Number<8>{}, Number<16>{})); + + using DescType = decltype(desc); + + // Verify the visible dimension IDs for a packed descriptor + constexpr auto visible_ids = DescType::GetVisibleDimensionIds(); + using ExpectedVisibleIds = Sequence<1, 2, 3>; + EXPECT_TRUE((is_same::value)); + + // Test the functor + constexpr auto functor = convert_visible_to_hidden_id{}; + + // Visible ID 0 -> Hidden ID 1 + constexpr auto hidden_0 = functor(Number<0>{}); + EXPECT_EQ(hidden_0, 1); + + // Visible ID 1 -> Hidden ID 2 + constexpr auto hidden_1 = functor(Number<1>{}); + EXPECT_EQ(hidden_1, 2); + + // Visible ID 2 -> Hidden ID 3 + constexpr auto hidden_2 = functor(Number<2>{}); + EXPECT_EQ(hidden_2, 3); +} + +TEST(ConvertVisibleToHiddenId, SimplePackedDescriptor2D) +{ + constexpr auto desc = + make_naive_tensor_descriptor_packed(make_tuple(Number<32>{}, Number<64>{})); + + using DescType = decltype(desc); + + constexpr auto functor = convert_visible_to_hidden_id{}; + + // For 2D packed: visible IDs are Sequence<1, 2> + constexpr auto hidden_0 = functor(Number<0>{}); + constexpr auto hidden_1 = functor(Number<1>{}); + + EXPECT_EQ(hidden_0, 1); + EXPECT_EQ(hidden_1, 2); +} + +TEST(ConvertVisibleToHiddenId, SimplePackedDescriptor1D) +{ + constexpr auto desc = make_naive_tensor_descriptor_packed(make_tuple(Number<128>{})); + + using DescType = decltype(desc); + + constexpr auto functor = convert_visible_to_hidden_id{}; + + // For 1D packed: visible IDs are Sequence<1> + constexpr auto hidden_0 = functor(Number<0>{}); + EXPECT_EQ(hidden_0, 1); +} + +TEST(ConvertVisibleToHiddenId, SimplePackedDescriptor4D) +{ + constexpr auto desc = make_naive_tensor_descriptor_packed( + make_tuple(Number<2>{}, Number<4>{}, Number<8>{}, Number<16>{})); + + using DescType = decltype(desc); + + constexpr auto functor = convert_visible_to_hidden_id{}; + + // For 4D packed: visible IDs are Sequence<1, 2, 3, 4> + EXPECT_EQ(functor(Number<0>{}), 1); + EXPECT_EQ(functor(Number<1>{}), 2); + EXPECT_EQ(functor(Number<2>{}), 3); + EXPECT_EQ(functor(Number<3>{}), 4); +} + +// ============================================================================= +// Tests for convert_visible_ids_to_hidden_ids functor +// ============================================================================= + +TEST(ConvertVisibleIdsToHiddenIds, SingleElement) +{ + constexpr auto desc = + make_naive_tensor_descriptor_packed(make_tuple(Number<4>{}, Number<8>{}, Number<16>{})); + + using DescType = decltype(desc); + + constexpr auto functor = convert_visible_ids_to_hidden_ids{}; + + // Convert single visible ID + constexpr auto result = functor(Sequence<0>{}); + using ExpectedResult = Sequence<1>; + EXPECT_TRUE((is_same::value)); +} + +TEST(ConvertVisibleIdsToHiddenIds, MultipleElements) +{ + constexpr auto desc = + make_naive_tensor_descriptor_packed(make_tuple(Number<4>{}, Number<8>{}, Number<16>{})); + + using DescType = decltype(desc); + + constexpr auto functor = convert_visible_ids_to_hidden_ids{}; + + // Convert multiple visible IDs + constexpr auto result = functor(Sequence<0, 2>{}); + using ExpectedResult = Sequence<1, 3>; + EXPECT_TRUE((is_same::value)); +} + +TEST(ConvertVisibleIdsToHiddenIds, AllElements) +{ + constexpr auto desc = + make_naive_tensor_descriptor_packed(make_tuple(Number<4>{}, Number<8>{}, Number<16>{})); + + using DescType = decltype(desc); + + constexpr auto functor = convert_visible_ids_to_hidden_ids{}; + + // Convert all visible IDs + constexpr auto result = functor(Sequence<0, 1, 2>{}); + using ExpectedResult = Sequence<1, 2, 3>; + EXPECT_TRUE((is_same::value)); +} + +TEST(ConvertVisibleIdsToHiddenIds, ReversedOrder) +{ + constexpr auto desc = + make_naive_tensor_descriptor_packed(make_tuple(Number<4>{}, Number<8>{}, Number<16>{})); + + using DescType = decltype(desc); + + constexpr auto functor = convert_visible_ids_to_hidden_ids{}; + + // Convert visible IDs in reverse order + constexpr auto result = functor(Sequence<2, 1, 0>{}); + using ExpectedResult = Sequence<3, 2, 1>; + EXPECT_TRUE((is_same::value)); +} + +TEST(ConvertVisibleIdsToHiddenIds, EmptySequence) +{ + constexpr auto desc = + make_naive_tensor_descriptor_packed(make_tuple(Number<4>{}, Number<8>{}, Number<16>{})); + + using DescType = decltype(desc); + + constexpr auto functor = convert_visible_ids_to_hidden_ids{}; + + // Convert empty sequence + constexpr auto result = functor(Sequence<>{}); + using ExpectedResult = Sequence<>; + EXPECT_TRUE((is_same::value)); +} + +TEST(ConvertVisibleIdsToHiddenIds, HighDimensional) +{ + constexpr auto desc = make_naive_tensor_descriptor_packed( + make_tuple(Number<2>{}, Number<4>{}, Number<8>{}, Number<16>{}, Number<32>{})); + + using DescType = decltype(desc); + + constexpr auto functor = convert_visible_ids_to_hidden_ids{}; + + // Convert subset of visible IDs + constexpr auto result = functor(Sequence<1, 3>{}); + using ExpectedResult = Sequence<2, 4>; + EXPECT_TRUE((is_same::value)); +} + +// ============================================================================= +// Tests for generate_arithmetic_sequence_from_scan functor +// ============================================================================= + +TEST(GenerateArithmeticSequenceFromScan, SingleRange) +{ + // Scan sequence: <0, 3> means: + // - Index 0: range from (base + 0) to (base + 3) = Sequence + using ScanSeq = Sequence<0, 3>; + constexpr index_t base = 5; + + constexpr auto functor = generate_arithmetic_sequence_from_scan{}; + + constexpr auto result = functor(Number<0>{}); + using ExpectedResult = Sequence<5, 6, 7>; + EXPECT_TRUE((is_same::value)); +} + +TEST(GenerateArithmeticSequenceFromScan, MultipleRanges) +{ + // Scan sequence: <0, 2, 5> means: + // - Index 0: range from (base + 0) to (base + 2) = 2 elements + // - Index 1: range from (base + 2) to (base + 5) = 3 elements + using ScanSeq = Sequence<0, 2, 5>; + constexpr index_t base = 3; + + constexpr auto functor = generate_arithmetic_sequence_from_scan{}; + + // First range: base + [0, 2) = Sequence<3, 4> + constexpr auto result_0 = functor(Number<0>{}); + using ExpectedResult0 = Sequence<3, 4>; + EXPECT_TRUE((is_same::value)); + + // Second range: base + [2, 5) = Sequence<5, 6, 7> + constexpr auto result_1 = functor(Number<1>{}); + using ExpectedResult1 = Sequence<5, 6, 7>; + EXPECT_TRUE((is_same::value)); +} + +TEST(GenerateArithmeticSequenceFromScan, SingleElementRanges) +{ + // Scan sequence with single-element ranges: <0, 1, 2, 3> + using ScanSeq = Sequence<0, 1, 2, 3>; + constexpr index_t base = 10; + + constexpr auto functor = generate_arithmetic_sequence_from_scan{}; + + // Each range contains exactly one element + constexpr auto result_0 = functor(Number<0>{}); + using ExpectedResult0 = Sequence<10>; + EXPECT_TRUE((is_same::value)); + + constexpr auto result_1 = functor(Number<1>{}); + using ExpectedResult1 = Sequence<11>; + EXPECT_TRUE((is_same::value)); + + constexpr auto result_2 = functor(Number<2>{}); + using ExpectedResult2 = Sequence<12>; + EXPECT_TRUE((is_same::value)); +} + +TEST(GenerateArithmeticSequenceFromScan, ZeroBase) +{ + // Test with base = 0 + using ScanSeq = Sequence<0, 2, 4>; + constexpr index_t base = 0; + + constexpr auto functor = generate_arithmetic_sequence_from_scan{}; + + constexpr auto result_0 = functor(Number<0>{}); + using ExpectedResult0 = Sequence<0, 1>; + EXPECT_TRUE((is_same::value)); + + constexpr auto result_1 = functor(Number<1>{}); + using ExpectedResult1 = Sequence<2, 3>; + EXPECT_TRUE((is_same::value)); +} + +TEST(GenerateArithmeticSequenceFromScan, VariableSizeRanges) +{ + // Scan sequence with variable-size ranges: <0, 1, 4, 6, 10> + using ScanSeq = Sequence<0, 1, 4, 6, 10>; + constexpr index_t base = 2; + + constexpr auto functor = generate_arithmetic_sequence_from_scan{}; + + // Range 0: [2+0, 2+1) = Sequence<2> + constexpr auto result_0 = functor(Number<0>{}); + using ExpectedResult0 = Sequence<2>; + EXPECT_TRUE((is_same::value)); + + // Range 1: [2+1, 2+4) = Sequence<3, 4, 5> + constexpr auto result_1 = functor(Number<1>{}); + using ExpectedResult1 = Sequence<3, 4, 5>; + EXPECT_TRUE((is_same::value)); + + // Range 2: [2+4, 2+6) = Sequence<6, 7> + constexpr auto result_2 = functor(Number<2>{}); + using ExpectedResult2 = Sequence<6, 7>; + EXPECT_TRUE((is_same::value)); + + // Range 3: [2+6, 2+10) = Sequence<8, 9, 10, 11> + constexpr auto result_3 = functor(Number<3>{}); + using ExpectedResult3 = Sequence<8, 9, 10, 11>; + EXPECT_TRUE((is_same::value)); +} + +TEST(GenerateArithmeticSequenceFromScan, LargeBase) +{ + // Test with a larger base value + using ScanSeq = Sequence<0, 3>; + constexpr index_t base = 100; + + constexpr auto functor = generate_arithmetic_sequence_from_scan{}; + + constexpr auto result = functor(Number<0>{}); + using ExpectedResult = Sequence<100, 101, 102>; + EXPECT_TRUE((is_same::value)); +} + +// ============================================================================= +// Integration tests - verify functors work together in transform_tensor_descriptor context +// ============================================================================= + +TEST(TensorDescriptorFunctorsIntegration, TransformPreservesMapping) +{ + // Create a simple packed 2D descriptor + constexpr auto desc = make_naive_tensor_descriptor_packed(make_tuple(Number<4>{}, Number<8>{})); + + using DescType = decltype(desc); + + // Verify the descriptor structure + EXPECT_EQ(DescType::GetNumOfVisibleDimension(), 2); + EXPECT_EQ(DescType::GetNumOfHiddenDimension(), 3); // 1 (element space) + 2 (visible dims) + + // Test that convert_visible_to_hidden_id preserves the expected mapping + constexpr auto functor = convert_visible_to_hidden_id{}; + + // The hidden IDs should be consecutive starting from 1 + constexpr auto hidden_0 = functor(Number<0>{}); + constexpr auto hidden_1 = functor(Number<1>{}); + + EXPECT_EQ(hidden_0, 1); + EXPECT_EQ(hidden_1, 2); + EXPECT_EQ(hidden_1 - hidden_0, 1); // Should be consecutive +} + +TEST(TensorDescriptorFunctorsIntegration, ConvertIdsMatchDirectAccess) +{ + // Verify that the functor produces the same result as direct access + constexpr auto desc = + make_naive_tensor_descriptor_packed(make_tuple(Number<2>{}, Number<4>{}, Number<8>{})); + + using DescType = decltype(desc); + + constexpr auto visible_ids = DescType::GetVisibleDimensionIds(); + constexpr auto functor = convert_visible_to_hidden_id{}; + + // Functor result should match direct sequence access + EXPECT_EQ(functor(Number<0>{}), visible_ids.At(Number<0>{})); + EXPECT_EQ(functor(Number<1>{}), visible_ids.At(Number<1>{})); + EXPECT_EQ(functor(Number<2>{}), visible_ids.At(Number<2>{})); +} + +TEST(TensorDescriptorFunctorsIntegration, BatchConvertMatchesSingle) +{ + // Verify that batch conversion produces the same result as individual conversions + constexpr auto desc = + make_naive_tensor_descriptor_packed(make_tuple(Number<4>{}, Number<8>{}, Number<16>{})); + + using DescType = decltype(desc); + + constexpr auto single_functor = convert_visible_to_hidden_id{}; + constexpr auto batch_functor = convert_visible_ids_to_hidden_ids{}; + + constexpr auto batch_result = batch_functor(Sequence<0, 1, 2>{}); + + // Each element should match the single conversion + EXPECT_EQ(batch_result.At(Number<0>{}), single_functor(Number<0>{})); + EXPECT_EQ(batch_result.At(Number<1>{}), single_functor(Number<1>{})); + EXPECT_EQ(batch_result.At(Number<2>{}), single_functor(Number<2>{})); +}