Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 48 additions & 28 deletions include/ck/tensor_description/tensor_descriptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,9 @@ struct TensorDescriptor

__host__ __device__ static constexpr index_t GetNumOfHiddenDimension()
{
constexpr auto all_low_dim_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionIdss{});
constexpr auto all_low_dim_ids = unpack_and_merge_sequences(LowerDimensionIdss{});

constexpr auto all_up_dim_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionIdss{});
constexpr auto all_up_dim_ids = unpack_and_merge_sequences(UpperDimensionIdss{});

constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids);

Expand Down Expand Up @@ -319,6 +317,41 @@ struct lambda_get_up_dim_num
}
};

// Maps a visible dimension ID to its corresponding hidden dimension ID
template <typename OldTensorDescriptor>
struct convert_visible_to_hidden_id
{
__host__ __device__ constexpr auto operator()(index_t low_dim_visible_id) const
{
return OldTensorDescriptor::GetVisibleDimensionIds().At(low_dim_visible_id);
}
};

// Maps a sequence of visible IDs to their corresponding hidden IDs
template <typename OldTensorDescriptor>
struct convert_visible_ids_to_hidden_ids
{
template <typename LowDimVisibleIds>
__host__ __device__ constexpr auto operator()(LowDimVisibleIds low_dim_visible_ids) const
{
return transform_sequences(convert_visible_to_hidden_id<OldTensorDescriptor>{},
low_dim_visible_ids);
}
};

// Generates consecutive ranges of hidden dimension IDs for each transform's upper dimensions
template <index_t OldHiddenDimNumber, typename UpDimNumbersScan>
struct generate_arithmetic_sequence_from_scan
{
template <typename I>
__host__ __device__ constexpr auto operator()(I) const
{
constexpr index_t start = OldHiddenDimNumber + UpDimNumbersScan{}.At(I{});
constexpr index_t end = OldHiddenDimNumber + UpDimNumbersScan{}.At(I{} + Number<1>{});
return typename arithmetic_sequence_gen<start, end, 1>::type{};
}
};

template <typename OldTensorDescriptor,
typename NewTransforms,
typename NewLowerDimensionOldVisibleIdss,
Expand All @@ -335,11 +368,11 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
NewTransforms::Size() == NewUpperDimensionNewVisibleIdss::Size(),
"wrong! inconsitent number of transform");

constexpr auto all_old_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
NewLowerDimensionOldVisibleIdss{});
constexpr auto all_old_top_ids =
unpack_and_merge_sequences(NewLowerDimensionOldVisibleIdss{});

constexpr auto all_new_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
NewUpperDimensionNewVisibleIdss{});
constexpr auto all_new_top_ids =
unpack_and_merge_sequences(NewUpperDimensionNewVisibleIdss{});

static_assert(is_valid_sequence_map<decltype(all_old_top_ids)>::value &&
is_valid_sequence_map<decltype(all_new_top_ids)>::value,
Expand All @@ -349,17 +382,9 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
// lower dimension's hidden idss
// convert lower dimension visible idss (tuple of sequences) to hidden idss (tuple of
// sequences)
constexpr auto low_dim_hidden_idss = transform_tuples(
// convert lower dimension visible ids (a sequence) to hidden ids (a sequence)
[](auto low_dim_visible_ids) constexpr {
return transform_sequences(
// convert lower dimension visible id to hidden id
[](auto low_dim_visible_id) constexpr {
return OldTensorDescriptor::GetVisibleDimensionIds()[low_dim_visible_id];
},
low_dim_visible_ids);
},
NewLowerDimensionOldVisibleIdss{});
constexpr auto low_dim_hidden_idss =
transform_tuples(convert_visible_ids_to_hidden_ids<OldTensorDescriptor>{},
NewLowerDimensionOldVisibleIdss{});

constexpr index_t num_new_transform = NewTransforms::Size();

Expand All @@ -372,22 +397,17 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
constexpr auto up_dim_numbers_scan = merge_sequences(
Sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, math::plus<index_t>{}, Number<0>{}));

using UpDimNumbersScanType = remove_cvref_t<decltype(up_dim_numbers_scan)>;
constexpr auto up_dim_hidden_idss = generate_tuple(
[old_hidden_dim_number, up_dim_numbers_scan](auto i) constexpr {
return
typename arithmetic_sequence_gen<old_hidden_dim_number + up_dim_numbers_scan[i],
old_hidden_dim_number + up_dim_numbers_scan[i + 1],
1>::type{};
},
generate_arithmetic_sequence_from_scan<old_hidden_dim_number, UpDimNumbersScanType>{},
Number<num_new_transform>{});

// new visible dimension's hidden ids
constexpr auto unordered_new_visible_dim_hidden_ids =
unpack([](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss);
unpack_and_merge_sequences(up_dim_hidden_idss);

constexpr auto new_visible_dim_unordered2ordered =
unpack([](auto... xs) constexpr { return merge_sequences(xs...); },
NewUpperDimensionNewVisibleIdss{});
unpack_and_merge_sequences(NewUpperDimensionNewVisibleIdss{});

constexpr auto new_visible_dim_hidden_ids =
unordered_new_visible_dim_hidden_ids.ReorderGivenOld2New(new_visible_dim_unordered2ordered);
Expand Down
7 changes: 2 additions & 5 deletions include/ck/tensor_operation/gpu/device/matrix_padder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,8 @@ PadTensorDescriptor(const TensorDesc& desc, const TileLengths& tile_lengths, DoP
},
Number<num_dim>{});

// lower dimension Id
const auto lower_dimss =
generate_tuple([&](auto idim) { return Sequence<idim.value>{}; }, Number<num_dim>{});

// upper dimension Id
// lower/upper dimension Ids
const auto lower_dimss = generate_identity_sequences<num_dim>();
const auto upper_dimss = lower_dimss;

return transform_tensor_descriptor(desc, transforms, lower_dimss, upper_dimss);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -739,8 +739,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -894,8 +894,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down Expand Up @@ -944,8 +943,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down Expand Up @@ -993,8 +991,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -833,8 +833,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down Expand Up @@ -892,8 +891,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -692,8 +692,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down Expand Up @@ -744,8 +743,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -514,8 +514,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down Expand Up @@ -563,8 +562,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -657,8 +657,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down Expand Up @@ -707,8 +706,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -548,8 +548,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down Expand Up @@ -598,8 +597,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down
18 changes: 18 additions & 0 deletions include/ck/utility/sequence_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#pragma once

#include "ck/utility/functional4.hpp"
#include "ck/utility/tuple.hpp"

namespace ck {
Expand Down Expand Up @@ -34,4 +35,21 @@ __host__ __device__ constexpr auto to_sequence(Tuple<Number<Is>...>)
return Sequence<Is...>{};
}

// Functor wrapper for merge_sequences to enable reuse across call sites
struct merge_sequences_functor
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one thing to consider is whether those new helper functors are the implementation detail and not the part of the header interface

{
template <typename... Seqs>
__host__ __device__ constexpr auto operator()(Seqs... seqs) const
{
return merge_sequences(seqs...);
}
};

// Unpacks tuple of sequences and merges them into a single sequence
template <typename TupleOfSequences>
__host__ __device__ constexpr auto unpack_and_merge_sequences(TupleOfSequences tuple_of_sequences)
{
return unpack(merge_sequences_functor{}, tuple_of_sequences);
}

} // namespace ck
21 changes: 21 additions & 0 deletions include/ck/utility/tuple_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,27 @@ __host__ __device__ constexpr auto generate_tie(F&& f, Number<N>)
typename arithmetic_sequence_gen<0, N, 1>::type{});
}

// Creates Tuple<Sequence<0>, Sequence<1>, ..., Sequence<N-1>>
namespace detail {
template <index_t... Is>
__host__ __device__ constexpr auto make_identity_sequences_impl(Sequence<Is...>)
{
return make_tuple(Sequence<Is>{}...);
}
} // namespace detail

template <index_t N>
__host__ __device__ constexpr auto generate_identity_sequences()
{
return detail::make_identity_sequences_impl(make_index_sequence<N>{});
}

template <index_t N>
__host__ __device__ constexpr auto generate_identity_sequences(Number<N>)
{
return generate_identity_sequences<N>();
}

// tx and ty are tuple of references, return type of will tuple of referennce (not rvalue)
template <typename... X, typename... Y>
__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple<X&...>& tx,
Expand Down
3 changes: 1 addition & 2 deletions include/ck/wrapper/layout.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,7 @@ struct Layout
const auto lower_dims =
generate_tuple([&](auto i) { return GenerateLowerDim<Number<i>>(shape); },
Number<Tuple<ShapeDims...>::Size()>{});
const auto upper_dims = generate_tuple([&](auto i) { return Sequence<i.value>{}; },
Number<Tuple<ShapeDims...>::Size()>{});
const auto upper_dims = generate_identity_sequences<Tuple<ShapeDims...>::Size()>();

return transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims);
}
Expand Down
3 changes: 1 addition & 2 deletions include/ck/wrapper/operations/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,7 @@ make_blockwise_gemm_xdl_c_local_partition(CTensorType& c_local_tile_tensor)
const auto partition_desc = BlockwiseGemmXdlops::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(
layout(c_local_tile_tensor).GetUnrolledDescriptor());

const auto lower_upper_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<8>{});
const auto lower_upper_dims = generate_identity_sequences<8>();

auto sliced_desc = transform_tensor_descriptor(
partition_desc,
Expand Down
3 changes: 1 addition & 2 deletions include/ck/wrapper/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,7 @@ __host__ __device__ constexpr auto GenerateSlicedDescriptor(const Tuple<Ts...>&
const auto transforms = GenerateSliceTransforms(idx, shape);
using TransformsTupleType = decltype(transforms);

const auto lower_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<old_shape_dims>{});
const auto lower_dims = generate_identity_sequences<old_shape_dims>();
const auto upper_dims = decltype(GenerateUpperDims<0>(TransformsTupleType{})){};
return transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims);
}
Expand Down
6 changes: 2 additions & 4 deletions include/ck/wrapper/utils/layout_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,7 @@ __host__ __device__ constexpr auto get(const Layout<Shape, UnrolledDesc>& layout
},
Number<old_shape_dims>{});

const auto lower_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<old_shape_dims>{});
const auto lower_dims = generate_identity_sequences<old_shape_dims>();
const auto upper_dims = generate_tuple(
[&](auto i) {
if constexpr(i < shape_offset || i >= shape_offset + new_shape_dims)
Expand Down Expand Up @@ -492,8 +491,7 @@ __host__ __device__ constexpr auto unmerge(const Layout<Shape, UnrolledDesc>& la
},
Number<dims>{});

constexpr auto lower_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<dims>{});
constexpr auto lower_dims = generate_identity_sequences<dims>();
constexpr auto upper_dims = generate_tuple(
[&](auto i) {
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, NewIdxs>>::value)
Expand Down
3 changes: 1 addition & 2 deletions include/ck/wrapper/utils/tensor_partition.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,7 @@ make_local_partition(TensorType& tensor,
},
Number<remove_reference_t<decltype(tensor_shape)>::Size()>{});
const auto lower_upper_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; },
Number<remove_reference_t<decltype(tensor_shape)>::Size()>{});
generate_identity_sequences<remove_reference_t<decltype(tensor_shape)>::Size()>();
auto sliced_desc =
transform_tensor_descriptor(unrolled_desc, transforms, lower_upper_dims, lower_upper_dims);
// Create layout
Expand Down
Loading