Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
fc18520
Use public API instead of removed private function in `te_llama.py` (…
janekb04 Jun 9, 2025
ddcda1f
Manage dependencies and add missing `einops` req (#1859)
ksivaman Jun 9, 2025
031c6cf
Python 3.12+ support (#1862)
ksivaman Jun 10, 2025
faee0e8
Support Context Parallel for Multi Latent Attention (MLA) (#1729)
yuzhongw-nvidia Jun 10, 2025
aedd7e1
pyproject.toml (#1852)
ksivaman Jun 10, 2025
0efc7da
[PyTorch] Fix backward compatibility for checkpoint loading (#1868)
ksivaman Jun 12, 2025
c293d3a
[PyTorch] Fix typo in GrouppedLinear (#1867)
pggPL Jun 12, 2025
5d01ef2
[JAX] GroupedDense v.2 without dynamic shape (#1721)
phu0ngng Jun 12, 2025
4d4f1ed
Cpu reload double buffer (#1695)
sanandaraj5597 Jun 12, 2025
c3b7c2a
Revert "[JAX] GroupedDense v.2 without dynamic shape" (#1874)
phu0ngng Jun 12, 2025
c9d7f3f
[JAX] GroupedDense v.2 without dynamic shape (#1875)
phu0ngng Jun 12, 2025
40a30a5
[PyTorch] Support L2Normalization basic op -> use for qk_norm (#1864)
negvet Jun 12, 2025
227961e
[JAX] Distinguish the reasons why fp8 / mxfp8 is not supported in uni…
huanghua1994 Jun 12, 2025
ecaf3e2
Fixes for JIT-able grouped_gemm (#1872)
phu0ngng Jun 12, 2025
d90ced7
Add support for overlapping wgrad NCCL AG with dgrad GEMM (#1849)
djns99 Jun 13, 2025
8d4bdbc
Optimize `/ops/fuser.py` by moving computation from `forward` to `__i…
janekb04 Jun 13, 2025
655512c
[PyTorch] Inference mode disables initializing quantized weights with…
timmoon10 Jun 13, 2025
e963e4a
[PyTorch] Add support for FP8 current scaling in operation-based API …
timmoon10 Jun 13, 2025
7b94bd9
[common] Added support of FP4 data type (#1779)
Oleg-Goncharov Jun 13, 2025
71c76b6
Add support for head_dim > 128 (#1797)
cyanguwa Jun 13, 2025
1ddfa0c
[JAX] Add support for Fused Attn MLA head_dim_qk != head_dim_v (#1851)
KshitijLakhani Jun 13, 2025
980c434
Changed VERSION to 2.5.0
ptrendx Jun 13, 2025
efe19c3
[JAX] Grouped GEMM & Dense support MXFP8 and handle empty matrices (#…
huanghua1994 Jun 16, 2025
4a16c2d
[Pytorch] Bugfix in te fusion ce implementation (#1879)
BestJuly Jun 16, 2025
b894f69
[JAX] Fixes for L0_jax_distributed_unittest (#1884)
phu0ngng Jun 17, 2025
82bff47
[JAX] TensorUsage + FP8 GEMM with all layouts handling on BW (#1844)
phu0ngng Jun 18, 2025
9192fb6
[PyTorch] Use FP16 tols for distributed tests with TF32 compute (#1831)
timmoon10 Jun 19, 2025
1e03882
Fix cppunittest test.sh for editable installs (#1869)
jberchtold-nvidia Jun 25, 2025
6f6951e
[PyTorch][MoE] Reduce CPU Overhead By Fuse Torch Empty Calls (#1793)
zhongbozhu Jun 26, 2025
7b9d9a5
[PyTorch|common] Optimize unpadding kernel for FP8 (#1866)
xiaoxi-wangfj Jun 26, 2025
c42614d
[PyTorch Debug] Fix the issue with PP (#1894)
pggPL Jun 26, 2025
968eb0d
[PyTorch Debug] Fixed the empty tensor bug in statistics computation …
pggPL Jun 26, 2025
866953e
[JAX] Use keyword args for jit in_shardings and out_shardings (#1898)
jberchtold-nvidia Jun 26, 2025
8382eed
[PyTorch] Skip KV cache for sm89 and cuDNN < 9.12 (#1895)
cyanguwa Jun 26, 2025
f05f12c
Fix MLA CP Bugs (#1896)
yuzhongw-nvidia Jun 28, 2025
94ac69f
Extended tensor parallelism support: support ETP+TP comm overlap rela…
fanshiqing Mar 12, 2025
c58598f
fix: fix aggregate=True for AG->Wgrad (layout=NT).
fanshiqing Jul 31, 2025
e23a81c
fix: fix aggregate=True with allgather on A tensor for AG->Wgrad (lay…
fanshiqing Jul 31, 2025
9107742
fix ag_gemm shape for A tensor.
fanshiqing Aug 11, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ jobs:
- name: 'Dependencies'
run: |
apt-get update
apt-get install -y git python3.9 pip ninja-build cudnn9-cuda-12
pip install cmake==3.21.0
apt-get install -y git python3.9 pip cudnn9-cuda-12
pip install cmake==3.21.0 pybind11[global] ninja
- name: 'Checkout'
uses: actions/checkout@v3
with:
Expand All @@ -42,8 +42,8 @@ jobs:
- name: 'Dependencies'
run: |
apt-get update
apt-get install -y git python3.9 pip ninja-build cudnn9-cuda-12
pip install cmake torch pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops
apt-get install -y git python3.9 pip cudnn9-cuda-12
pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops
- name: 'Checkout'
uses: actions/checkout@v3
with:
Expand All @@ -62,6 +62,8 @@ jobs:
image: ghcr.io/nvidia/jax:jax
options: --user root
steps:
- name: 'Dependencies'
run: pip install pybind11[global]
- name: 'Checkout'
uses: actions/checkout@v3
with:
Expand All @@ -87,7 +89,7 @@ jobs:
with:
submodules: recursive
- name: 'Build'
run: pip install --no-build-isolation . -v
run: pip install --no-build-isolation . -v --no-deps
env:
NVTE_FRAMEWORK: all
MAX_JOBS: 1
Expand Down
4 changes: 2 additions & 2 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,13 @@ Alternatively, install directly from the GitHub repository:

.. code-block:: bash

pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
pip install --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@stable

When installing from GitHub, you can explicitly specify frameworks using the environment variable:

.. code-block:: bash

NVTE_FRAMEWORK=pytorch,jax pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
NVTE_FRAMEWORK=pytorch,jax pip install --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@stable

conda Installation
^^^^^^^^^^^^^^^^^^
Expand Down
241 changes: 241 additions & 0 deletions benchmarks/linear/benchmark_grouped_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import argparse
import torch
import torch.utils.benchmark as benchmark
import pandas as pd
import pathlib

from transformer_engine.pytorch.module import GroupedLinear
from transformer_engine.common.recipe import Float8BlockScaling
from transformer_engine.pytorch.fp8 import fp8_autocast
from contextlib import nullcontext

RECIPES = {
"bf16": None,
"fp8_sub_channel": Float8BlockScaling(),
}


def run_linear_multiple_steps(layer, x, m_splits, mode, gradient, run_num_steps=1, recipe=None):
assert mode in ["fwd_only", "fwd_bwd"]
fp8_context = (
fp8_autocast(enabled=True, fp8_recipe=recipe) if recipe is not None else nullcontext()
)
# print(f"fp8_context: {fp8_context} and is it nullcontext? {isinstance(fp8_context, nullcontext)}")

if mode == "fwd_only":
with torch.no_grad(), fp8_context:
for i in range(run_num_steps):
y_q = layer.forward(
x,
m_splits,
is_first_microbatch=(i == 0),
)
return y_q
else:
# reset gradients
layer.zero_grad()
x.grad = None

with fp8_context:
for i in range(run_num_steps):
label = f"step_{i}"
torch.cuda.nvtx.range_push(label)
y_q = layer.forward(
x,
m_splits,
is_first_microbatch=(i == 0),
)
y_q.backward(gradient)
torch.cuda.nvtx.range_pop()

grads_q = []
grads_q.append(x.grad)
# remaining derivatives are in respect to model parameters
for p in layer.parameters():
if p.requires_grad:
grads_q.append(p.grad)

return y_q, grads_q


def benchmark_linear(
x,
ws,
m_splits,
bias,
recipe_name,
mode,
num_gemms=4,
):
params_dtype = torch.bfloat16
recipe = RECIPES[recipe_name]

in_features = x.shape[1]
out_features = ws[0].shape[0]
gradient = torch.ones((x.shape[0], out_features), dtype=torch.bfloat16, device=x.device)

layer = GroupedLinear(
num_gemms,
in_features,
out_features,
bias=bias is not None,
params_dtype=params_dtype,
)

layer = layer.to("cuda")
with torch.no_grad():
for i in range(num_gemms):
weight_i = getattr(layer, f"weight{i}")
weight_i.copy_(ws[i])
if bias is not None:
bias_i = getattr(layer, f"bias{i}")
bias_i.copy_(bias)

num_microbatches = 32

label = f"{recipe_name}_{'grouped'}"
torch.cuda.nvtx.range_push(label)
timing = benchmark.Timer(
stmt=(
"run_linear_multiple_steps(layer, x, m_splits, mode, gradient, num_microbatches,"
" recipe)"
),
globals={
"run_linear_multiple_steps": run_linear_multiple_steps,
"layer": layer,
"x": x,
"m_splits": m_splits,
"mode": mode,
"gradient": gradient,
"num_microbatches": num_microbatches,
"recipe": recipe,
},
num_threads=1,
).blocked_autorange(min_run_time=5)
print(f"{recipe_name}: {timing} \n")
timing_ms = timing.median * 1000 / num_microbatches

return timing_ms


def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4):
data = []
assert not use_bias, "Bias is not supported for GroupedLinear benchmark"

print(f"========== Benchmarking {recipe_name} ==========")
for m, k, n in mkns:
device = "cuda"
x = torch.randn((m, k), dtype=torch.bfloat16, device=device, requires_grad=True)
ws = [torch.randn((n, k), dtype=torch.bfloat16, device=device) for _ in range(num_gemms)]
assert m % num_gemms == 0
m_splits = [m // num_gemms] * num_gemms
# Bias is not supported for GroupedLinear benchmark
bias = None

# Run the benchmark
print(f"fwd_m={m}, fwd_k={k}, fwd_n={n}")

grouped_fwd_bwd_timing_ms = benchmark_linear(
x,
ws,
m_splits,
bias,
recipe_name,
mode="fwd_bwd",
num_gemms=num_gemms,
)

# Append the results
data.append(
[
m,
k,
n,
recipe_name,
num_gemms,
grouped_fwd_bwd_timing_ms,
]
)

df = pd.DataFrame(
data=data,
columns=[
"m",
"k",
"n",
"recipe",
"num_gemms",
"grouped_fwd_bwd_time_ms",
],
)

print(df, "\n")
return df


if __name__ == "__main__":

parser = argparse.ArgumentParser()
parser.add_argument("--profile", action="store_true", help="Enable profiling mode")
parser.add_argument(
"--output_dir",
type=str,
default="benchmark_output/",
help="output path for report",
)
args = parser.parse_args()

use_bias = False
# Set the MKN values to benchmark
mkns = []
for m in [1024]:
# for m in [4096, 8192, 16384]:
# for n in [1024, 2048, 4096, 8192, 16384]:
for n in [3072]:
for k in [4096]:
mkns.append((m, k, n))

# recipe_list = [
# "bf16", "fp8_sub_channel",
# ]
recipe_list = [
"fp8_sub_channel",
]

# num_gemms_list = [16, 32]
num_gemms_list = [4]

if args.profile:
# nsys profile --output=./benchmarks/linear/mkn_4096_4096_4096_numgemm_1_bf16 --trace=cuda,nvtx,cudnn,cublas python benchmarks/linear/benchmark_grouped_linear.py --profile
# nsys profile --output=./benchmarks/linear/mkn_8192_8192_8192_numgemm_32_bf16 --trace=cuda,nvtx,cudnn,cublas python benchmarks/linear/benchmark_grouped_linear.py --profile
# nsys profile --output=./benchmarks/linear/mkn_4096_4096_4096_numgemm_8_fp8_sub_channel --trace=cuda,nvtx,cudnn,cublas python benchmarks/linear/benchmark_grouped_linear.py --profile
# nsys profile --output=./benchmarks/linear/mkn_8192_8192_8192_numgemm_2_fp8_sub_channel --trace=cuda,nvtx,cudnn,cublas python benchmarks/linear/benchmark_grouped_linear.py --profile
mkns = [(4096, 4096, 4096)]
recipe_list = ["fp8_sub_channel"]
# recipe_list = ["bf16"]
num_gemms_list = [8]
torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__()

# Initialize a dataframe to store the results
df_linears = pd.DataFrame()

# Run the fp8 benchmarks
for num_gemms in num_gemms_list:
print(f"========== Benchmarking with num_gemms={num_gemms} ==========")
for recipe_name in recipe_list:
df = run_benchmark_linear(
mkns,
recipe_name,
use_bias,
num_gemms=num_gemms,
)
df_linears = pd.concat([df_linears, df])

print(df_linears)

if args.profile:
torch.autograd.profiler.emit_nvtx().__exit__(None, None, None)
2 changes: 1 addition & 1 deletion build_tools/VERSION.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.5.0.dev0
2.5.0
26 changes: 12 additions & 14 deletions build_tools/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

"""JAX related extensions."""
import os
import shutil
from pathlib import Path

import setuptools
Expand All @@ -13,6 +12,16 @@
from typing import List


def install_requirements() -> List[str]:
"""Install dependencies for TE/JAX extensions."""
return ["jax", "flax>=0.7.1"]


def test_requirements() -> List[str]:
"""Test dependencies for TE/JAX extensions."""
return ["numpy"]


def xla_path() -> str:
"""XLA root path lookup.
Throws FileNotFoundError if XLA source is not found."""
Expand Down Expand Up @@ -66,20 +75,9 @@ def setup_jax_extension(
# Define TE/JAX as a Pybind11Extension
from pybind11.setup_helpers import Pybind11Extension

class Pybind11CPPExtension(Pybind11Extension):
"""Modified Pybind11Extension to allow custom CXX flags."""

def _add_cflags(self, flags: List[str]) -> None:
if isinstance(self.extra_compile_args, dict):
cxx_flags = self.extra_compile_args.pop("cxx", [])
cxx_flags += flags
self.extra_compile_args["cxx"] = cxx_flags
else:
self.extra_compile_args[:0] = flags

return Pybind11CPPExtension(
return Pybind11Extension(
"transformer_engine_jax",
sources=[str(path) for path in sources],
include_dirs=[str(path) for path in include_dirs],
extra_compile_args={"cxx": cxx_flags},
extra_compile_args=cxx_flags,
)
16 changes: 16 additions & 0 deletions build_tools/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,22 @@
import setuptools

from .utils import all_files_in_dir, cuda_version, get_cuda_include_dirs, debug_build_enabled
from typing import List


def install_requirements() -> List[str]:
"""Install dependencies for TE/JAX extensions."""
reqs = ["torch>=2.1", "einops"]
reqs.append(
"nvdlfw-inspect @"
" git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git@v0.1#egg=nvdlfw-inspect"
)
return reqs


def test_requirements() -> List[str]:
"""Test dependencies for TE/JAX extensions."""
return ["numpy", "torchvision", "transformers"]


def setup_pytorch_extension(
Expand Down
7 changes: 0 additions & 7 deletions build_tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,10 +354,3 @@ def copy_common_headers(
new_path = dst_dir / path.relative_to(src_dir)
new_path.parent.mkdir(exist_ok=True, parents=True)
shutil.copy(path, new_path)


def install_and_import(package):
"""Install a package via pip (if not already installed) and import into globals."""
main_package = package.split("[")[0]
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
globals()[main_package] = importlib.import_module(main_package)
Loading
Loading