diff --git a/include/cell/copy/global_to_register.hpp b/include/cell/copy/global_to_register.hpp index 9a0fe629..482e9dae 100644 --- a/include/cell/copy/global_to_register.hpp +++ b/include/cell/copy/global_to_register.hpp @@ -15,13 +15,7 @@ using namespace traits; * @tparam type Global Layout type. */ template -struct GlobalToRegMatLoader { - using Global = Global_; - using BaseTile = BaseTile_; - using DType = Global::DType; - - DEVICE void operator()(const DType* src, BaseTile& dst); -}; +struct GlobalToRegMatLoader; template struct GlobalToRegMatLoader { diff --git a/include/cell/copy/global_to_shared.hpp b/include/cell/copy/global_to_shared.hpp index 2d6b3c43..bf9d640b 100644 --- a/include/cell/copy/global_to_shared.hpp +++ b/include/cell/copy/global_to_shared.hpp @@ -302,10 +302,6 @@ struct SharedToGlobalStorer : public Base { using WarpLayout = WarpLayout_; using BaseShape = traits::BaseTileShape; - static_assert( - (Shared::kSwizzled && sizeof(DType) == 4 || Shared::kSwizzled == false), - "Not implemented for swizzled layout with 2-byte data types."); - static_assert(Shared::kRows % BaseShape::kRows == 0, "Shared::kRows must be divisible by BaseShape::kRows."); static_assert(Shared::kCols % BaseShape::kCols == 0, diff --git a/include/cell/copy/global_to_shared_2.hpp b/include/cell/copy/global_to_shared_2.hpp new file mode 100644 index 00000000..f4ba25ca --- /dev/null +++ b/include/cell/copy/global_to_shared_2.hpp @@ -0,0 +1,317 @@ +#pragma once + +#include "types/mod.hpp" + +#include + +namespace tiledcuda::cell::copy { +using namespace traits; +namespace tl = tile_layout; +using namespace cute; + +namespace detail { +template +struct GlobalToSharedLoaderImpl2; + +template +struct GlobalToSharedLoaderImpl2 { + using Global = Global_; + using Shared = Shared_; + using WarpLayout = WarpLayout_; + using DType = typename Global::DType; + + static_assert(Global::kRows == Shared::kRows && + Global::kCols == Shared::kCols, + "Global and shared memory should have the same shape."); + static_assert(Global::kType == Shared::kType, + "The layout of Global memory and Shared memory tile should " + "be the same."); + static_assert(Global::kType == tl::Layout::kRowMajor, + "The layout of Global memory and Shared memory tile should " + "be row-major."); + static_assert(std::is_same_v, + "The data type of Shared and Global must be the same."); + + using WarpThreadLayout = tl::ColMajor<16, 2>; + static constexpr int kNumPerAccess = TraitsBase::kNumPerAccess; + + static constexpr int kThreadsRows = + tl::num_rows * tl::num_rows; + static constexpr int kThreadsCols = + tl::num_cols * tl::num_cols; + + static constexpr int kRows = Global::kRows; + static constexpr int kCols = Global::kCols; + + using GlobalLayout = + cute::Layout, Int>, Stride, _1>>; + + using LayoutAtom = cute::Layout, Stride<_16, _1>>; + using SharedLayoutNonSwizzled = decltype(tile_to_shape( + LayoutAtom{}, Shape, Int>{}, cute::Step<_2, _1>{})); + + // this swizzle function works only for 4-byte data types + using LayoutAtomSwizzled = + decltype(composition(Swizzle<2, 3, 3>{}, LayoutAtom{})); + using SharedLayoutSwizzled = decltype(tile_to_shape( + LayoutAtomSwizzled{}, Shape, Int>{}, + cute::Step<_2, _1>{})); + + using SharedLayout = + std::conditional_t; + + using ThreadLayout = + cute::Layout, Int>, + Stride, _1>>; + using ValueLayout = cute::Layout>>; + +#ifdef CP_ASYNC_SM80_ENABLED + using CopyInst = + Copy_Atom, DType>; +#else + using CopyInst = Copy_Atom; +#endif + + using TiledCopy = + decltype(make_tiled_copy(CopyInst{}, ThreadLayout{}, ValueLayout{})); + + DEVICE void operator()(const DType* src_data, DType* dst_data) { + TiledCopy tiled_copy; + + int tid = threadIdx.x; + + auto gtile = make_tensor(make_gmem_ptr(src_data), GlobalLayout{}); + auto stile = make_tensor(make_smem_ptr(dst_data), SharedLayout{}); + + auto loader = tiled_copy.get_thread_slice(tid); + + auto src = loader.partition_S(gtile); + auto dst = loader.partition_D(stile); + +#pragma unroll + for (int i = 0; i < int(size<1>(src)); ++i) +#pragma unroll + for (int j = 0; j < int(size<2>(src)); ++j) + cute::copy(tiled_copy, src(cute::_, i, j), dst(cute::_, i, j)); + } +}; + +template +struct GlobalToSharedLoaderImpl2 { + using Global = Global_; + using Shared = Shared_; + using WarpLayout = WarpLayout_; + using DType = typename Global::DType; + + static_assert(Global::kRows == Shared::kRows && + Global::kCols == Shared::kCols, + "Global and shared memory should have the same shape."); + static_assert(Global::kType == Shared::kType, + "The layout of Global memory and Shared memory tile should " + "be the same."); + static_assert(Global::kType == tl::Layout::kColMajor, + "The layout of Global memory and Shared memory tile should " + "be column-major."); + static_assert(std::is_same_v, + "The data type of Shared and Global must be the same."); + + using WarpThreadLayout = tl::RowMajor<2, 16>; + static constexpr int kNumPerAccess = TraitsBase::kNumPerAccess; + + static constexpr int kThreadsRows = + tl::num_rows * tl::num_rows; + static constexpr int kThreadsCols = + tl::num_cols * tl::num_cols; + + static constexpr int kRows = Global::kRows; + static constexpr int kCols = Global::kCols; + + using GlobalLayout = + cute::Layout, Int>, Stride<_1, Int>>; + + using LayoutAtom = cute::Layout, Stride<_1, _16>>; + // this swizzle function works only for 4-byte data types + using LayoutAtomSwizzled = + decltype(composition(Swizzle<2, 3, 3>{}, LayoutAtom{})); + + using SharedLayoutNonSwizzled = + decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{})); + + using SharedLayoutSwizzled = decltype(tile_to_shape( + LayoutAtomSwizzled{}, Shape, Int>{})); + + using SharedLayout = + std::conditional_t; + + using ThreadLayout = + cute::Layout, Int>, + Stride, _1>>; + using ValueLayout = cute::Layout, _1>, + Stride<_1, Int>>; + +#ifdef CP_ASYNC_SM80_ENABLED + using CopyInst = + Copy_Atom, DType>; +#else + using CopyInst = Copy_Atom; +#endif + + using TiledCopy = + decltype(make_tiled_copy(CopyInst{}, ThreadLayout{}, ValueLayout{})); + + DEVICE void operator()(const DType* src_data, DType* dst_data) { + TiledCopy tiled_copy; + int tid = threadIdx.x; + + auto gtile = make_tensor(make_gmem_ptr(src_data), GlobalLayout{}); + auto stile = make_tensor(make_smem_ptr(dst_data), SharedLayout{}); + + auto loader = tiled_copy.get_thread_slice(tid); + + auto src = loader.partition_S(gtile); + auto dst = loader.partition_D(stile); + +#pragma unroll + for (int i = 0; i < int(size<1>(src)); ++i) +#pragma unroll + for (int j = 0; j < int(size<2>(src)); ++j) + cute::copy(tiled_copy, src(cute::_, i, j), dst(cute::_, i, j)); + } +}; + +template +struct SharedToGlobalStorerImpl2; + +template +struct SharedToGlobalStorerImpl2 { + using Global = Global_; + using Shared = Shared_; + using DType = typename Global::DType; + using WarpLayout = WarpLayout_; + + static_assert(Global::kRows == Shared::kRows && + Global::kCols == Shared::kCols, + "Global and shared memory should have the same shape."); + static_assert(Global::kType == Shared::kType, + "The layout of Global memory and Shared memory tile should " + "be the same."); + static_assert(Global::kType == tl::Layout::kRowMajor, + "The layout of Global memory and Shared memory tile should " + "be row-major."); + static_assert(std::is_same_v, + "The data type of Shared and Global must be the same."); + + static constexpr int kRows = Global::kRows; + static constexpr int kCols = Global::kCols; + + static constexpr int kNumPerAccess = TraitsBase::kNumPerAccess; + using WarpThreadLayout = tl::ColMajor<16, 2>; + + // thread layout for the entire thread block + static constexpr int kThreadsRows = + tl::num_rows * tl::num_rows; + static constexpr int kThreadsCols = + tl::num_cols * tl::num_cols; + + using BaseTileLayout = cute::Layout, Stride<_16, _1>>; + using SharedLayoutNonSwizzled = decltype(tile_to_shape( + BaseTileLayout{}, Shape, Int>{}, + cute::Step<_2, _1>{})); + + // this swizzle function works only for 4-byte data types + using LayoutAtom = + decltype(composition(Swizzle<2, 3, 3>{}, BaseTileLayout{})); + using SharedLayoutSwizzled = decltype(tile_to_shape( + LayoutAtom{}, Shape, Int>{}, cute::Step<_2, _1>{})); + + // source layout + using SharedLayout = + std::conditional_t; + // target layout + using GlobalLayout = + cute::Layout, Int>, Stride, _1>>; + + using ThreadLayout = + cute::Layout, Int>, + Stride, _1>>; + using ValueLayout = cute::Layout>>; + + // transfer data from global memory to shared memory has cp.async, + // while transfer data from shared memory to global memory does not have. + // for the latter case, the copy instruction should be the default one. + using TiledCopy = decltype(make_tiled_copy(Copy_Atom{}, + ThreadLayout{}, ValueLayout{})); + + DEVICE void operator()(const DType* src_data, DType* dst_data) { + TiledCopy tiled_copy; + int tid = threadIdx.x; + + auto stile = make_tensor(make_smem_ptr(src_data), SharedLayout{}); + auto gtile = make_tensor(make_gmem_ptr(dst_data), GlobalLayout{}); + + auto loader = tiled_copy.get_thread_slice(tid); + + auto src = loader.partition_S(stile); + auto dst = loader.partition_D(gtile); + +#pragma unroll + for (int i = 0; i < int(size<1>(src)); ++i) +#pragma unroll + for (int j = 0; j < int(size<2>(src)); ++j) { + cute::copy(tiled_copy, src(cute::_, i, j), dst(cute::_, i, j)); + } + } +}; +} // namespace detail + +template +struct GlobalToSharedLoader2 { + using Shared = Shared_; + using DType = typename Shared::DType; + using WarpLayout = WarpLayout_; + + static_assert(sizeof(DType) != 4, + "Not implemented for data types other than 2 bytes."); + + template + DEVICE void operator()(const Global& src, Shared& dst) { + const DType* src_ptr = src.data(); + DType* dst_ptr = dst.mutable_data(); + + using Loader = + detail::GlobalToSharedLoaderImpl2; + Loader loader; + loader(src_ptr, dst_ptr); + } +}; + +template +struct SharedToGlobalStorer2 { + using Shared = Shared_; + using DType = typename Shared::DType; + using WarpLayout = WarpLayout_; + + static_assert(sizeof(DType) != 4, + "Not implemented for data types other than 2 bytes."); + + template + DEVICE void operator()(const Shared& src_, Global& dst_) { + const DType* src = src_.data(); + DType* dst = dst_.mutable_data(); + using Storer = + detail::SharedToGlobalStorerImpl2; + Storer storer; + storer(src, dst); + } +}; +} // namespace tiledcuda::cell::copy diff --git a/include/cell/copy/shared_to_register.hpp b/include/cell/copy/shared_to_register.hpp index 85ce21fb..550c910f 100644 --- a/include/cell/copy/shared_to_register.hpp +++ b/include/cell/copy/shared_to_register.hpp @@ -271,10 +271,6 @@ struct RegToSharedStorer { static_assert(Shared::kCols % BaseShape::kCols == 0, "The number of shared memory columns must be divisible " "by the base tile column."); - static_assert( - (Shared::kSwizzled && sizeof(DType) == 4 || - Shared::kSwizzled == false), - "Not implemented for swizzled layout with 2-byte data types."); // how many times the 16x16 `BaseTile` is executed along the row and // column direction. diff --git a/include/cell/sync.hpp b/include/cell/sync.hpp index c767a047..19cd0d09 100644 --- a/include/cell/sync.hpp +++ b/include/cell/sync.hpp @@ -21,5 +21,4 @@ DEVICE void __copy_async() { commit_copy_group(); wait_group<0>(); } - } // namespace tiledcuda::cell diff --git a/include/types/layout.hpp b/include/types/layout.hpp index b62078ae..22c5611d 100644 --- a/include/types/layout.hpp +++ b/include/types/layout.hpp @@ -95,6 +95,33 @@ struct SharedLayout { template struct SwizzledRowMajor; +/// @brief Swizzled row-major layout for storing half-typed 16x16 BaseTile. +template <> +struct SwizzledRowMajor<32> { + using BaseShape = traits::BaseTileShape<__half>; + + static constexpr int kB = 2; + static constexpr int kM = 3; + static constexpr int kS = 3; + + static_assert( + BaseShape::kNumel == ((1 << kB) * (1 << kM) * (1 << kS)), + "Swizzling is performed based on the BaseTile, and the number of " + "elements in a BaseTile should be equal to 2^B x 2^S x 2^M."); + + using SwizzledBaseTile = decltype(composition( + cute::Swizzle{}, + cute::Layout, Int>, + Stride, _1>>{})); + + DEVICE SwizzledRowMajor() : swizzled_(SwizzledBaseTile{}){}; + + DEVICE int operator()(int i, int j) const { return swizzled_(i, j); } + + private: + SwizzledBaseTile swizzled_; +}; + template <> struct SwizzledRowMajor<64> { using BaseShape = traits::BaseTileShape; @@ -227,6 +254,15 @@ struct SharedLayoutWrapperImpl { Stride<_1, Int>>; }; +/// @brief Shared memory layout for swizzled row-major layout with 16-bit data +/// type. +template <> +struct SharedLayoutWrapperImpl { + using BaseShape = traits::BaseTileShape<__half>; + + using Layout = SwizzledRowMajor<32>; +}; + /// @brief Shared memory layout for swizzled row-major layout with 16-bit data /// type. template <> @@ -322,6 +358,5 @@ HOST_DEVICE auto make_col_major_layout(const int row, const int col, return cute::make_layout(cute::make_shape(row, col), cute::make_stride(cute::_1{}, stride)); } - } // namespace tile_layout } // namespace tiledcuda::cell diff --git a/tests/cpp/cell/test_g2s_copy_2.cu b/tests/cpp/cell/test_g2s_copy_2.cu new file mode 100644 index 00000000..ddd423a0 --- /dev/null +++ b/tests/cpp/cell/test_g2s_copy_2.cu @@ -0,0 +1,236 @@ +#include "cell/copy/global_to_shared_2.hpp" +#include "cell/copy/mod.hpp" +#include "cell/sync.hpp" +#include "common/test_utils.hpp" +#include "types/mod.hpp" + +#include +#include + +namespace tiledcuda::testing { +using namespace cell; +using namespace copy; + +namespace { +template +__device__ bool is_equal(const Element* data1, const Element* data2, + int numel) { + Element epsilon = static_cast(1e-3); + + for (int i = 0; i < numel; ++i) { + if (data1[i] - data2[i] > epsilon) { + printf("%d-th emelemnt is not equal: %.2f, %.2f\n", i, + __half2float(data1[i]), __half2float(data2[i])); + return false; + } + } + return true; +} + +template +__global__ void copy_g2s(const Element* src, Loader1& loader1, + Loader2& loader2) { + extern __shared__ __align__(sizeof(double)) unsigned char buf_[]; + auto* buf1 = reinterpret_cast(buf_); + auto* buf2 = buf1 + SharedTile::kNumel; + + GlobalTile g_tile(src); + SharedTile s_tile1(buf1); + SharedTile s_tile2(buf2); + + loader1(g_tile, s_tile1); + __copy_async(); + __syncthreads(); + + loader2(g_tile, s_tile2); + __copy_async(); + __syncthreads(); + + if (thread(0)) { +#if defined(DEBUG) + printf("self-implemented\n"); + s_tile1.dump_value(); + + printf("cute\n"); + s_tile2.dump_value(); +#endif + + assert(is_equal(buf1, buf2, SharedTile::kNumel)); + } +} + +template +__global__ void store_s2g(const Element* src, Element* dst) { + extern __shared__ __align__(sizeof(double)) unsigned char buf_[]; + auto* buf = reinterpret_cast(buf_); + + Loader loader; + StorerR2S storer1; + StorerS2G storer2; + + Global g_src_tile(src); + Reg r_tile; + + Shared s_tile(buf); + Global g_dst_tile(dst); + + loader(g_src_tile, r_tile); + __syncthreads(); + + storer1(r_tile, s_tile); + __syncthreads(); + + storer2(s_tile, g_dst_tile); + __syncthreads(); + +#if defined(DEBUG) + if (thread0()) { + printf("\nglobal tile source:\n"); + g_src_tile.dump_value(); + + printf("\nshared tile:\n"); + s_tile.dump_value(); + + printf("\nglobal tile target:\n"); + g_dst_tile.dump_value(); + } +#endif +} + +template +void test_row_major_load() { + using GlobalTile = GlobalTile>; + using SharedTile = + SharedTile, kSwizzled>; + + using Loader1 = copy::GlobalToSharedLoader; + using Loader2 = copy::GlobalToSharedLoader2; + + Loader1 loader1; + Loader2 loader2; + + // threads are arranged as 8 x 4 to perform 2D copy + static const int kThreads = tl::get_numel * 32; + + int numel = kRows * kCols; + thrust::host_vector h_A(numel); + + for (int i = 0; i < h_A.size(); ++i) { + h_A[i] = static_cast(i); + } + thrust::device_vector d_A = h_A; + + dim3 dim_grid(1, 1, 1); + dim3 dim_block(kThreads); + copy_g2s + <<>>( + thrust::raw_pointer_cast(d_A.data()), loader1, loader2); + cudaDeviceSynchronize(); +} + +template +void test_col_major_load() { + using GlobalTile = GlobalTile>; + using SharedTile = + SharedTile, kSwizzled>; + + using Loader1 = copy::GlobalToSharedLoader; + using Loader2 = copy::GlobalToSharedLoader2; + + Loader1 loader1; + Loader2 loader2; + + // threads are arranged as 8 x 4 to perform 2D copy + static const int kThreads = tl::get_numel * 32; + + int numel = kRows * kCols; + thrust::host_vector h_A(numel); + + for (int i = 0; i < h_A.size(); ++i) { + h_A[i] = static_cast(i); + } + thrust::device_vector d_A = h_A; + + dim3 dim_grid(1, 1, 1); + dim3 dim_block(kThreads); + copy_g2s + <<>>( + thrust::raw_pointer_cast(d_A.data()), loader1, loader2); + cudaDeviceSynchronize(); +} + +template +void test_row_major_store() { + using BaseShape = traits::BaseTileShape; + + const int kThreads = tl::get_numel * 32; + + // define tiles + using Global = GlobalTile>; + static constexpr int kRowRepeats = + kRows / tl::num_rows / BaseShape::kTileSize; + static constexpr int kColRepeats = + kCols / tl::num_cols / BaseShape::kTileSize; + + using Reg = RegTile, + tl::RowMajor>; + using Shared = SharedTile, kSwizzled>; + + // define loader and storer + using Loader = GlobalToRegLoader; + using StorerR2S = RegToSharedStorer; + using StorerS2G = SharedToGlobalStorer2; + + int numel = kRows * kCols; + thrust::host_vector h_src(numel); + for (int i = 0; i < h_src.size(); ++i) h_src[i] = static_cast(i); + + thrust::device_vector d_src = h_src; + + thrust::device_vector d_dst(numel); + thrust::fill(d_dst.begin(), d_dst.end(), static_cast(0.)); + + auto test_func = + &store_s2g; + + dim3 dim_grid(1, 1, 1); + dim3 dim_block(kThreads, 1, 1); + int shm_size = Shared::kNumel * sizeof(Element); + + test_func<<>>( + thrust::raw_pointer_cast(d_src.data()), + thrust::raw_pointer_cast(d_dst.data())); + cudaDeviceSynchronize(); + + thrust::host_vector h_dst = d_dst; + + assert_equal(thrust::raw_pointer_cast(h_src.data()), + thrust::raw_pointer_cast(h_dst.data()), numel, 1e-4); + + LOG(INFO) << "[" << kRows << ", " << kCols << "] test passed!" << std::endl; +}; +} // namespace + +TEST(G2S_Copy, load_row_major) { + test_row_major_load<__half, 128, 64, tl::RowMajor<2, 2>, false>(); + test_row_major_load<__half, 128, 64, tl::RowMajor<2, 2>, true>(); +} + +TEST(G2S_Copy, load_col_major) { + test_col_major_load<__half, 128, 64, tl::RowMajor<2, 2>, false>(); + test_col_major_load<__half, 128, 64, tl::RowMajor<2, 2>, true>(); +} + +TEST(S2G_Copy, store_row_major) { + test_row_major_store<__half, 128, 64, tl::RowMajor<2, 2>, false>(); + test_row_major_store<__half, 128, 64, tl::RowMajor<2, 2>, true>(); + + test_row_major_store, false>(); + test_row_major_store, true>(); +} +} // namespace tiledcuda::testing diff --git a/tests/cpp/cell/test_swizzled_copy.cu b/tests/cpp/cell/test_swizzled_copy.cu index d121a47c..59a01ccb 100644 --- a/tests/cpp/cell/test_swizzled_copy.cu +++ b/tests/cpp/cell/test_swizzled_copy.cu @@ -467,6 +467,15 @@ TEST(TestNonSwizzledStore, test_row_major) { TEST(TestSwizzledStored, test_row_major) { static constexpr int kSwizzled = true; + test_row_major_store<__half, tl::RowMajor<1, 1>, 16, 32, kSwizzled>(); + test_row_major_store<__half, tl::RowMajor<1, 1>, 16, 32, kSwizzled>(); + test_row_major_store<__half, tl::RowMajor<1, 1>, 16, 48, kSwizzled>(); + test_row_major_store<__half, tl::RowMajor<2, 1>, 32, 48, kSwizzled>(); + test_row_major_store<__half, tl::RowMajor<1, 1>, 16, 32, kSwizzled>(); + test_row_major_store<__half, tl::RowMajor<2, 1>, 64, 32, kSwizzled>(); + test_row_major_store<__half, tl::RowMajor<1, 2>, 128, 64, kSwizzled>(); + test_row_major_store<__half, tl::RowMajor<2, 2>, 64, 64, kSwizzled>(); + test_row_major_store, 16, 32, kSwizzled>(); test_row_major_store, 16, 48, kSwizzled>(); test_row_major_store, 32, 48, kSwizzled>(); diff --git a/tests/cpp/common/test_utils.cc b/tests/cpp/common/test_utils.cc index 0883ab81..ca59b307 100644 --- a/tests/cpp/common/test_utils.cc +++ b/tests/cpp/common/test_utils.cc @@ -18,6 +18,23 @@ void assert_equal(const __half* v1, const __half* v2, int64_t numel, } } +template <> +void assert_equal(const cutlass::half_t* v1_, const cutlass::half_t* v2_, + int64_t numel, float epsilon) { + const __half* v1 = reinterpret_cast(v1_); + const __half* v2 = reinterpret_cast(v2_); + + float a = 0.f; + float b = 0.f; + for (int i = 0; i < numel; ++i) { + a = __half2float(v1[i]); + b = __half2float(v2[i]); + + EXPECT_NEAR(a, b, epsilon) << "v1[" << i << "] vs. v2[" << i + << "] = " << a << " vs. " << b << std::endl; + } +} + template <> void assert_equal(const float* v1, const float* v2, int64_t numel, float epsilon) {