Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
ad748da
GEMM reference HIP implementation
matthiasdiener Dec 9, 2025
11e090b
blockwise amax
matthiasdiener Dec 11, 2025
9006224
Merge branch 'dev' into compute-ref-offload
matthiasdiener Dec 18, 2025
3ecea7f
Change to use Tensor arguments, combine mxfp8/non-mxfp8 paths
matthiasdiener Jan 13, 2026
cafee59
Merge remote-tracking branch 'origin/dev' into compute-ref-offload
matthiasdiener Jan 14, 2026
86fbbac
skip on SwizzleScale limitation on gfx950
matthiasdiener Jan 14, 2026
54de3db
Revert "skip on SwizzleScale limitation on gfx950"
matthiasdiener Jan 14, 2026
311ddfe
MXFP8 fix
matthiasdiener Jan 14, 2026
306e432
Merge remote-tracking branch 'origin/dev' into compute-ref-offload
matthiasdiener Jan 15, 2026
445e64f
correct scale_inv packing and exp2(biased−127) conversion
matthiasdiener Jan 15, 2026
462945f
cleanups
matthiasdiener Jan 15, 2026
e32fb3d
Merge branch 'dev' into compute-ref-offload
matthiasdiener Jan 19, 2026
7bf8adb
Merge remote-tracking branch 'origin/dev' into compute-ref-offload
matthiasdiener Jan 22, 2026
e11e400
use Tensor class for more device objects
matthiasdiener Jan 22, 2026
325ece6
Pass D Tensor into run_reference and move RefD allocation into Perfor…
matthiasdiener Jan 23, 2026
fc64b8c
[WIP] proof-of-concept: grouped GEMM with ck_tile
matthiasdiener Jan 26, 2026
134b350
Merge branch 'dev' into ck-grouped-gemm
matthiasdiener Jan 28, 2026
9091e6c
restructure and enable tests
matthiasdiener Jan 29, 2026
7435062
Merge remote-tracking branch 'origin/dev' into ck-grouped-gemm
matthiasdiener Jan 29, 2026
a00a1c8
Merge remote-tracking branch 'origin/dev' into ck-grouped-gemm
matthiasdiener Jan 30, 2026
4e9ead9
grid improvements
matthiasdiener Jan 30, 2026
259645c
restructure
matthiasdiener Feb 3, 2026
9986bd4
reduce code duplication & simplify
matthiasdiener Feb 4, 2026
355ec2f
make the code more similar to nv, check emopty gelu/bias
matthiasdiener Feb 4, 2026
df5e3ea
Merge branch 'dev' into ck-grouped-gemm
matthiasdiener Feb 4, 2026
a42f7ca
further simplify & make closer to nv
matthiasdiener Feb 4, 2026
fac7c11
add ck_tile reference
matthiasdiener Feb 4, 2026
71b97e0
rename in error messages
matthiasdiener Feb 4, 2026
dd3ed2f
allow flattened higher-D tensors
matthiasdiener Feb 4, 2026
7b0413e
Merge remote-tracking branch 'origin/dev' into ck-grouped-gemm
matthiasdiener Feb 5, 2026
ebc005f
relax tolerance on gfx942
matthiasdiener Feb 5, 2026
c0bf502
enable more tests
matthiasdiener Feb 5, 2026
0b16287
return early when num_gemms<=0
matthiasdiener Feb 5, 2026
58b34e7
simplify normalization
matthiasdiener Feb 5, 2026
74f229a
Merge remote-tracking branch 'origin/dev' into ck-grouped-gemm
matthiasdiener Feb 10, 2026
e28c801
run hipblaslt for num_gemms==1
matthiasdiener Feb 11, 2026
6151b96
Merge remote-tracking branch 'origin/dev' into ck-grouped-gemm
matthiasdiener Feb 12, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions gmm2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import os
import time
import torch
import transformer_engine.pytorch as te

torch.manual_seed(0)

os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1"
os.environ["NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK"] = "1"

device = "cuda"
dtype = torch.bfloat16

E = 4
K = 1024
N = 2048
m_splits = [128, 64, 0, 256]
M_total = sum(m_splits)

x = torch.randn(M_total, K, device=device, dtype=dtype)

# Timing helper
def bench_cuda(fn, warmup=20, iters=100):
# Warmup
for _ in range(warmup):
fn()
torch.cuda.synchronize()

# Timed
start = time.time()
for _ in range(iters):
fn()
torch.cuda.synchronize()
end = time.time()

avg_ms = (end - start) * 1000.0 / iters
return avg_ms

# TE GroupedLinear
glinear = te.GroupedLinear(E, K, N, bias=False).to(device=device, dtype=dtype)

def te_run():
return glinear(x, m_splits=m_splits)

te_ms = bench_cuda(te_run, warmup=20, iters=100)

# Grab weights for reference path
Ws = [getattr(glinear, f"weight{e}") for e in range(E)] # each [N, K]
W = torch.stack(Ws, dim=0) # [E, N, K]
assert W.shape == (E, N, K), f"Unexpected weight shape: {W.shape}"

# Torch reference (group loop)
offsets = []
off = 0
for m in m_splits:
offsets.append(off)
off += m

y_ref_buf = torch.empty((M_total, N), device=device, dtype=dtype)

def torch_run():
# Fill the preallocated buffer
for e, m in enumerate(m_splits):
if m == 0:
continue
o = offsets[e]
y_ref_buf[o:o+m].copy_(x[o:o+m] @ W[e].transpose(0, 1))
return y_ref_buf

torch_ms = bench_cuda(torch_run, warmup=20, iters=100)

# Compare outputs
y_te = te_run()
y_ref = torch_run().clone()

diff = (y_te.float() - y_ref.float())
max_abs = diff.abs().max().item()
rel = (diff.abs() / (y_ref.float().abs() + 1e-6)).max().item()

print(f"Errors:")
print(f" {y_te.shape=}, {y_ref.shape=}")
print(" max_abs_err:", max_abs)
print(" max_rel_err:", rel)

torch.testing.assert_close(y_te.float(), y_ref.float(), rtol=3e-2, atol=3e-2)

print(f"\nTiming:")
print(f" TE avg: {te_ms:.3f} ms")
print(f" Torch avg: {torch_ms:.3f} ms")
print(f" Speedup: {torch_ms/te_ms:.2f}x (Torch / TE)")
21 changes: 13 additions & 8 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
is_bf16_compatible,
)
if IS_HIP_EXTENSION:
from transformer_engine.pytorch.utils import is_mi200, is_mi308
from transformer_engine.pytorch.utils import is_mi200, is_mi308, is_mi300_class

from transformer_engine.pytorch import (
DotProductAttention,
Expand Down Expand Up @@ -148,7 +148,7 @@ def rocm_attn_backend() -> tuple[bool, bool, bool]:

use_cutlass_grouped_gemm = [False]
# Only enable cutlass grouped gemm on Hopper
if torch.cuda.get_device_capability() == (9, 0):
if torch.cuda.get_device_capability() == (9, 0) or IS_HIP_EXTENSION:
use_cutlass_grouped_gemm.append(True)


Expand Down Expand Up @@ -1386,7 +1386,7 @@ def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_

if IS_HIP_EXTENSION:
if dtype not in (torch.float32,) and fuse_wgrad_accumulation and bias:
pytest.skip(f"Rocm does not support fused wgrad accumulation for {dtype}.")
pytest.skip(f"ROCm does not support fused wgrad accumulation for {dtype}.")

te_linear_ref = Linear(
config.hidden_size,
Expand Down Expand Up @@ -1678,7 +1678,7 @@ def test_layernorm_linear_accuracy_delay_wgrad_compute(
):
if IS_HIP_EXTENSION:
if dtype not in (torch.float32,) and fuse_wgrad_accumulation and bias:
pytest.skip(f"Rocm does not support fused wgrad accumulation for {dtype}.")
pytest.skip(f"ROCm does not support fused wgrad accumulation for {dtype}.")
config = model_configs[model]

ln_linear_ref = LayerNormLinear(
Expand Down Expand Up @@ -1892,7 +1892,7 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute(

if IS_HIP_EXTENSION:
if dtype not in (torch.float32,) and fuse_wgrad_accumulation and bias:
pytest.skip(f"Rocm does not support fused wgrad accumulation for {dtype}.")
pytest.skip(f"ROCm does not support fused wgrad accumulation for {dtype}.")

ln_mlp = LayerNormMLP(
hidden_size=config.hidden_size,
Expand Down Expand Up @@ -2042,7 +2042,7 @@ def test_grouped_linear_accuracy(

if IS_HIP_EXTENSION:
if dtype not in (torch.float32,) and fuse_wgrad_accumulation and not fp8:
pytest.skip(f"Rocm does not support fused wgrad accumulation for {dtype}.")
pytest.skip(f"ROCm does not support fused wgrad accumulation for {dtype}.")
if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")

Expand Down Expand Up @@ -2121,6 +2121,8 @@ def test_grouped_linear_accuracy(
atol, rtol = 0, 0
if use_cutlass:
atol, rtol = 1e-3, 1e-3
if IS_HIP_EXTENSION and is_mi300_class():
atol, rtol = 3e-2, 3e-2
if use_triton:
atol, rtol = get_tolerances(dtype)
if dtype == torch.float32:
Expand All @@ -2131,7 +2133,7 @@ def test_grouped_linear_accuracy(


@pytest.mark.skipif(
torch.cuda.get_device_capability() != (9, 0),
torch.cuda.get_device_capability() != (9, 0) and not IS_HIP_EXTENSION,
reason="Only enable CUTLASS grouped gemm on Hopper",
)
@pytest.mark.parametrize("dtype", param_types, ids=str)
Expand Down Expand Up @@ -2936,7 +2938,10 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
# cublas implementation should be bit-wise match
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
else:
torch.testing.assert_close(o, o_ref, rtol=1.5e-2, atol=1.5e-2)
if IS_HIP_EXTENSION and is_mi300_class():
torch.testing.assert_close(o, o_ref, rtol=2.0e-2, atol=3.0e-2)
else:
torch.testing.assert_close(o, o_ref, rtol=1.5e-2, atol=1.5e-2)

if use_cutlass:
os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None)
Expand Down
4 changes: 4 additions & 0 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ else()
fused_attn_rocm/fused_attn_ck.cpp
fused_attn_rocm/utils.cpp
gemm/rocm_gemm.cu
gemm/ck_grouped_gemm.cpp
amd_detail/system.cpp)

# process source code files
Expand Down Expand Up @@ -251,6 +252,9 @@ if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0)
else()
message(FATAL_ERROR "cutlass gemm/cutlass_grouped_gemm.cu kernel required sm 90a")
endif()
else()
set(CK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/aiter/3rdparty/composable_kernel)
target_include_directories(transformer_engine PRIVATE ${CK_ROOT}/include)
endif() #USE_CUDA

# Configure dependencies
Expand Down
Loading
Loading