Skip to content
Open
45 changes: 43 additions & 2 deletions fbgemm_gpu/include/fbgemm_gpu/utils/inclusive_sum_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,52 @@ __inline__ __device__ void inclusive_sum_scan_kernel(
const int block_id,
const bool is_multi_block,
const int signal) {
// ROCm path
#ifdef USE_ROCM
// Perform scan within a block
cub::BlockScan<scalar_t, NUM_THREADS_PER_BLOCK>(temp_storage)
.InclusiveSum(arr, arr);

// Perform stream scan across blocks
// Perform scan across blocks
if (is_multi_block) {
Comment thread
amd-wsung102 marked this conversation as resolved.
const bool is_last_thread =
threadIdx.x == (num_entries_per_block - 1) / ITEMS_PER_THREAD;
// The thread that holds the last entry in the block does synchronization
if (is_last_thread) {
scalar_t block_prev_local = 0;
if (block_id != 0) {
// Spin wait for the previous block to write the sum value
while (atomicAdd(&block_flags[block_id - 1], 0) < signal)
;

// Get sum from the previous block
*block_prev = block_prev_local = block_sums[block_id - 1];
}

// Write sum to global memory for the next block to consume
const int scope = (num_entries_per_block - 1) % ITEMS_PER_THREAD;
block_sums[block_id] = block_prev_local + arr[scope];
__threadfence();
// Set a flag to notify the next block
atomicExch(&block_flags[block_id], signal);
}

__syncthreads();

if (block_id != 0) {
scalar_t block_prev_local = *block_prev;
for (int i = 0; i < ITEMS_PER_THREAD; i++) {
arr[i] += block_prev_local;
}
}
}
#else
// CUDA path
// Perform scan across blocks
cub::BlockScan<scalar_t, NUM_THREADS_PER_BLOCK>(temp_storage)
.InclusiveSum(arr, arr);

// Perform scan across blocks
if (is_multi_block) {
// The thread that holds the last entry in the block does synchronization
if (threadIdx.x == (num_entries_per_block - 1) / ITEMS_PER_THREAD) {
Expand Down Expand Up @@ -104,6 +145,6 @@ __inline__ __device__ void inclusive_sum_scan_kernel(
}
}
}
#endif
}

} // namespace fbgemm_gpu
259 changes: 213 additions & 46 deletions fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
* LICENSE file in the root directory of this source tree.
*/

#include <type_traits>
#include "common.cuh"

using Tensor = at::Tensor;
Expand All @@ -17,7 +18,8 @@ template <
typename index_t,
typename acc_t,
int NUM_THREADS_PER_BLOCK,
int MAX_ENTRIES_PER_BLOCK>
int MAX_ENTRIES_PER_BLOCK,
int ENTRIES_PER_THREAD>
__global__ void index_select_scalar_cumsum_kernel(
pta::PackedTensorAccessor32<scalar_t, 1, at::RestrictPtrTraits> output,
pta::PackedTensorAccessor32<acc_t, 1, at::RestrictPtrTraits> output_cumsum,
Expand All @@ -31,6 +33,81 @@ __global__ void index_select_scalar_cumsum_kernel(
acc_t* block_sums) {
typedef cub::BlockScan<acc_t, NUM_THREADS_PER_BLOCK> BlockScan;
__shared__ typename BlockScan::TempStorage bs_temp_storage;
__shared__ acc_t block_prefix;

// ROCm path
#ifdef USE_ROCM
const int output_batch_size = indices.size(0);
const int num_entries = num_batches * output_batch_size;
const bool multi_block = gridDim.x > 1;
const int block_entries = blockIdx.x == gridDim.x - 1
? last_block_num_entries
: MAX_ENTRIES_PER_BLOCK;
const int block_entry_start = blockIdx.x * MAX_ENTRIES_PER_BLOCK;
const int remaining_entries = num_entries - block_entry_start;
const int num_entries_per_block = remaining_entries > 0
? (remaining_entries < block_entries ? remaining_entries : block_entries)
: 0;

const int base_entry = block_entry_start + threadIdx.x * ENTRIES_PER_THREAD;
acc_t local_data[ENTRIES_PER_THREAD];

#pragma unroll
for (int i = 0; i < ENTRIES_PER_THREAD; ++i) {
const int entry = base_entry + i;
if (entry < num_entries) {
const int bid = entry / output_batch_size;
const int idx_in_batch = entry - bid * output_batch_size;
const int bid_base = bid * input_batch_size;
const index_t sel_idx = indices[idx_in_batch];
local_data[i] = __builtin_nontemporal_load(&input[bid_base + sel_idx]);
output[entry] = local_data[i];
} else {
local_data[i] = 0;
}
}

// Faster path for single block
if (!multi_block) {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As in the other file, you may consider passing multi_block as a compile-time parameter or splitting the function and dispatching the appropriate one at runtime.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, I will be sure to try it to test its results.

if (num_entries_per_block > 0) {
BlockScan(bs_temp_storage).InclusiveSum(local_data, local_data);
}
if (base_entry < num_entries) {
#pragma unroll
for (int i = 0; i < ENTRIES_PER_THREAD; ++i) {
const int entry = base_entry + i;
if (entry < num_entries) {
output_cumsum[entry] = local_data[i];
}
}
}
return;
}

if (num_entries_per_block > 0) {
inclusive_sum_scan_kernel<acc_t, ENTRIES_PER_THREAD, NUM_THREADS_PER_BLOCK>(
local_data,
bs_temp_storage,
block_flags,
block_sums,
&block_prefix,
num_entries_per_block,
blockIdx.x,
multi_block,
1);
}

if (base_entry < num_entries) {
#pragma unroll
for (int i = 0; i < ENTRIES_PER_THREAD; ++i) {
const int entry = base_entry + i;
if (entry < num_entries) {
output_cumsum[entry] = local_data[i];
}
}
}
#else
// CUDA path
__shared__ acc_t smem[MAX_ENTRIES_PER_BLOCK];
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
const int output_batch_size = indices.size(0);
Expand Down Expand Up @@ -65,6 +142,7 @@ __global__ void index_select_scalar_cumsum_kernel(
if (tid < num_batches * output_batch_size) {
output_cumsum[tid] = *local_data;
}
#endif
}

template <
Expand Down Expand Up @@ -183,62 +261,151 @@ class KeyedJaggedIndexSelectDim1GPUOp
const int num_batches = lengths.numel() / batch_size;
const int num_output_lengths = num_batches * indices.numel();
const int MAX_CUMSUM_ENTRIES_PER_BLOCK = 256;
#ifdef USE_ROCM
const int num_entries_per_thread[] = {4, 2, 1};
int entries_per_thread = 1;
for (int i : num_entries_per_thread) {
if (indices.numel() % i == 0) {
entries_per_thread = i;
break;
}
}
#else
constexpr int ENTRIES_PER_THREAD = 1;
auto grid_size = cuda_calc_xblock_count(
num_output_lengths, MAX_CUMSUM_ENTRIES_PER_BLOCK);
#endif

Tensor output_offsets =
at::empty({num_batches * indices.numel()}, offsets.options());
Tensor output_lengths =
at::empty({num_batches * indices.numel()}, lengths.options());

Tensor block_flags, block_sums;
if (grid_size > 1) {
block_flags = at::zeros({grid_size}, lengths.options().dtype(at::kInt));
block_sums = at::empty({grid_size}, output_offsets.options());
}

// Do index select and cumsum
AT_DISPATCH_INDEX_TYPES(
lengths.scalar_type(), "index_select_scalar_cumsum_wrapper_1", [&] {
using length_t = index_t;
AT_DISPATCH_INDEX_TYPES(
offsets.scalar_type(),
"index_select_scalar_cumsum_wrapper_2",
[&] {
using offset_t = index_t;
AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(),
"index_select_scalar_cumsum_wrapper_3",
[&] {
FBGEMM_LAUNCH_KERNEL(
(index_select_scalar_cumsum_kernel<
length_t,
index_t,
offset_t,
MAX_CUMSUM_ENTRIES_PER_BLOCK,
MAX_CUMSUM_ENTRIES_PER_BLOCK>),
grid_size,
MAX_CUMSUM_ENTRIES_PER_BLOCK,
0,
at::cuda::getCurrentCUDAStream(),
PTA_B(output_lengths, length_t, 1, 32),
PTA_B(output_offsets, offset_t, 1, 32),
PTA_B(lengths, length_t, 1, 32),
PTA_B(indices, index_t, 1, 32),
num_batches,
batch_size,
num_output_lengths -
MAX_CUMSUM_ENTRIES_PER_BLOCK * (grid_size - 1),
grid_size > 1 ? block_flags.data_ptr<int>() : nullptr,
grid_size > 1 ? block_sums.data_ptr<offset_t>()
: nullptr);
});
});
});
#ifdef USE_ROCM
// ROCm path
auto dispatch_cumsum = [&](auto vec_tag) {
constexpr int ENTRIES_PER_THREAD = decltype(vec_tag)::value;
constexpr int ENTRIES_PER_BLOCK =
MAX_CUMSUM_ENTRIES_PER_BLOCK * ENTRIES_PER_THREAD;
const auto rocm_grid_size =
(num_output_lengths + ENTRIES_PER_BLOCK - 1) / ENTRIES_PER_BLOCK;

if (rocm_grid_size == 0)
return;

if (rocm_grid_size > 1) {
block_flags = at::zeros({rocm_grid_size}, lengths.options().dtype(at::kInt));
block_sums = at::empty({rocm_grid_size}, output_offsets.options());
}

AT_DISPATCH_INDEX_TYPES(
lengths.scalar_type(), "index_select_scalar_cumsum_wrapper_1", [&] {
using length_t = index_t;
AT_DISPATCH_INDEX_TYPES(
offsets.scalar_type(),
"index_select_scalar_cumsum_wrapper_2",
[&] {
using offset_t = index_t;
AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(),
"index_select_scalar_cumsum_wrapper_3",
[&] {
FBGEMM_LAUNCH_KERNEL(
(index_select_scalar_cumsum_kernel<
length_t,
index_t,
offset_t,
MAX_CUMSUM_ENTRIES_PER_BLOCK,
ENTRIES_PER_BLOCK,
ENTRIES_PER_THREAD>),
rocm_grid_size,
MAX_CUMSUM_ENTRIES_PER_BLOCK,
0,
at::cuda::getCurrentCUDAStream(),
PTA_B(output_lengths, length_t, 1, 32),
PTA_B(output_offsets, offset_t, 1, 32),
PTA_B(lengths, length_t, 1, 32),
PTA_B(indices, index_t, 1, 32),
num_batches,
batch_size,
num_output_lengths -
ENTRIES_PER_BLOCK * (rocm_grid_size - 1),
rocm_grid_size > 1
? block_flags.data_ptr<int>()
: nullptr,
rocm_grid_size > 1
? block_sums.data_ptr<offset_t>()
: nullptr);
});
});
});
};

switch (entries_per_thread) {
case 4:
dispatch_cumsum(std::integral_constant<int, 4>{});
break;
case 2:
dispatch_cumsum(std::integral_constant<int, 2>{});
break;
default:
dispatch_cumsum(std::integral_constant<int, 1>{});
break;
}
#else
// CUDA path
if (grid_size > 1) {
block_flags = at::zeros({grid_size}, lengths.options().dtype(at::kInt));
block_sums = at::empty({grid_size}, output_offsets.options());
}

AT_DISPATCH_INDEX_TYPES(
lengths.scalar_type(), "index_select_scalar_cumsum_wrapper_1", [&] {
using length_t = index_t;
AT_DISPATCH_INDEX_TYPES(
offsets.scalar_type(),
"index_select_scalar_cumsum_wrapper_2",
[&] {
using offset_t = index_t;
AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(),
"index_select_scalar_cumsum_wrapper_3",
[&] {
FBGEMM_LAUNCH_KERNEL(
(index_select_scalar_cumsum_kernel<
length_t,
index_t,
offset_t,
MAX_CUMSUM_ENTRIES_PER_BLOCK,
MAX_CUMSUM_ENTRIES_PER_BLOCK,
ENTRIES_PER_THREAD>),
grid_size,
MAX_CUMSUM_ENTRIES_PER_BLOCK,
0,
at::cuda::getCurrentCUDAStream(),
PTA_B(output_lengths, length_t, 1, 32),
PTA_B(output_offsets, offset_t, 1, 32),
PTA_B(lengths, length_t, 1, 32),
PTA_B(indices, index_t, 1, 32),
num_batches,
batch_size,
num_output_lengths -
MAX_CUMSUM_ENTRIES_PER_BLOCK * (grid_size - 1),
grid_size > 1
? block_flags.data_ptr<int>()
: nullptr,
grid_size > 1
? block_sums.data_ptr<offset_t>()
: nullptr);
});
});
});
#endif

const int64_t num_outputs = (selected_lengths_sum.has_value())
? selected_lengths_sum.value().guard_int(__FILE__, __LINE__)
: output_offsets[output_offsets.numel() - 1].item<int64_t>();
? selected_lengths_sum.value().guard_int(__FILE__, __LINE__)
: output_offsets[output_offsets.numel() - 1].item<int64_t>();
Tensor output = at::empty({num_outputs}, values.options());
Tensor output_weights;
if (weights.has_value()) {
Expand Down