Skip to content

Conversation

@pcmoritz
Copy link
Collaborator

This brings down the step time of

uv run --with wandb --with tinker==0.3.0 sl_loop.py     base_url=http://localhost:8000     model_name=Qwen/Qwen3-30B-A3B lora_rank=1 max_length=512

with

uv run --extra gpu --extra tinker -m tx.tinker.api     --base-model Qwen/Qwen3-30B-A3B     --backend-config '{"max_lora_adapters": 2, "max_lora_rank": 1, "expert_parallel_size": 8, "train_micro_batch_size": 1, "shard_attention_heads": false}'

from 40s to 30s.

@pcmoritz pcmoritz added the tx label Jan 19, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a custom CUTLASS kernel for ragged_dot with group_offset to improve performance, which is a great addition. The implementation includes the CUDA kernel, Python FFI bindings, and integration into the existing ragged_dot utility.

My review has found a few issues:

  • A critical bug in the CUDA kernel's pointer arithmetic that will lead to incorrect results.
  • A potential thread-safety issue in the device property caching function.
  • An error in the build instructions in the README file.

I've provided detailed comments and suggestions for each of these points. Once these are addressed, this will be a solid performance enhancement.

int32_t m = group_offsets_cumsum[global] - start;

A_ptrs[tid] = A + static_cast<int64_t>(start) * k;
B_ptrs[tid] = B + static_cast<int64_t>(tid) * n * k;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There appears to be a bug in the pointer arithmetic for B_ptrs. The rhs tensor has dimensions [g_local, k, n]. To get the pointer to the tid-th matrix, the offset should be tid * k * n. The current code uses tid * n * k, which will lead to incorrect memory access and wrong results.

  B_ptrs[tid] = B + static_cast<int64_t>(tid) * k * n;

1) Build the shared library (requires CUDA nvcc with C++17 support):
```
export CUTLASS_DIR=/path/to/cutlass
export NVCC_ARCH=sm_90a # for H100, adjust if needed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The instruction for setting NVCC_ARCH is inconsistent with how it's used in the build script. The script prepends sm_ to the NVCC_ARCH variable. Therefore, you should set NVCC_ARCH to the architecture number (e.g., 90a) instead of the full sm_ string (e.g., sm_90a).

Suggested change
export NVCC_ARCH=sm_90a # for H100, adjust if needed
export NVCC_ARCH=90a # for H100, adjust if needed

Comment on lines +31 to +44
static int get_sm_count() {
int device = 0;
if (cudaGetDevice(&device) != cudaSuccess || device < 0) {
return 0;
}
if (static_cast<size_t>(device) >= g_device_props.size()) {
g_device_props.resize(device + 1);
}
cudaDeviceProp& props = g_device_props[device];
if (!props.multiProcessorCount) {
cudaGetDeviceProperties(&props, device);
}
return props.multiProcessorCount;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The get_sm_count function uses a global static std::vector (g_device_props) to cache device properties. This is not thread-safe. If multiple host threads call this function concurrently, it could lead to a race condition when resizing the vector. Using a static local mutex will make it thread-safe.

Please also add #include <mutex> at the top of the file.

static int get_sm_count() {
  static std::mutex mtx;
  std::lock_guard<std::mutex> lock(mtx);

  int device = 0;
  if (cudaGetDevice(&device) != cudaSuccess || device < 0) {
    return 0;
  }
  if (static_cast<size_t>(device) >= g_device_props.size()) {
    g_device_props.resize(device + 1);
  }
  cudaDeviceProp& props = g_device_props[device];
  if (!props.multiProcessorCount) {
    cudaGetDeviceProperties(&props, device);
  }
  return props.multiProcessorCount;
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant