Skip to content
Open
1 change: 1 addition & 0 deletions fbgemm_gpu/FbgemmGpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ if(NOT FBGEMM_BUILD_VARIANT STREQUAL BUILD_VARIANT_CPU)
src/quantize_ops/quantize_msfp.cu
src/quantize_ops/quantize_padded_fp8_rowwise.cu
src/quantize_ops/quantize_mx.cu
src/sparse_ops/utils/rocm/sparse_group_utils.cu
src/sparse_ops/sparse_async_batched_cumsum.cu
src/sparse_ops/sparse_block_bucketize_features.cu
src/sparse_ops/sparse_bucketize_features.cu
Expand Down
7 changes: 6 additions & 1 deletion fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -1078,6 +1078,8 @@ void group_index_select_or_add_cuda(
const int64_t* input_ptrs,
const int64_t* output_ptrs,
const int64_t* indices_ptrs,
const int64_t* sorted_indices_ptrs,
const int64_t* reverse_indices_ptrs,
const int64_t* warp_offsets_group,
const int32_t* num_cols_group,
const c10::ScalarType& input_scalar_type,
Expand All @@ -1087,7 +1089,10 @@ void group_index_select_or_add_cuda(
const int64_t total_num_warps,
const int group_size,
const bool use_index_select,
const bool use_var_cols);
const bool use_var_cols,
const bool use_contiguous_warps,
const bool use_cache,
const bool use_packed_rows);

int get_group_index_select_cols_per_warp();

Expand Down
63 changes: 62 additions & 1 deletion fbgemm_gpu/include/fbgemm_gpu/utils/rocm/sparse_group_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,26 @@
#include <cstdint>
#include <limits>

#include <ATen/Dispatch.h>
#include <ATen/ATen.h>

#include <hip/hip_runtime.h>
#include <rocprim/device/device_radix_sort.hpp>
#include <rocprim/device/device_segmented_radix_sort.hpp>

#include "fbgemm_gpu/utils/cuda_prelude.cuh"
#include "fbgemm_gpu/utils/function_types.h"

namespace fbgemm_gpu::rocm {
// Selected empirically: rocprim uses merge sort when num_items < this threshold,
// which is faster for small inputs. Must match across sizing and sort calls.
constexpr unsigned int k_sort_merge_threshold = 400'000;
using sort_config = rocprim::radix_sort_config<
rocprim::default_config,
rocprim::default_config,
rocprim::default_config,
k_sort_merge_threshold>;

namespace {

template <typename scalar_t, int kLogicalWarpSize = kWarpSize>
Expand Down Expand Up @@ -67,6 +84,50 @@ __device__ __forceinline__ void warp_upper_bound(
*found = result;
*cached_boundary = cached_result;
}

} // namespace

// Returns temp storage size for a single-segment sort of num_items elements.
size_t get_sort_temp_storage_bytes(
const size_t num_items,
const c10::ScalarType scalar_type,
const at::cuda::CUDAStream& stream);
// Returns temp storage size for segmented sort of num_groups segments each
// with num_items_per_segment elements.
size_t get_segmented_sort_temp_storage_bytes(
const size_t num_items_per_segment,
const int64_t num_groups,
const c10::ScalarType scalar_type,
const at::cuda::CUDAStream& stream);
// Sort all groups' indices with one rocprim::segmented_radix_sort_pairs call,
// eliminating all per-group CPU launch overhead.
//
// Inputs must be contiguous across groups:
// all_keys_in : [num_groups * num_items_per_segment] — packed input indices
// all_values_in : [num_groups * num_items_per_segment] — tiled 0..N-1 per segment
// segment_offsets: [num_groups + 1] device tensor — [0, N, 2N, ..., K*N]
// all_keys_out / all_values_out: pre-allocated output buffers (same shape)
// temp_storage : pre-allocated via get_segmented_sort_temp_storage_bytes()
void sort_indices_segmented_rocprim(
const at::Tensor& all_keys_in,
at::Tensor& all_keys_out,
const at::Tensor& all_values_in,
at::Tensor& all_values_out,
const at::Tensor& segment_offsets,
const size_t num_items_per_segment,
const int64_t num_groups,
at::Tensor& temp_storage,
const at::cuda::CUDAStream& stream);
// Sort all groups in a batch with one AT_DISPATCH and one stream lookup.
// Uses radix_sort_pairs<sort_config> per group, preserving the merge sort
// fallback for small segment sizes (num_items < k_sort_merge_threshold).
void sort_indices_batch_rocprim(
const int64_t* keys_in_ptrs,
void* keys_out_base,
int64_t* values_out_base,
const int64_t* values_in,
const size_t num_items,
const int64_t num_groups,
at::Tensor& temp_storage,
const c10::ScalarType scalar_type,
const at::cuda::CUDAStream& stream);
} // namespace fbgemm_gpu::rocm
Loading