Skip to content
2 changes: 2 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -1090,6 +1090,8 @@ void group_index_select_or_add_cuda(

int get_group_index_select_cols_per_warp();

int get_group_index_select_unroll_factor();

std::vector<at::Tensor> jagged_index_select_2d(
const at::Tensor& values,
const at::Tensor& lengths,
Expand Down
92 changes: 74 additions & 18 deletions fbgemm_gpu/src/sparse_ops/sparse_group_index.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ int get_group_index_select_cols_per_warp() {
return GROUP_INDEX_SELECT_COLS_PER_WARP;
}

int get_group_index_select_unroll_factor() {
return GROUP_INDEX_SELECT_UNROLL_FACTOR;
}

template <
typename index_t,
typename scalar_t,
Expand Down Expand Up @@ -82,28 +86,80 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
// All columns are the same
member_id = warp_id / (warps_per_row * num_work_rows);
member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows);
#ifdef USE_ROCM
if (num_cols < COLS_PER_WARP && num_cols >= UNROLL_FACTOR) {
// Need to ensure that [member_id] and [member_warp_id] are calculated correctly
// for the small embedding dimension path below
const auto rows_per_warp = COLS_PER_WARP / num_cols;
const auto warps_per_member = (num_work_rows + rows_per_warp - 1) / rows_per_warp;
member_id = warp_id / warps_per_member;
member_warp_id = warp_id % warps_per_member;
}
#endif // USE_ROCM
}
const auto row = member_warp_id / warps_per_row;
const auto col_offset =
((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) +
(threadIdx.x * UNROLL_FACTOR);
scalar_t* input =
reinterpret_cast<scalar_t*>(input_ptrs[member_id]) + col_offset;
scalar_t* output =
reinterpret_cast<scalar_t*>(output_ptrs[member_id]) + col_offset;

index_t* indices = reinterpret_cast<index_t*>(indices_ptrs[member_id]);
const index_t idx = indices[row];

#ifdef USE_ROCM
if (num_cols < COLS_PER_WARP && num_cols >= UNROLL_FACTOR) {
// Optimized path for small embedding dimensions
// Each warp processes 'rows_per_warp' rows
const auto rows_per_warp = COLS_PER_WARP / num_cols;
const int64_t start_row = member_warp_id * rows_per_warp;

// Since we are processing multiple rows within the warp, we need to
// map each lane to a specific row, in addition to the column
const auto local_row = (threadIdx.x * UNROLL_FACTOR) / num_cols; // the row ID within the set of rows handled by this warp
const auto col_offset = (threadIdx.x * UNROLL_FACTOR) % num_cols;
const int64_t current_row = start_row + local_row; // the actual row within the table processed by this lane

// local_row may be out of bounds for the last few lanes in the warp if [COLS_PER_WARP % num_cols != 0]
// and we also need to confirm that we are within num_work_rows
if (local_row < rows_per_warp && current_row < num_work_rows) {
scalar_t* input =
reinterpret_cast<scalar_t*>(input_ptrs[member_id]) + col_offset;
scalar_t* output =
reinterpret_cast<scalar_t*>(output_ptrs[member_id]) + col_offset;

index_t* indices = reinterpret_cast<index_t*>(indices_ptrs[member_id]);
const index_t idx = indices[current_row];
#pragma unroll
for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) {
// Compile time conditional
if constexpr (USE_INDEX_SELECT) {
output[row * num_cols + i] = LDG(&input[idx * num_cols + i]);
} else {
gpuAtomicAddNoReturn(
&output[idx * num_cols + i], input[row * num_cols + i]);
for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) {
// Compile time conditional
if constexpr (USE_INDEX_SELECT) {
output[current_row * num_cols + i] = LDG(&input[idx * num_cols + i]);
} else {
gpuAtomicAddNoReturn(
&output[idx * num_cols + i], input[current_row * num_cols + i]);
}
}
}
} else {
// Large embedding dimensions use >= 1 warp per row
// which is the default codepath for non-ROCm as well
#endif // USE_ROCM
const auto row = member_warp_id / warps_per_row;
const auto col_offset =
((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) +
(threadIdx.x * UNROLL_FACTOR);
scalar_t* input =
reinterpret_cast<scalar_t*>(input_ptrs[member_id]) + col_offset;
scalar_t* output =
reinterpret_cast<scalar_t*>(output_ptrs[member_id]) + col_offset;

index_t* indices = reinterpret_cast<index_t*>(indices_ptrs[member_id]);
const index_t idx = indices[row];
#pragma unroll
for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) {
// Compile time conditional
if constexpr (USE_INDEX_SELECT) {
output[row * num_cols + i] = LDG(&input[idx * num_cols + i]);
} else {
gpuAtomicAddNoReturn(
&output[idx * num_cols + i], input[row * num_cols + i]);
}
}
#ifdef USE_ROCM
}
#endif // USE_ROCM
}
}

Expand Down
20 changes: 20 additions & 0 deletions fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu(
Tensor input_reshaped = first_input.reshape({num_input_rows, -1});
const int num_cols = input_reshaped.size(1);
const int cols_per_warp = get_group_index_select_cols_per_warp();
const int unroll_factor = get_group_index_select_unroll_factor();
int64_t warp_offset = 0;
bool use_var_cols = false;

Expand Down Expand Up @@ -303,7 +304,22 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu(

// Number of columns can be different
auto num_cols_ = input_reshaped_.size(1);

#ifdef USE_ROCM
int64_t warps_needed;
if (num_cols_ < cols_per_warp && num_cols_ >= unroll_factor) {
// Optimization: Pack multiple rows into one warp
int rows_per_warp = cols_per_warp / num_cols_;
warps_needed = (num_output_rows_ + rows_per_warp - 1) / rows_per_warp;
} else {
// Standard: One or more warps per row
int warps_per_row = (num_cols_ + cols_per_warp - 1) / cols_per_warp;
warps_needed = warps_per_row * num_output_rows_;
}
#else
// Standard: One or more warps per row
auto warps_per_row = (num_cols_ + cols_per_warp - 1) / cols_per_warp;
#endif // USE_ROCM

if (num_cols != num_cols_) {
use_var_cols = true;
Expand All @@ -329,7 +345,11 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu(
warp_offsets_group[i] = warp_offset;
num_cols_group[i] = num_cols_;

#ifdef USE_ROCM
warp_offset += warps_needed;
#else
warp_offset += warps_per_row * num_output_rows;
#endif // USE_ROCM
}

// Store the last offset
Expand Down