-
Notifications
You must be signed in to change notification settings - Fork 222
[tx] Implement cutlass kernel for ragged_dot with group_offset #896
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…o tx-ragged-dot-cutlass
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a custom CUTLASS kernel for ragged_dot with group_offset to improve performance, which is a great addition. The implementation includes the CUDA kernel, Python FFI bindings, and integration into the existing ragged_dot utility.
My review has found a few issues:
- A critical bug in the CUDA kernel's pointer arithmetic that will lead to incorrect results.
- A potential thread-safety issue in the device property caching function.
- An error in the build instructions in the README file.
I've provided detailed comments and suggestions for each of these points. Once these are addressed, this will be a solid performance enhancement.
skyrl-tx/tx/ffi/ragged_dot_ffi.cu
Outdated
| int32_t m = group_offsets_cumsum[global] - start; | ||
|
|
||
| A_ptrs[tid] = A + static_cast<int64_t>(start) * k; | ||
| B_ptrs[tid] = B + static_cast<int64_t>(tid) * n * k; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There appears to be a bug in the pointer arithmetic for B_ptrs. The rhs tensor has dimensions [g_local, k, n]. To get the pointer to the tid-th matrix, the offset should be tid * k * n. The current code uses tid * n * k, which will lead to incorrect memory access and wrong results.
B_ptrs[tid] = B + static_cast<int64_t>(tid) * k * n;
skyrl-tx/tx/ffi/README.md
Outdated
| 1) Build the shared library (requires CUDA nvcc with C++17 support): | ||
| ``` | ||
| export CUTLASS_DIR=/path/to/cutlass | ||
| export NVCC_ARCH=sm_90a # for H100, adjust if needed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The instruction for setting NVCC_ARCH is inconsistent with how it's used in the build script. The script prepends sm_ to the NVCC_ARCH variable. Therefore, you should set NVCC_ARCH to the architecture number (e.g., 90a) instead of the full sm_ string (e.g., sm_90a).
| export NVCC_ARCH=sm_90a # for H100, adjust if needed | |
| export NVCC_ARCH=90a # for H100, adjust if needed |
| static int get_sm_count() { | ||
| int device = 0; | ||
| if (cudaGetDevice(&device) != cudaSuccess || device < 0) { | ||
| return 0; | ||
| } | ||
| if (static_cast<size_t>(device) >= g_device_props.size()) { | ||
| g_device_props.resize(device + 1); | ||
| } | ||
| cudaDeviceProp& props = g_device_props[device]; | ||
| if (!props.multiProcessorCount) { | ||
| cudaGetDeviceProperties(&props, device); | ||
| } | ||
| return props.multiProcessorCount; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The get_sm_count function uses a global static std::vector (g_device_props) to cache device properties. This is not thread-safe. If multiple host threads call this function concurrently, it could lead to a race condition when resizing the vector. Using a static local mutex will make it thread-safe.
Please also add #include <mutex> at the top of the file.
static int get_sm_count() {
static std::mutex mtx;
std::lock_guard<std::mutex> lock(mtx);
int device = 0;
if (cudaGetDevice(&device) != cudaSuccess || device < 0) {
return 0;
}
if (static_cast<size_t>(device) >= g_device_props.size()) {
g_device_props.resize(device + 1);
}
cudaDeviceProp& props = g_device_props[device];
if (!props.multiProcessorCount) {
cudaGetDeviceProperties(&props, device);
}
return props.multiProcessorCount;
}
This brings down the step time of
with
from 40s to 30s.