-
Notifications
You must be signed in to change notification settings - Fork 9
Optimizations for index_select_scalar_cumsum_kernel #137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: will/upstream
Are you sure you want to change the base?
Changes from all commits
dbab5d3
fe9bfbc
782bf91
b9f8625
3ef975b
20d1445
27e5537
2182e0b
fe929e5
489e4b6
6083e78
1a75460
2d160b2
8fe70ea
200de28
af5ccbc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,6 +6,7 @@ | |
| * LICENSE file in the root directory of this source tree. | ||
| */ | ||
|
|
||
| #include <type_traits> | ||
| #include "common.cuh" | ||
|
|
||
| using Tensor = at::Tensor; | ||
|
|
@@ -17,7 +18,8 @@ template < | |
| typename index_t, | ||
| typename acc_t, | ||
| int NUM_THREADS_PER_BLOCK, | ||
| int MAX_ENTRIES_PER_BLOCK> | ||
| int MAX_ENTRIES_PER_BLOCK, | ||
| int ENTRIES_PER_THREAD> | ||
| __global__ void index_select_scalar_cumsum_kernel( | ||
| pta::PackedTensorAccessor32<scalar_t, 1, at::RestrictPtrTraits> output, | ||
| pta::PackedTensorAccessor32<acc_t, 1, at::RestrictPtrTraits> output_cumsum, | ||
|
|
@@ -31,6 +33,81 @@ __global__ void index_select_scalar_cumsum_kernel( | |
| acc_t* block_sums) { | ||
| typedef cub::BlockScan<acc_t, NUM_THREADS_PER_BLOCK> BlockScan; | ||
| __shared__ typename BlockScan::TempStorage bs_temp_storage; | ||
| __shared__ acc_t block_prefix; | ||
|
|
||
| // ROCm path | ||
| #ifdef USE_ROCM | ||
| const int output_batch_size = indices.size(0); | ||
| const int num_entries = num_batches * output_batch_size; | ||
| const bool multi_block = gridDim.x > 1; | ||
| const int block_entries = blockIdx.x == gridDim.x - 1 | ||
| ? last_block_num_entries | ||
| : MAX_ENTRIES_PER_BLOCK; | ||
| const int block_entry_start = blockIdx.x * MAX_ENTRIES_PER_BLOCK; | ||
| const int remaining_entries = num_entries - block_entry_start; | ||
| const int num_entries_per_block = remaining_entries > 0 | ||
| ? (remaining_entries < block_entries ? remaining_entries : block_entries) | ||
| : 0; | ||
|
|
||
| const int base_entry = block_entry_start + threadIdx.x * ENTRIES_PER_THREAD; | ||
| acc_t local_data[ENTRIES_PER_THREAD]; | ||
|
|
||
| #pragma unroll | ||
| for (int i = 0; i < ENTRIES_PER_THREAD; ++i) { | ||
| const int entry = base_entry + i; | ||
| if (entry < num_entries) { | ||
| const int bid = entry / output_batch_size; | ||
| const int idx_in_batch = entry - bid * output_batch_size; | ||
| const int bid_base = bid * input_batch_size; | ||
| const index_t sel_idx = indices[idx_in_batch]; | ||
| local_data[i] = __builtin_nontemporal_load(&input[bid_base + sel_idx]); | ||
| output[entry] = local_data[i]; | ||
| } else { | ||
| local_data[i] = 0; | ||
| } | ||
| } | ||
|
|
||
| // Faster path for single block | ||
| if (!multi_block) { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As in the other file, you may consider passing
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it, I will be sure to try it to test its results. |
||
| if (num_entries_per_block > 0) { | ||
| BlockScan(bs_temp_storage).InclusiveSum(local_data, local_data); | ||
| } | ||
| if (base_entry < num_entries) { | ||
| #pragma unroll | ||
| for (int i = 0; i < ENTRIES_PER_THREAD; ++i) { | ||
| const int entry = base_entry + i; | ||
| if (entry < num_entries) { | ||
| output_cumsum[entry] = local_data[i]; | ||
| } | ||
| } | ||
| } | ||
| return; | ||
| } | ||
|
|
||
| if (num_entries_per_block > 0) { | ||
| inclusive_sum_scan_kernel<acc_t, ENTRIES_PER_THREAD, NUM_THREADS_PER_BLOCK>( | ||
| local_data, | ||
| bs_temp_storage, | ||
| block_flags, | ||
| block_sums, | ||
| &block_prefix, | ||
| num_entries_per_block, | ||
| blockIdx.x, | ||
| multi_block, | ||
| 1); | ||
| } | ||
|
|
||
| if (base_entry < num_entries) { | ||
| #pragma unroll | ||
| for (int i = 0; i < ENTRIES_PER_THREAD; ++i) { | ||
| const int entry = base_entry + i; | ||
| if (entry < num_entries) { | ||
| output_cumsum[entry] = local_data[i]; | ||
| } | ||
| } | ||
| } | ||
| #else | ||
| // CUDA path | ||
| __shared__ acc_t smem[MAX_ENTRIES_PER_BLOCK]; | ||
| const auto tid = blockIdx.x * blockDim.x + threadIdx.x; | ||
| const int output_batch_size = indices.size(0); | ||
|
|
@@ -65,6 +142,7 @@ __global__ void index_select_scalar_cumsum_kernel( | |
| if (tid < num_batches * output_batch_size) { | ||
| output_cumsum[tid] = *local_data; | ||
| } | ||
| #endif | ||
| } | ||
|
|
||
| template < | ||
|
|
@@ -183,62 +261,151 @@ class KeyedJaggedIndexSelectDim1GPUOp | |
| const int num_batches = lengths.numel() / batch_size; | ||
| const int num_output_lengths = num_batches * indices.numel(); | ||
| const int MAX_CUMSUM_ENTRIES_PER_BLOCK = 256; | ||
| #ifdef USE_ROCM | ||
| const int num_entries_per_thread[] = {4, 2, 1}; | ||
| int entries_per_thread = 1; | ||
| for (int i : num_entries_per_thread) { | ||
| if (indices.numel() % i == 0) { | ||
| entries_per_thread = i; | ||
| break; | ||
| } | ||
| } | ||
| #else | ||
| constexpr int ENTRIES_PER_THREAD = 1; | ||
| auto grid_size = cuda_calc_xblock_count( | ||
| num_output_lengths, MAX_CUMSUM_ENTRIES_PER_BLOCK); | ||
| #endif | ||
|
|
||
| Tensor output_offsets = | ||
| at::empty({num_batches * indices.numel()}, offsets.options()); | ||
| Tensor output_lengths = | ||
| at::empty({num_batches * indices.numel()}, lengths.options()); | ||
|
|
||
| Tensor block_flags, block_sums; | ||
| if (grid_size > 1) { | ||
| block_flags = at::zeros({grid_size}, lengths.options().dtype(at::kInt)); | ||
| block_sums = at::empty({grid_size}, output_offsets.options()); | ||
| } | ||
|
|
||
| // Do index select and cumsum | ||
| AT_DISPATCH_INDEX_TYPES( | ||
| lengths.scalar_type(), "index_select_scalar_cumsum_wrapper_1", [&] { | ||
| using length_t = index_t; | ||
| AT_DISPATCH_INDEX_TYPES( | ||
| offsets.scalar_type(), | ||
| "index_select_scalar_cumsum_wrapper_2", | ||
| [&] { | ||
| using offset_t = index_t; | ||
| AT_DISPATCH_INDEX_TYPES( | ||
| indices.scalar_type(), | ||
| "index_select_scalar_cumsum_wrapper_3", | ||
| [&] { | ||
| FBGEMM_LAUNCH_KERNEL( | ||
| (index_select_scalar_cumsum_kernel< | ||
| length_t, | ||
| index_t, | ||
| offset_t, | ||
| MAX_CUMSUM_ENTRIES_PER_BLOCK, | ||
| MAX_CUMSUM_ENTRIES_PER_BLOCK>), | ||
| grid_size, | ||
| MAX_CUMSUM_ENTRIES_PER_BLOCK, | ||
| 0, | ||
| at::cuda::getCurrentCUDAStream(), | ||
| PTA_B(output_lengths, length_t, 1, 32), | ||
| PTA_B(output_offsets, offset_t, 1, 32), | ||
| PTA_B(lengths, length_t, 1, 32), | ||
| PTA_B(indices, index_t, 1, 32), | ||
| num_batches, | ||
| batch_size, | ||
| num_output_lengths - | ||
| MAX_CUMSUM_ENTRIES_PER_BLOCK * (grid_size - 1), | ||
| grid_size > 1 ? block_flags.data_ptr<int>() : nullptr, | ||
| grid_size > 1 ? block_sums.data_ptr<offset_t>() | ||
| : nullptr); | ||
| }); | ||
| }); | ||
| }); | ||
| #ifdef USE_ROCM | ||
| // ROCm path | ||
| auto dispatch_cumsum = [&](auto vec_tag) { | ||
| constexpr int ENTRIES_PER_THREAD = decltype(vec_tag)::value; | ||
| constexpr int ENTRIES_PER_BLOCK = | ||
| MAX_CUMSUM_ENTRIES_PER_BLOCK * ENTRIES_PER_THREAD; | ||
| const auto rocm_grid_size = | ||
| (num_output_lengths + ENTRIES_PER_BLOCK - 1) / ENTRIES_PER_BLOCK; | ||
|
|
||
| if (rocm_grid_size == 0) | ||
| return; | ||
|
|
||
| if (rocm_grid_size > 1) { | ||
| block_flags = at::zeros({rocm_grid_size}, lengths.options().dtype(at::kInt)); | ||
| block_sums = at::empty({rocm_grid_size}, output_offsets.options()); | ||
| } | ||
|
|
||
| AT_DISPATCH_INDEX_TYPES( | ||
| lengths.scalar_type(), "index_select_scalar_cumsum_wrapper_1", [&] { | ||
| using length_t = index_t; | ||
| AT_DISPATCH_INDEX_TYPES( | ||
| offsets.scalar_type(), | ||
| "index_select_scalar_cumsum_wrapper_2", | ||
| [&] { | ||
| using offset_t = index_t; | ||
| AT_DISPATCH_INDEX_TYPES( | ||
| indices.scalar_type(), | ||
| "index_select_scalar_cumsum_wrapper_3", | ||
| [&] { | ||
| FBGEMM_LAUNCH_KERNEL( | ||
| (index_select_scalar_cumsum_kernel< | ||
| length_t, | ||
| index_t, | ||
| offset_t, | ||
| MAX_CUMSUM_ENTRIES_PER_BLOCK, | ||
| ENTRIES_PER_BLOCK, | ||
| ENTRIES_PER_THREAD>), | ||
| rocm_grid_size, | ||
| MAX_CUMSUM_ENTRIES_PER_BLOCK, | ||
| 0, | ||
| at::cuda::getCurrentCUDAStream(), | ||
| PTA_B(output_lengths, length_t, 1, 32), | ||
| PTA_B(output_offsets, offset_t, 1, 32), | ||
| PTA_B(lengths, length_t, 1, 32), | ||
| PTA_B(indices, index_t, 1, 32), | ||
| num_batches, | ||
| batch_size, | ||
| num_output_lengths - | ||
| ENTRIES_PER_BLOCK * (rocm_grid_size - 1), | ||
| rocm_grid_size > 1 | ||
| ? block_flags.data_ptr<int>() | ||
| : nullptr, | ||
| rocm_grid_size > 1 | ||
| ? block_sums.data_ptr<offset_t>() | ||
| : nullptr); | ||
| }); | ||
| }); | ||
| }); | ||
| }; | ||
|
|
||
| switch (entries_per_thread) { | ||
| case 4: | ||
| dispatch_cumsum(std::integral_constant<int, 4>{}); | ||
| break; | ||
| case 2: | ||
| dispatch_cumsum(std::integral_constant<int, 2>{}); | ||
| break; | ||
| default: | ||
| dispatch_cumsum(std::integral_constant<int, 1>{}); | ||
| break; | ||
| } | ||
| #else | ||
| // CUDA path | ||
| if (grid_size > 1) { | ||
| block_flags = at::zeros({grid_size}, lengths.options().dtype(at::kInt)); | ||
| block_sums = at::empty({grid_size}, output_offsets.options()); | ||
| } | ||
|
|
||
| AT_DISPATCH_INDEX_TYPES( | ||
| lengths.scalar_type(), "index_select_scalar_cumsum_wrapper_1", [&] { | ||
| using length_t = index_t; | ||
| AT_DISPATCH_INDEX_TYPES( | ||
| offsets.scalar_type(), | ||
| "index_select_scalar_cumsum_wrapper_2", | ||
| [&] { | ||
| using offset_t = index_t; | ||
| AT_DISPATCH_INDEX_TYPES( | ||
| indices.scalar_type(), | ||
| "index_select_scalar_cumsum_wrapper_3", | ||
| [&] { | ||
| FBGEMM_LAUNCH_KERNEL( | ||
| (index_select_scalar_cumsum_kernel< | ||
| length_t, | ||
| index_t, | ||
| offset_t, | ||
| MAX_CUMSUM_ENTRIES_PER_BLOCK, | ||
| MAX_CUMSUM_ENTRIES_PER_BLOCK, | ||
| ENTRIES_PER_THREAD>), | ||
| grid_size, | ||
| MAX_CUMSUM_ENTRIES_PER_BLOCK, | ||
| 0, | ||
| at::cuda::getCurrentCUDAStream(), | ||
| PTA_B(output_lengths, length_t, 1, 32), | ||
| PTA_B(output_offsets, offset_t, 1, 32), | ||
| PTA_B(lengths, length_t, 1, 32), | ||
| PTA_B(indices, index_t, 1, 32), | ||
| num_batches, | ||
| batch_size, | ||
| num_output_lengths - | ||
| MAX_CUMSUM_ENTRIES_PER_BLOCK * (grid_size - 1), | ||
| grid_size > 1 | ||
| ? block_flags.data_ptr<int>() | ||
| : nullptr, | ||
| grid_size > 1 | ||
| ? block_sums.data_ptr<offset_t>() | ||
| : nullptr); | ||
| }); | ||
| }); | ||
| }); | ||
| #endif | ||
|
|
||
| const int64_t num_outputs = (selected_lengths_sum.has_value()) | ||
| ? selected_lengths_sum.value().guard_int(__FILE__, __LINE__) | ||
| : output_offsets[output_offsets.numel() - 1].item<int64_t>(); | ||
| ? selected_lengths_sum.value().guard_int(__FILE__, __LINE__) | ||
| : output_offsets[output_offsets.numel() - 1].item<int64_t>(); | ||
| Tensor output = at::empty({num_outputs}, values.options()); | ||
| Tensor output_weights; | ||
| if (weights.has_value()) { | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.