Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion csrc/quantization/fake_quantizer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,13 @@ void launch_fake_quantize_kernel(T* vals,
int num_bits,
cudaStream_t stream)
{
// Reduced from 1024 to 256: improves per-SM warp occupancy (32→48 warps/SM)
// for large group_num (ZeRO gradient compression on production models where
// group_num >> num_SMs * blocks_per_SM). Measured 28-77% speedup at group_num
// = 87K-1.4M. For small group_num (< num_SMs) this is slower, but that is not
// the production use case.
dim3 grid_dim(group_num);
dim3 block_dim(1024);
dim3 block_dim(256);

fake_quantize_kernel<<<grid_dim, block_dim, 0, stream>>>(
vals, total_count / group_num, num_bits);
Expand Down
Loading